Update Mistral converter (#35967)

* Update convert_mistral_weights_to_hf.py

* Update convert_mistral_weights_to_hf.py

* update

* style

* move it to integrations

* style

* trigger CIs

* trigger CIs
This commit is contained in:
Cyril Vallez 2025-02-04 11:13:12 +01:00 committed by GitHub
parent b1954fd64a
commit ad30598923
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 142 additions and 104 deletions

View file

@ -0,0 +1,105 @@
from tokenizers import Regex, Tokenizer, decoders, pre_tokenizers, processors
from tokenizers.models import BPE
from transformers import LlamaTokenizerFast
from transformers.convert_slow_tokenizer import bytes_to_unicode
class MistralConverter:
"""
A general tiktoken converter.
"""
def __init__(
self,
vocab=None,
pattern=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
add_prefix_space=False,
additional_special_tokens=None,
*args,
**kwargs,
):
super().__init__(*args)
self.vocab = vocab
self.pattern = pattern
self.add_prefix_space = add_prefix_space
self.additional_special_tokens = additional_special_tokens
def extract_vocab_merges_from_model(self, vocab: str):
bpe_ranks = vocab
byte_encoder = bytes_to_unicode()
def token_bytes_to_string(b):
return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
merges = []
vocab = {}
for idx, (token, rank) in enumerate(bpe_ranks.items()):
if token not in self.additional_special_tokens:
vocab[token_bytes_to_string(token)] = idx
if len(token) == 1:
continue
local = []
for index in range(1, len(token)):
piece_l, piece_r = token[:index], token[index:]
if piece_l in bpe_ranks and piece_r in bpe_ranks and (piece_l + piece_r) in bpe_ranks:
local.append((piece_l, piece_r, rank))
local = sorted(local, key=lambda x: (bpe_ranks[x[0]], bpe_ranks[x[1]]), reverse=False)
merges.extend(local)
else:
vocab[token] = idx
merges = sorted(merges, key=lambda val: val[2], reverse=False)
merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in merges]
return vocab, merges
def tokenizer(self):
vocab_scores, merges = self.extract_vocab_merges_from_model(self.vocab)
tokenizer = Tokenizer(BPE(vocab_scores, merges, fuse_unk=False))
if hasattr(tokenizer.model, "ignore_merges"):
tokenizer.model.ignore_merges = True
return tokenizer
def converted(self) -> Tokenizer:
tokenizer = self.tokenizer()
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
[
pre_tokenizers.Split(Regex(self.pattern), behavior="isolated", invert=False),
pre_tokenizers.ByteLevel(add_prefix_space=self.add_prefix_space, use_regex=False),
]
)
tokenizer.decoder = decoders.ByteLevel()
tokenizer.add_special_tokens(self.additional_special_tokens)
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
return tokenizer
def convert_tekken_tokenizer(tokenizer_file: str):
"""Convert a "tekken" tokenizer to a fast Tokenizer."""
# Tekken format -- need to use the Converter
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
# Load directly using their lib
mistral_tokenizer = MistralTokenizer.from_file(tokenizer_file)
# Extract vocab and special tokens
vocab = mistral_tokenizer.instruct_tokenizer.tokenizer._tekken_token2id_nospecial
all_special = [
token.value if hasattr(token, "value") else token
for token in mistral_tokenizer.instruct_tokenizer.tokenizer._all_special_tokens
]
specials_tokens = {token: all_special.index(token) for token in all_special}
specials_tokens.update(vocab)
vocab = specials_tokens
# Convert
tokenizer = LlamaTokenizerFast(
tokenizer_object=MistralConverter(vocab=vocab, additional_special_tokens=all_special).converted(),
)
# Post-process
tokenizer.add_special_tokens({"additional_special_tokens": all_special})
return tokenizer

View file

