mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Update Mistral converter (#35967)
* Update convert_mistral_weights_to_hf.py * Update convert_mistral_weights_to_hf.py * update * style * move it to integrations * style * trigger CIs * trigger CIs
This commit is contained in:
parent
b1954fd64a
commit
ad30598923
3 changed files with 142 additions and 104 deletions
105
src/transformers/integrations/mistral.py
Normal file
105
src/transformers/integrations/mistral.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
from tokenizers import Regex, Tokenizer, decoders, pre_tokenizers, processors
|
||||
from tokenizers.models import BPE
|
||||
|
||||
from transformers import LlamaTokenizerFast
|
||||
from transformers.convert_slow_tokenizer import bytes_to_unicode
|
||||
|
||||
|
||||
class MistralConverter:
|
||||
"""
|
||||
A general tiktoken converter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab=None,
|
||||
pattern=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
|
||||
add_prefix_space=False,
|
||||
additional_special_tokens=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args)
|
||||
self.vocab = vocab
|
||||
self.pattern = pattern
|
||||
self.add_prefix_space = add_prefix_space
|
||||
self.additional_special_tokens = additional_special_tokens
|
||||
|
||||
def extract_vocab_merges_from_model(self, vocab: str):
|
||||
bpe_ranks = vocab
|
||||
byte_encoder = bytes_to_unicode()
|
||||
|
||||
def token_bytes_to_string(b):
|
||||
return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
|
||||
|
||||
merges = []
|
||||
vocab = {}
|
||||
for idx, (token, rank) in enumerate(bpe_ranks.items()):
|
||||
if token not in self.additional_special_tokens:
|
||||
vocab[token_bytes_to_string(token)] = idx
|
||||
if len(token) == 1:
|
||||
continue
|
||||
local = []
|
||||
for index in range(1, len(token)):
|
||||
piece_l, piece_r = token[:index], token[index:]
|
||||
if piece_l in bpe_ranks and piece_r in bpe_ranks and (piece_l + piece_r) in bpe_ranks:
|
||||
local.append((piece_l, piece_r, rank))
|
||||
local = sorted(local, key=lambda x: (bpe_ranks[x[0]], bpe_ranks[x[1]]), reverse=False)
|
||||
merges.extend(local)
|
||||
else:
|
||||
vocab[token] = idx
|
||||
merges = sorted(merges, key=lambda val: val[2], reverse=False)
|
||||
merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in merges]
|
||||
return vocab, merges
|
||||
|
||||
def tokenizer(self):
|
||||
vocab_scores, merges = self.extract_vocab_merges_from_model(self.vocab)
|
||||
tokenizer = Tokenizer(BPE(vocab_scores, merges, fuse_unk=False))
|
||||
if hasattr(tokenizer.model, "ignore_merges"):
|
||||
tokenizer.model.ignore_merges = True
|
||||
return tokenizer
|
||||
|
||||
def converted(self) -> Tokenizer:
|
||||
tokenizer = self.tokenizer()
|
||||
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
|
||||
[
|
||||
pre_tokenizers.Split(Regex(self.pattern), behavior="isolated", invert=False),
|
||||
pre_tokenizers.ByteLevel(add_prefix_space=self.add_prefix_space, use_regex=False),
|
||||
]
|
||||
)
|
||||
tokenizer.decoder = decoders.ByteLevel()
|
||||
tokenizer.add_special_tokens(self.additional_special_tokens)
|
||||
|
||||
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def convert_tekken_tokenizer(tokenizer_file: str):
|
||||
"""Convert a "tekken" tokenizer to a fast Tokenizer."""
|
||||
# Tekken format -- need to use the Converter
|
||||
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
# Load directly using their lib
|
||||
mistral_tokenizer = MistralTokenizer.from_file(tokenizer_file)
|
||||
|
||||
# Extract vocab and special tokens
|
||||
vocab = mistral_tokenizer.instruct_tokenizer.tokenizer._tekken_token2id_nospecial
|
||||
all_special = [
|
||||
token.value if hasattr(token, "value") else token
|
||||
for token in mistral_tokenizer.instruct_tokenizer.tokenizer._all_special_tokens
|
||||
]
|
||||
specials_tokens = {token: all_special.index(token) for token in all_special}
|
||||
specials_tokens.update(vocab)
|
||||
vocab = specials_tokens
|
||||
|
||||
# Convert
|
||||
tokenizer = LlamaTokenizerFast(
|
||||
tokenizer_object=MistralConverter(vocab=vocab, additional_special_tokens=all_special).converted(),
|
||||
)
|
||||
|
||||
# Post-process
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": all_special})
|
||||
|
||||
return tokenizer
|
||||
|
|
@ -15,25 +15,14 @@ import argparse
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from transformers import LlamaTokenizer, MistralConfig, MistralForCausalLM
|
||||
from transformers import AutoTokenizer, LlamaTokenizerFast, MistralConfig, MistralForCausalLM
|
||||
from transformers.integrations.mistral import convert_tekken_tokenizer
|
||||
|
||||
|
||||
try:
|
||||
from transformers import LlamaTokenizerFast
|
||||
|
||||
tokenizer_class = LlamaTokenizerFast
|
||||
except ImportError as e:
|
||||
warnings.warn(e)
|
||||
warnings.warn(
|
||||
"The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
|
||||
)
|
||||
tokenizer_class = LlamaTokenizer
|
||||
|
||||
# fmt: off
|
||||
STATE_DICT_MAPPING = {
|
||||
# CausalLM keys
|
||||
|
|
@ -87,23 +76,24 @@ def convert_state_dict(original_state_dict: dict, config: MistralConfig):
|
|||
"""Convert a state dict file, when a single `nn.Module` is never sharded in different files (usual case)."""
|
||||
new_dict = {}
|
||||
|
||||
n_heads = config.num_attention_heads
|
||||
dim = config.hidden_size
|
||||
dims_per_head = dim // n_heads
|
||||
num_attention_heads = config.num_attention_heads
|
||||
hidden_size = config.hidden_size
|
||||
head_dim = config.head_dim
|
||||
num_key_value_heads = config.num_key_value_heads
|
||||
key_value_dim = dims_per_head * num_key_value_heads
|
||||
key_value_dim = head_dim * num_key_value_heads
|
||||
query_dim = head_dim * num_attention_heads
|
||||
|
||||
for old_key, tensor in original_state_dict.items():
|
||||
new_key = map_old_key_to_new(old_key)
|
||||
|
||||
if "q_proj" in new_key:
|
||||
tensor = tensor.view(n_heads, dims_per_head, dim).reshape(dim, dim)
|
||||
tensor = permute_for_rope(tensor, n_heads, dim, dim)
|
||||
tensor = tensor.view(num_attention_heads, head_dim, hidden_size).reshape(query_dim, hidden_size)
|
||||
tensor = permute_for_rope(tensor, num_attention_heads, query_dim, hidden_size)
|
||||
elif "k_proj" in new_key:
|
||||
tensor = tensor.view(num_key_value_heads, dims_per_head, dim).reshape(key_value_dim, dim)
|
||||
tensor = permute_for_rope(tensor, num_key_value_heads, key_value_dim, dim)
|
||||
tensor = tensor.view(num_key_value_heads, head_dim, hidden_size).reshape(key_value_dim, hidden_size)
|
||||
tensor = permute_for_rope(tensor, num_key_value_heads, key_value_dim, hidden_size)
|
||||
elif "v_proj" in new_key:
|
||||
tensor = tensor.view(num_key_value_heads, dims_per_head, dim).reshape(key_value_dim, dim)
|
||||
tensor = tensor.view(num_key_value_heads, head_dim, hidden_size).reshape(key_value_dim, hidden_size)
|
||||
|
||||
new_dict[new_key] = tensor
|
||||
return new_dict
|
||||
|
|
@ -169,7 +159,7 @@ def convert_state_dict_sharded(loaded_shards: list[dict], config: MistralConfig)
|
|||
return new_dict
|
||||
|
||||
|
||||
def convert_config(original_config: dict, max_position_embeddings: int):
|
||||
def convert_config(original_config: dict, max_position_embeddings: int = 32768):
|
||||
key_mapping = {
|
||||
"hidden_size": "dim",
|
||||
"num_hidden_layers": "n_layers",
|
||||
|
|
@ -191,9 +181,7 @@ def convert_config(original_config: dict, max_position_embeddings: int):
|
|||
"n_kv_heads", new_config_kwargs["num_attention_heads"]
|
||||
)
|
||||
new_config_kwargs["rope_theta"] = original_config.get("rope_theta", 10000.0)
|
||||
|
||||
# This is never provided in `params.json`, we provide it manually
|
||||
new_config_kwargs["max_position_embeddings"] = max_position_embeddings
|
||||
new_config_kwargs["max_position_embeddings"] = original_config.get("max_seq_len", max_position_embeddings)
|
||||
|
||||
# This may sometimes be a string in `params.json`
|
||||
if new_config_kwargs["sliding_window"] is not None:
|
||||
|
|
@ -230,11 +218,23 @@ def convert_and_write_model(input_dir: str, output_dir: str, max_position_embedd
|
|||
model.save_pretrained(output_dir)
|
||||
|
||||
|
||||
def convert_and_write_tokenizer(input_dir: str, output_dir: str):
|
||||
def convert_and_write_tokenizer(input_dir: str, output_dir: str, tokenizer_template_name: str = ""):
|
||||
"""Convert the tokenizer and save it."""
|
||||
# May have .v3 or .v7 at the end
|
||||
tokenizer_file = [file for file in os.listdir(input_dir) if "tokenizer.model" in file][0]
|
||||
tokenizer = tokenizer_class(os.path.join(input_dir, tokenizer_file))
|
||||
# Tekken format
|
||||
if "tekken.json" in os.listdir(input_dir):
|
||||
tokenizer_file = os.path.join(input_dir, "tekken.json")
|
||||
tokenizer = convert_tekken_tokenizer(tokenizer_file)
|
||||
else:
|
||||
# May have .v3 or .v7 at the end
|
||||
tokenizer_file = [file for file in os.listdir(input_dir) if "tokenizer.model" in file][0]
|
||||
tokenizer = LlamaTokenizerFast(os.path.join(input_dir, tokenizer_file))
|
||||
|
||||
# Load a chat template from another model
|
||||
if tokenizer_template_name != "":
|
||||
template_tok = AutoTokenizer.from_pretrained(tokenizer_template_name)
|
||||
tokenizer.chat_template = template_tok.chat_template
|
||||
|
||||
# Finally save it
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
|
||||
|
|
@ -248,6 +248,12 @@ def main():
|
|||
"output_dir",
|
||||
help="Location to write HF model and tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--template_name",
|
||||
type=str,
|
||||
default="",
|
||||
help="Another model name from which to copy the chat template.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_position_embeddings",
|
||||
type=int,
|
||||
|
|
@ -269,7 +275,7 @@ def main():
|
|||
|
||||
if not args.tokenizer_only:
|
||||
convert_and_write_model(args.input_dir, args.output_dir, args.max_position_embeddings, args.modules_are_split)
|
||||
convert_and_write_tokenizer(args.input_dir, args.output_dir)
|
||||
convert_and_write_tokenizer(args.input_dir, args.output_dir, args.template_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -19,8 +19,6 @@ import regex as re
|
|||
import torch
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
from safetensors.torch import load_file as safe_load_file
|
||||
from tokenizers import Regex, Tokenizer, decoders, pre_tokenizers, processors
|
||||
from tokenizers.models import BPE
|
||||
|
||||
from transformers import (
|
||||
LlavaConfig,
|
||||
|
|
@ -30,7 +28,6 @@ from transformers import (
|
|||
PixtralProcessor,
|
||||
PixtralVisionConfig,
|
||||
)
|
||||
from transformers.convert_slow_tokenizer import bytes_to_unicode
|
||||
|
||||
|
||||
"""
|
||||
|
|
@ -87,76 +84,6 @@ OLD_KEY_TO_NEW_KEY_MAPPING = {
|
|||
}
|
||||
|
||||
|
||||
class MistralConverter:
|
||||
"""
|
||||
A general tiktoken converter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab=None,
|
||||
pattern=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
|
||||
add_prefix_space=False,
|
||||
additional_special_tokens=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args)
|
||||
self.vocab = vocab
|
||||
self.pattern = pattern
|
||||
self.add_prefix_space = add_prefix_space
|
||||
self.additional_special_tokens = additional_special_tokens
|
||||
|
||||
def extract_vocab_merges_from_model(self, vocab: str):
|
||||
bpe_ranks = vocab
|
||||
byte_encoder = bytes_to_unicode()
|
||||
|
||||
def token_bytes_to_string(b):
|
||||
return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
|
||||
|
||||
merges = []
|
||||
vocab = {}
|
||||
for idx, (token, rank) in enumerate(bpe_ranks.items()):
|
||||
if token not in self.additional_special_tokens:
|
||||
vocab[token_bytes_to_string(token)] = idx
|
||||
if len(token) == 1:
|
||||
continue
|
||||
local = []
|
||||
for index in range(1, len(token)):
|
||||
piece_l, piece_r = token[:index], token[index:]
|
||||
if piece_l in bpe_ranks and piece_r in bpe_ranks and (piece_l + piece_r) in bpe_ranks:
|
||||
local.append((piece_l, piece_r, rank))
|
||||
local = sorted(local, key=lambda x: (bpe_ranks[x[0]], bpe_ranks[x[1]]), reverse=False)
|
||||
merges.extend(local)
|
||||
else:
|
||||
vocab[token] = idx
|
||||
merges = sorted(merges, key=lambda val: val[2], reverse=False)
|
||||
merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in merges]
|
||||
return vocab, merges
|
||||
|
||||
def tokenizer(self):
|
||||
vocab_scores, merges = self.extract_vocab_merges_from_model(self.vocab)
|
||||
tokenizer = Tokenizer(BPE(vocab_scores, merges, fuse_unk=False))
|
||||
if hasattr(tokenizer.model, "ignore_merges"):
|
||||
tokenizer.model.ignore_merges = True
|
||||
return tokenizer
|
||||
|
||||
def converted(self) -> Tokenizer:
|
||||
tokenizer = self.tokenizer()
|
||||
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
|
||||
[
|
||||
pre_tokenizers.Split(Regex(self.pattern), behavior="isolated", invert=False),
|
||||
pre_tokenizers.ByteLevel(add_prefix_space=self.add_prefix_space, use_regex=False),
|
||||
]
|
||||
)
|
||||
tokenizer.decoder = decoders.ByteLevel()
|
||||
tokenizer.add_special_tokens(self.additional_special_tokens)
|
||||
|
||||
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def convert_mistral_tokenizer(model_file):
|
||||
from transformers import LlamaTokenizer
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue