From ad3059892391debd25bb3adcfed127523db16d90 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 4 Feb 2025 11:13:12 +0100 Subject: [PATCH] Update Mistral converter (#35967) * Update convert_mistral_weights_to_hf.py * Update convert_mistral_weights_to_hf.py * update * style * move it to integrations * style * trigger CIs * trigger CIs --- src/transformers/integrations/mistral.py | 105 ++++++++++++++++++ .../mistral/convert_mistral_weights_to_hf.py | 68 ++++++------ .../pixtral/convert_pixtral_weights_to_hf.py | 73 ------------ 3 files changed, 142 insertions(+), 104 deletions(-) create mode 100644 src/transformers/integrations/mistral.py diff --git a/src/transformers/integrations/mistral.py b/src/transformers/integrations/mistral.py new file mode 100644 index 000000000..781723292 --- /dev/null +++ b/src/transformers/integrations/mistral.py @@ -0,0 +1,105 @@ +from tokenizers import Regex, Tokenizer, decoders, pre_tokenizers, processors +from tokenizers.models import BPE + +from transformers import LlamaTokenizerFast +from transformers.convert_slow_tokenizer import bytes_to_unicode + + +class MistralConverter: + """ + A general tiktoken converter. + """ + + def __init__( + self, + vocab=None, + pattern=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""", + add_prefix_space=False, + additional_special_tokens=None, + *args, + **kwargs, + ): + super().__init__(*args) + self.vocab = vocab + self.pattern = pattern + self.add_prefix_space = add_prefix_space + self.additional_special_tokens = additional_special_tokens + + def extract_vocab_merges_from_model(self, vocab: str): + bpe_ranks = vocab + byte_encoder = bytes_to_unicode() + + def token_bytes_to_string(b): + return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")]) + + merges = [] + vocab = {} + for idx, (token, rank) in enumerate(bpe_ranks.items()): + if token not in self.additional_special_tokens: + vocab[token_bytes_to_string(token)] = idx + if len(token) == 1: + continue + local = [] + for index in range(1, len(token)): + piece_l, piece_r = token[:index], token[index:] + if piece_l in bpe_ranks and piece_r in bpe_ranks and (piece_l + piece_r) in bpe_ranks: + local.append((piece_l, piece_r, rank)) + local = sorted(local, key=lambda x: (bpe_ranks[x[0]], bpe_ranks[x[1]]), reverse=False) + merges.extend(local) + else: + vocab[token] = idx + merges = sorted(merges, key=lambda val: val[2], reverse=False) + merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in merges] + return vocab, merges + + def tokenizer(self): + vocab_scores, merges = self.extract_vocab_merges_from_model(self.vocab) + tokenizer = Tokenizer(BPE(vocab_scores, merges, fuse_unk=False)) + if hasattr(tokenizer.model, "ignore_merges"): + tokenizer.model.ignore_merges = True + return tokenizer + + def converted(self) -> Tokenizer: + tokenizer = self.tokenizer() + tokenizer.pre_tokenizer = pre_tokenizers.Sequence( + [ + pre_tokenizers.Split(Regex(self.pattern), behavior="isolated", invert=False), + pre_tokenizers.ByteLevel(add_prefix_space=self.add_prefix_space, use_regex=False), + ] + ) + tokenizer.decoder = decoders.ByteLevel() + tokenizer.add_special_tokens(self.additional_special_tokens) + + tokenizer.post_processor = processors.ByteLevel(trim_offsets=False) + + return tokenizer + + +def convert_tekken_tokenizer(tokenizer_file: str): + """Convert a "tekken" tokenizer to a fast Tokenizer.""" + # Tekken format -- need to use the Converter + + from mistral_common.tokens.tokenizers.mistral import MistralTokenizer + + # Load directly using their lib + mistral_tokenizer = MistralTokenizer.from_file(tokenizer_file) + + # Extract vocab and special tokens + vocab = mistral_tokenizer.instruct_tokenizer.tokenizer._tekken_token2id_nospecial + all_special = [ + token.value if hasattr(token, "value") else token + for token in mistral_tokenizer.instruct_tokenizer.tokenizer._all_special_tokens + ] + specials_tokens = {token: all_special.index(token) for token in all_special} + specials_tokens.update(vocab) + vocab = specials_tokens + + # Convert + tokenizer = LlamaTokenizerFast( + tokenizer_object=MistralConverter(vocab=vocab, additional_special_tokens=all_special).converted(), + ) + + # Post-process + tokenizer.add_special_tokens({"additional_special_tokens": all_special}) + + return tokenizer diff --git a/src/transformers/models/mistral/convert_mistral_weights_to_hf.py b/src/transformers/models/mistral/convert_mistral_weights_to_hf.py index 1a89ade8f..1fc4ad90e 100644 --- a/src/transformers/models/mistral/convert_mistral_weights_to_hf.py +++ b/src/transformers/models/mistral/convert_mistral_weights_to_hf.py @@ -15,25 +15,14 @@ import argparse import json import os import re -import warnings import torch from safetensors.torch import load_file -from transformers import LlamaTokenizer, MistralConfig, MistralForCausalLM +from transformers import AutoTokenizer, LlamaTokenizerFast, MistralConfig, MistralForCausalLM +from transformers.integrations.mistral import convert_tekken_tokenizer -try: - from transformers import LlamaTokenizerFast - - tokenizer_class = LlamaTokenizerFast -except ImportError as e: - warnings.warn(e) - warnings.warn( - "The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion" - ) - tokenizer_class = LlamaTokenizer - # fmt: off STATE_DICT_MAPPING = { # CausalLM keys @@ -87,23 +76,24 @@ def convert_state_dict(original_state_dict: dict, config: MistralConfig): """Convert a state dict file, when a single `nn.Module` is never sharded in different files (usual case).""" new_dict = {} - n_heads = config.num_attention_heads - dim = config.hidden_size - dims_per_head = dim // n_heads + num_attention_heads = config.num_attention_heads + hidden_size = config.hidden_size + head_dim = config.head_dim num_key_value_heads = config.num_key_value_heads - key_value_dim = dims_per_head * num_key_value_heads + key_value_dim = head_dim * num_key_value_heads + query_dim = head_dim * num_attention_heads for old_key, tensor in original_state_dict.items(): new_key = map_old_key_to_new(old_key) if "q_proj" in new_key: - tensor = tensor.view(n_heads, dims_per_head, dim).reshape(dim, dim) - tensor = permute_for_rope(tensor, n_heads, dim, dim) + tensor = tensor.view(num_attention_heads, head_dim, hidden_size).reshape(query_dim, hidden_size) + tensor = permute_for_rope(tensor, num_attention_heads, query_dim, hidden_size) elif "k_proj" in new_key: - tensor = tensor.view(num_key_value_heads, dims_per_head, dim).reshape(key_value_dim, dim) - tensor = permute_for_rope(tensor, num_key_value_heads, key_value_dim, dim) + tensor = tensor.view(num_key_value_heads, head_dim, hidden_size).reshape(key_value_dim, hidden_size) + tensor = permute_for_rope(tensor, num_key_value_heads, key_value_dim, hidden_size) elif "v_proj" in new_key: - tensor = tensor.view(num_key_value_heads, dims_per_head, dim).reshape(key_value_dim, dim) + tensor = tensor.view(num_key_value_heads, head_dim, hidden_size).reshape(key_value_dim, hidden_size) new_dict[new_key] = tensor return new_dict @@ -169,7 +159,7 @@ def convert_state_dict_sharded(loaded_shards: list[dict], config: MistralConfig) return new_dict -def convert_config(original_config: dict, max_position_embeddings: int): +def convert_config(original_config: dict, max_position_embeddings: int = 32768): key_mapping = { "hidden_size": "dim", "num_hidden_layers": "n_layers", @@ -191,9 +181,7 @@ def convert_config(original_config: dict, max_position_embeddings: int): "n_kv_heads", new_config_kwargs["num_attention_heads"] ) new_config_kwargs["rope_theta"] = original_config.get("rope_theta", 10000.0) - - # This is never provided in `params.json`, we provide it manually - new_config_kwargs["max_position_embeddings"] = max_position_embeddings + new_config_kwargs["max_position_embeddings"] = original_config.get("max_seq_len", max_position_embeddings) # This may sometimes be a string in `params.json` if new_config_kwargs["sliding_window"] is not None: @@ -230,11 +218,23 @@ def convert_and_write_model(input_dir: str, output_dir: str, max_position_embedd model.save_pretrained(output_dir) -def convert_and_write_tokenizer(input_dir: str, output_dir: str): +def convert_and_write_tokenizer(input_dir: str, output_dir: str, tokenizer_template_name: str = ""): """Convert the tokenizer and save it.""" - # May have .v3 or .v7 at the end - tokenizer_file = [file for file in os.listdir(input_dir) if "tokenizer.model" in file][0] - tokenizer = tokenizer_class(os.path.join(input_dir, tokenizer_file)) + # Tekken format + if "tekken.json" in os.listdir(input_dir): + tokenizer_file = os.path.join(input_dir, "tekken.json") + tokenizer = convert_tekken_tokenizer(tokenizer_file) + else: + # May have .v3 or .v7 at the end + tokenizer_file = [file for file in os.listdir(input_dir) if "tokenizer.model" in file][0] + tokenizer = LlamaTokenizerFast(os.path.join(input_dir, tokenizer_file)) + + # Load a chat template from another model + if tokenizer_template_name != "": + template_tok = AutoTokenizer.from_pretrained(tokenizer_template_name) + tokenizer.chat_template = template_tok.chat_template + + # Finally save it tokenizer.save_pretrained(output_dir) @@ -248,6 +248,12 @@ def main(): "output_dir", help="Location to write HF model and tokenizer", ) + parser.add_argument( + "--template_name", + type=str, + default="", + help="Another model name from which to copy the chat template.", + ) parser.add_argument( "--max_position_embeddings", type=int, @@ -269,7 +275,7 @@ def main(): if not args.tokenizer_only: convert_and_write_model(args.input_dir, args.output_dir, args.max_position_embeddings, args.modules_are_split) - convert_and_write_tokenizer(args.input_dir, args.output_dir) + convert_and_write_tokenizer(args.input_dir, args.output_dir, args.template_name) if __name__ == "__main__": diff --git a/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py b/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py index a8b3ae502..ee1f1e9eb 100644 --- a/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py +++ b/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py @@ -19,8 +19,6 @@ import regex as re import torch from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from safetensors.torch import load_file as safe_load_file -from tokenizers import Regex, Tokenizer, decoders, pre_tokenizers, processors -from tokenizers.models import BPE from transformers import ( LlavaConfig, @@ -30,7 +28,6 @@ from transformers import ( PixtralProcessor, PixtralVisionConfig, ) -from transformers.convert_slow_tokenizer import bytes_to_unicode """ @@ -87,76 +84,6 @@ OLD_KEY_TO_NEW_KEY_MAPPING = { } -class MistralConverter: - """ - A general tiktoken converter. - """ - - def __init__( - self, - vocab=None, - pattern=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""", - add_prefix_space=False, - additional_special_tokens=None, - *args, - **kwargs, - ): - super().__init__(*args) - self.vocab = vocab - self.pattern = pattern - self.add_prefix_space = add_prefix_space - self.additional_special_tokens = additional_special_tokens - - def extract_vocab_merges_from_model(self, vocab: str): - bpe_ranks = vocab - byte_encoder = bytes_to_unicode() - - def token_bytes_to_string(b): - return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")]) - - merges = [] - vocab = {} - for idx, (token, rank) in enumerate(bpe_ranks.items()): - if token not in self.additional_special_tokens: - vocab[token_bytes_to_string(token)] = idx - if len(token) == 1: - continue - local = [] - for index in range(1, len(token)): - piece_l, piece_r = token[:index], token[index:] - if piece_l in bpe_ranks and piece_r in bpe_ranks and (piece_l + piece_r) in bpe_ranks: - local.append((piece_l, piece_r, rank)) - local = sorted(local, key=lambda x: (bpe_ranks[x[0]], bpe_ranks[x[1]]), reverse=False) - merges.extend(local) - else: - vocab[token] = idx - merges = sorted(merges, key=lambda val: val[2], reverse=False) - merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in merges] - return vocab, merges - - def tokenizer(self): - vocab_scores, merges = self.extract_vocab_merges_from_model(self.vocab) - tokenizer = Tokenizer(BPE(vocab_scores, merges, fuse_unk=False)) - if hasattr(tokenizer.model, "ignore_merges"): - tokenizer.model.ignore_merges = True - return tokenizer - - def converted(self) -> Tokenizer: - tokenizer = self.tokenizer() - tokenizer.pre_tokenizer = pre_tokenizers.Sequence( - [ - pre_tokenizers.Split(Regex(self.pattern), behavior="isolated", invert=False), - pre_tokenizers.ByteLevel(add_prefix_space=self.add_prefix_space, use_regex=False), - ] - ) - tokenizer.decoder = decoders.ByteLevel() - tokenizer.add_special_tokens(self.additional_special_tokens) - - tokenizer.post_processor = processors.ByteLevel(trim_offsets=False) - - return tokenizer - - def convert_mistral_tokenizer(model_file): from transformers import LlamaTokenizer