@ -15,25 +15,14 @@ import argparse
import json
import os
import re
import warnings
import torch
from safetensors.torch import load_file
from transformers import LlamaTokenizer, MistralConfig, MistralForCausalLM
from transformers import AutoTokenizer, LlamaTokenizerFast, MistralConfig, MistralForCausalLM
from transformers.integrations.mistral import convert_tekken_tokenizer
try:
from transformers import LlamaTokenizerFast
tokenizer_class = LlamaTokenizerFast
except ImportError as e:
warnings.warn(e)
warnings.warn(
"The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
)
tokenizer_class = LlamaTokenizer
# fmt: off
STATE_DICT_MAPPING = {
# CausalLM keys
@ -87,23 +76,24 @@ def convert_state_dict(original_state_dict: dict, config: MistralConfig):
"""Convert a state dict file, when a single `nn.Module` is never sharded in different files (usual case)."""
new_dict = {}
n_heads = config.num_attention_heads
dim = config.hidden_size
dims_per_head = dim // n_heads
num_attention_heads = config.num_attention_heads
hidden_size = config.hidden_size
head_dim = config.head_dim
num_key_value_heads = config.num_key_value_heads
key_value_dim = dims_per_head * num_key_value_heads
key_value_dim = head_dim * num_key_value_heads
query_dim = head_dim * num_attention_heads
for old_key, tensor in original_state_dict.items():
new_key = map_old_key_to_new(old_key)
if "q_proj" in new_key:
tensor = tensor.view(n_heads, dims_per_head, dim).reshape(dim, dim)
tensor = permute_for_rope(tensor, n_heads, dim, dim)
tensor = tensor.view(num_attention_heads, head_dim, hidden_size).reshape(query_dim, hidden_size)
tensor = permute_for_rope(tensor, num_attention_heads, query_dim, hidden_size)
elif "k_proj" in new_key:
tensor = tensor.view(num_key_value_heads, dims_per_head, dim).reshape(key_value_dim, dim)
tensor = permute_for_rope(tensor, num_key_value_heads, key_value_dim, dim)
tensor = tensor.view(num_key_value_heads, head_dim, hidden_size).reshape(key_value_dim, hidden_size)
tensor = permute_for_rope(tensor, num_key_value_heads, key_value_dim, hidden_size)
elif "v_proj" in new_key:
tensor = tensor.view(num_key_value_heads, dims_per_head, dim).reshape(key_value_dim, dim)
tensor = tensor.view(num_key_value_heads, head_dim, hidden_size).reshape(key_value_dim, hidden_size)
new_dict[new_key] = tensor
return new_dict
@ -169,7 +159,7 @@ def convert_state_dict_sharded(loaded_shards: list[dict], config: MistralConfig)
return new_dict
def convert_config(original_config: dict, max_position_embeddings: int):
def convert_config(original_config: dict, max_position_embeddings: int = 32768):
key_mapping = {
"hidden_size": "dim",
"num_hidden_layers": "n_layers",
@ -191,9 +181,7 @@ def convert_config(original_config: dict, max_position_embeddings: int):
"n_kv_heads", new_config_kwargs["num_attention_heads"]
)
new_config_kwargs["rope_theta"] = original_config.get("rope_theta", 10000.0)
# This is never provided in `params.json`, we provide it manually
new_config_kwargs["max_position_embeddings"] = max_position_embeddings
new_config_kwargs["max_position_embeddings"] = original_config.get("max_seq_len", max_position_embeddings)
# This may sometimes be a string in `params.json`
if new_config_kwargs["sliding_window"] is not None:
@ -230,11 +218,23 @@ def convert_and_write_model(input_dir: str, output_dir: str, max_position_embedd
model.save_pretrained(output_dir)
def convert_and_write_tokenizer(input_dir: str, output_dir: str):
def convert_and_write_tokenizer(input_dir: str, output_dir: str, tokenizer_template_name: str = ""):
"""Convert the tokenizer and save it."""
# May have .v3 or .v7 at the end
tokenizer_file = [file for file in os.listdir(input_dir) if "tokenizer.model" in file][0]
tokenizer = tokenizer_class(os.path.join(input_dir, tokenizer_file))
# Tekken format
if "tekken.json" in os.listdir(input_dir):
tokenizer_file = os.path.join(input_dir, "tekken.json")
tokenizer = convert_tekken_tokenizer(tokenizer_file)
else:
# May have .v3 or .v7 at the end
tokenizer_file = [file for file in os.listdir(input_dir) if "tokenizer.model" in file][0]
tokenizer = LlamaTokenizerFast(os.path.join(input_dir, tokenizer_file))
# Load a chat template from another model
if tokenizer_template_name != "":
template_tok = AutoTokenizer.from_pretrained(tokenizer_template_name)
tokenizer.chat_template = template_tok.chat_template
# Finally save it
tokenizer.save_pretrained(output_dir)
@ -248,6 +248,12 @@ def main():
"output_dir",
help="Location to write HF model and tokenizer",
)
parser.add_argument(
"--template_name",
type=str,
default="",
help="Another model name from which to copy the chat template.",
)
parser.add_argument(
"--max_position_embeddings",
type=int,
@ -269,7 +275,7 @@ def main():
if not args.tokenizer_only:
convert_and_write_model(args.input_dir, args.output_dir, args.max_position_embeddings, args.modules_are_split)
convert_and_write_tokenizer(args.input_dir, args.output_dir)
convert_and_write_tokenizer(args.input_dir, args.output_dir, args.template_name)
if __name__ == "__main__":

View file

@ -19,8 +19,6 @@ import regex as re
import torch
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from safetensors.torch import load_file as safe_load_file
from tokenizers import Regex, Tokenizer, decoders, pre_tokenizers, processors
from tokenizers.models import BPE
from transformers import (
LlavaConfig,
@ -30,7 +28,6 @@ from transformers import (
PixtralProcessor,
PixtralVisionConfig,
)
from transformers.convert_slow_tokenizer import bytes_to_unicode
"""
@ -87,76 +84,6 @@ OLD_KEY_TO_NEW_KEY_MAPPING = {
}
class MistralConverter:
"""
A general tiktoken converter.
"""
def __init__(
self,
vocab=None,
pattern=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
add_prefix_space=False,
additional_special_tokens=None,
*args,
**kwargs,
):
super().__init__(*args)
self.vocab = vocab
self.pattern = pattern
self.add_prefix_space = add_prefix_space
self.additional_special_tokens = additional_special_tokens
def extract_vocab_merges_from_model(self, vocab: str):
bpe_ranks = vocab
byte_encoder = bytes_to_unicode()
def token_bytes_to_string(b):
return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
merges = []
vocab = {}
for idx, (token, rank) in enumerate(bpe_ranks.items()):
if token not in self.additional_special_tokens:
vocab[token_bytes_to_string(token)] = idx
if len(token) == 1:
continue
local = []
for index in range(1, len(token)):
piece_l, piece_r = token[:index], token[index:]
if piece_l in bpe_ranks and piece_r in bpe_ranks and (piece_l + piece_r) in bpe_ranks:
local.append((piece_l, piece_r, rank))
local = sorted(local, key=lambda x: (bpe_ranks[x[0]], bpe_ranks[x[1]]), reverse=False)
merges.extend(local)
else:
vocab[token] = idx
merges = sorted(merges, key=lambda val: val[2], reverse=False)
merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in merges]
return vocab, merges
def tokenizer(self):
vocab_scores, merges = self.extract_vocab_merges_from_model(self.vocab)
tokenizer = Tokenizer(BPE(vocab_scores, merges, fuse_unk=False))
if hasattr(tokenizer.model, "ignore_merges"):
tokenizer.model.ignore_merges = True
return tokenizer
def converted(self) -> Tokenizer:
tokenizer = self.tokenizer()
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
[
pre_tokenizers.Split(Regex(self.pattern), behavior="isolated", invert=False),
pre_tokenizers.ByteLevel(add_prefix_space=self.add_prefix_space, use_regex=False),
]
)
tokenizer.decoder = decoders.ByteLevel()
tokenizer.add_special_tokens(self.additional_special_tokens)
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
return tokenizer
def convert_mistral_tokenizer(model_file):
from transformers import LlamaTokenizer