From d5a99dfcee6e94065cb7c83cc8ab6fc5daa0cc4e Mon Sep 17 00:00:00 2001 From: Lysandre Date: Tue, 23 Jul 2024 16:58:17 +0200 Subject: [PATCH] Llama 3.1 conversion Co-authored-by: Arthur Zucker --- src/transformers/modeling_rope_utils.py | 103 +++++++++++- .../models/llama/configuration_llama.py | 31 ++-- .../llama/convert_llama_weights_to_hf.py | 151 +++++++++++++----- 3 files changed, 227 insertions(+), 58 deletions(-) diff --git a/src/transformers/modeling_rope_utils.py b/src/transformers/modeling_rope_utils.py index 33055d2bf..14a12b939 100644 --- a/src/transformers/modeling_rope_utils.py +++ b/src/transformers/modeling_rope_utils.py @@ -129,6 +129,7 @@ def _compute_dynamic_ntk_parameters( Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). """ + # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling if config is not None and len(rope_kwargs) > 0: raise ValueError( "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " @@ -249,6 +250,7 @@ def _compute_longrope_parameters( Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin. """ + # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling # No need to keep BC with longrope, unreleased when this new pattern was created. if len(rope_kwargs) > 0: raise ValueError( @@ -293,6 +295,50 @@ def _compute_longrope_parameters( return inv_freq, attention_factor +def _compute_llama3_parameters( + config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs +) -> Tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies for llama 3.1. + + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin. + """ + # Gets the default RoPE parameters + inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs) + + factor = config.rope_scaling["factor"] # `8` in the original implementation + low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation + high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation + old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + new_freqs = [] + for freq in inv_freq: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / factor) + else: + assert low_freq_wavelen != high_freq_wavelen + smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + new_freqs.append((1 - smooth) * freq / factor + smooth * freq) + inv_freq = torch.tensor(new_freqs, dtype=inv_freq.dtype, device=inv_freq.device) + return inv_freq, attention_factor + + # This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters # from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE # parameterizations, as long as the callable has the same signature. @@ -302,6 +348,7 @@ ROPE_INIT_FUNCTIONS = { "dynamic": _compute_dynamic_ntk_parameters, "yarn": _compute_yarn_parameters, "longrope": _compute_longrope_parameters, + "llama3": _compute_llama3_parameters, } @@ -339,6 +386,20 @@ def _validate_linear_scaling_rope_parameters(config: PretrainedConfig): raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") +def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig): + rope_scaling = config.rope_scaling + rope_type = rope_scaling["rope_type"] + required_keys = {"rope_type", "factor"} + # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` + optional_keys = {"original_max_position_embeddings"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, optional_keys) + + factor = rope_scaling["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + def _validate_yarn_parameters(config: PretrainedConfig): rope_scaling = config.rope_scaling rope_type = rope_scaling["rope_type"] @@ -374,7 +435,8 @@ def _validate_longrope_parameters(config: PretrainedConfig): rope_scaling = config.rope_scaling rope_type = rope_scaling["rope_type"] required_keys = {"rope_type", "short_factor", "long_factor"} - optional_keys = {"attention_factor", "factor"} + # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` + optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"} received_keys = set(rope_scaling.keys()) _check_received_keys(rope_type, received_keys, required_keys, optional_keys) @@ -417,13 +479,50 @@ def _validate_longrope_parameters(config: PretrainedConfig): ) +def _validate_llama3_parameters(config: PretrainedConfig): + rope_scaling = config.rope_scaling + rope_type = rope_scaling["rope_type"] + required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys) + + factor = rope_scaling["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + low_freq_factor = rope_scaling["low_freq_factor"] + high_freq_factor = rope_scaling["high_freq_factor"] + if low_freq_factor is None or not isinstance(low_freq_factor, float): + raise ValueError(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}") + if high_freq_factor is None or not isinstance(high_freq_factor, float): + raise ValueError(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}") + if high_freq_factor < low_freq_factor: + raise ValueError( + "`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor=" + f"{high_freq_factor} and low_freq_factor={low_freq_factor}" + ) + + original_max_position_embeddings = rope_scaling["original_max_position_embeddings"] + if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int): + raise ValueError( + "`rope_scaling`'s original_max_position_embeddings field must be an integer, got " + f"{original_max_position_embeddings}" + ) + if original_max_position_embeddings >= config.max_position_embeddings: + raise ValueError( + "`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got " + f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}" + ) + + # Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types. ROPE_VALIDATION_FUNCTIONS = { "default": _validate_default_rope_parameters, "linear": _validate_linear_scaling_rope_parameters, - "dynamic": _validate_linear_scaling_rope_parameters, # `dynamic` has the same validation pattern as `linear` + "dynamic": _validate_dynamic_scaling_rope_parameters, "yarn": _validate_yarn_parameters, "longrope": _validate_longrope_parameters, + "llama3": _validate_llama3_parameters, } diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py index 7c987ec85..c632a870b 100644 --- a/src/transformers/models/llama/configuration_llama.py +++ b/src/transformers/models/llama/configuration_llama.py @@ -73,25 +73,28 @@ class LlamaConfig(PretrainedConfig): End of stream token id. pretraining_tp (`int`, *optional*, defaults to 1): Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this - document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to understand more about it. This value is - necessary to ensure exact reproducibility of the pretraining results. Please refer to [this - issue](https://github.com/pytorch/pytorch/issues/76232). + document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to + understand more about it. This value is necessary to ensure exact reproducibility of the pretraining + results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232). tie_word_embeddings (`bool`, *optional*, defaults to `False`): Whether to tie weight embeddings rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. IMPORTANT: RoPE scaling expects - `max_position_embeddings` to remain unchanged -- some methods, like 'longrope', require the original value - to determine which scaling to apply. + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. Expected contents: `rope_type` (`str`): - The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope'], - with 'default' being the original RoPE implementation. + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. `factor` (`float`, *optional*): Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In most scaling types, a `factor` of x will enable the model to handle sequences of length x * - `max_position_embeddings`. + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. `attention_factor` (`float`, *optional*): Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention computation. If unspecified, it defaults to value recommended by the implementation, using the @@ -104,12 +107,16 @@ class LlamaConfig(PretrainedConfig): ramp function. If unspecified, it defaults to 1. `short_factor` (`List[float]`, *optional*): Only used with 'longrope'. The scaling factor to be applied to short contexts (< - `max_position_embeddings` * `factor`). Must be a list of numbers with the same length as the hidden + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2 `long_factor` (`List[float]`, *optional*): - Only used with 'longrope'. The scaling factor to be applied to short contexts (< - `max_position_embeddings` * `factor`). Must be a list of numbers with the same length as the hidden + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE attention_bias (`bool`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): diff --git a/src/transformers/models/llama/convert_llama_weights_to_hf.py b/src/transformers/models/llama/convert_llama_weights_to_hf.py index fd6ab4f2e..384daab6b 100644 --- a/src/transformers/models/llama/convert_llama_weights_to_hf.py +++ b/src/transformers/models/llama/convert_llama_weights_to_hf.py @@ -17,10 +17,11 @@ import json import os import shutil import warnings +from typing import List import torch -from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer, PreTrainedTokenizerFast +from transformers import GenerationConfig, LlamaConfig, LlamaForCausalLM, LlamaTokenizer, PreTrainedTokenizerFast from transformers.convert_slow_tokenizer import TikTokenConverter @@ -85,8 +86,12 @@ NUM_SHARDS = { "65B": 8, "70B": 8, "70Bf": 8, + "405B": 8, + "405B-MP16": 16, } +CONTEXT_LENGTH_FOR_VERSION = {"3.1": 131072, "3": 8192, "2": 4096, "1": 2048} + def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256): return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of) @@ -107,9 +112,10 @@ def write_model( input_base_path, model_size=None, safe_serialization=True, - llama_version=1, + llama_version="1", vocab_size=None, num_shards=None, + instruct=False, ): os.makedirs(model_path, exist_ok=True) tmp_model_path = os.path.join(model_path, "tmp") @@ -125,18 +131,11 @@ def write_model( dims_per_head = dim // n_heads base = params.get("rope_theta", 10000.0) inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) - if base > 10000.0 and llama_version != 3: + if base > 10000.0 and float(llama_version) < 3: max_position_embeddings = 16384 else: - # Depending on the Llama version, the default max_position_embeddings has different values. - if llama_version == 1: - max_position_embeddings = 2048 - elif llama_version == 2: - max_position_embeddings = 4096 - elif llama_version == 3: - max_position_embeddings = 8192 + max_position_embeddings = CONTEXT_LENGTH_FOR_VERSION[llama_version] - vocab_size = vocab_size if vocab_size is not None else 32000 if params.get("n_kv_heads", None) is not None: num_key_value_heads = params["n_kv_heads"] # for GQA / MQA num_key_value_heads_per_shard = num_key_value_heads // num_shards @@ -144,8 +143,7 @@ def write_model( else: # compatibility with other checkpoints num_key_value_heads = n_heads num_key_value_heads_per_shard = n_heads_per_shard - key_value_dim = dims_per_head * num_key_value_heads - print(num_shards, num_key_value_heads, num_key_value_heads_per_shard, key_value_dim) + key_value_dim = dim # permute for sliced rotary def permute(w, n_heads, dim1=dim, dim2=dim): @@ -159,11 +157,9 @@ def write_model( loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu") else: # Sharded - loaded = [ - torch.load(os.path.join(input_base_path, file), map_location="cpu") - for file in sorted(os.listdir(input_base_path)) - if file.endswith(".pth") - ] + checkpoint_list = sorted([file for file in os.listdir(input_base_path) if file.endswith(".pth")]) + print("Loading in order:", checkpoint_list) + loaded = [torch.load(os.path.join(input_base_path, file), map_location="cpu") for file in checkpoint_list] param_count = 0 index_dict = {"weight_map": {}} for layer_i in range(n_layers): @@ -263,7 +259,7 @@ def write_model( "lm_head.weight": loaded["output.weight"], } else: - concat_dim = 0 if llama_version == 3 else 1 + concat_dim = 0 if llama_version in ["3", "3.1"] else 1 state_dict = { "model.norm.weight": loaded[0]["norm.weight"], "model.embed_tokens.weight": torch.cat( @@ -282,6 +278,18 @@ def write_model( write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json")) ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1 multiple_of = params["multiple_of"] if "multiple_of" in params else 256 + + if llama_version in ["3", "3.1"]: + bos_token_id = 128000 + + if instruct: + eos_token_id = [128001, 128008, 128009] + else: + eos_token_id = 128001 + else: + bos_token_id = 1 + eos_token_id = 2 + config = LlamaConfig( hidden_size=dim, intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of), @@ -292,11 +300,21 @@ def write_model( vocab_size=vocab_size, rope_theta=base, max_position_embeddings=max_position_embeddings, - bos_token_id=128000 if llama_version == 3 else 1, - eos_token_id=128001 if llama_version == 3 else 2, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, ) config.save_pretrained(tmp_model_path) + if instruct: + generation_config = GenerationConfig( + do_sample=True, + temperature=0.6, + top_p=0.9, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + ) + generation_config.save_pretrained(tmp_model_path) + # Make space so we can load the model properly now. del state_dict del loaded @@ -313,7 +331,7 @@ def write_model( class Llama3Converter(TikTokenConverter): - def __init__(self, vocab_file, num_reserved_special_tokens=256, **kwargs): + def __init__(self, vocab_file, special_tokens=None, instruct=False, model_max_length=None, **kwargs): super().__init__(vocab_file, **kwargs) tokenizer = self.converted() chat_template = ( @@ -327,34 +345,24 @@ class Llama3Converter(TikTokenConverter): "{% endfor %}" "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}" ) - num_reserved_special_tokens = 256 - special_tokens = [ - "<|begin_of_text|>", - "<|end_of_text|>", - "<|reserved_special_token_0|>", - "<|reserved_special_token_1|>", - "<|reserved_special_token_2|>", - "<|reserved_special_token_3|>", - "<|start_header_id|>", - "<|end_header_id|>", - "<|reserved_special_token_4|>", - "<|eot_id|>", # end of turn - ] + [f"<|reserved_special_token_{i}|>" for i in range(5, num_reserved_special_tokens - 5)] tokenizer.add_special_tokens(special_tokens) self.tokenizer = PreTrainedTokenizerFast( tokenizer_object=tokenizer, bos_token="<|begin_of_text|>", - eos_token="<|end_of_text|>", - chat_template=chat_template, + eos_token="<|end_of_text|>" if not instruct else "<|eot_id|>", + chat_template=chat_template if instruct else None, model_input_names=["input_ids", "attention_mask"], + model_max_length=model_max_length, ) -def write_tokenizer(tokenizer_path, input_tokenizer_path, llama_version=2): +def write_tokenizer(tokenizer_path, input_tokenizer_path, llama_version="2", special_tokens=None, instruct=False): tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast - if llama_version == 3: - tokenizer = Llama3Converter(input_tokenizer_path).tokenizer + if llama_version in ["3", "3.1"]: + tokenizer = Llama3Converter( + input_tokenizer_path, special_tokens, instruct, model_max_length=CONTEXT_LENGTH_FOR_VERSION[llama_version] + ).tokenizer else: tokenizer = tokenizer_class(input_tokenizer_path) print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.") @@ -362,6 +370,37 @@ def write_tokenizer(tokenizer_path, input_tokenizer_path, llama_version=2): return tokenizer +DEFAULT_LLAMA_SPECIAL_TOKENS = { + "3": [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|reserved_special_token_2|>", + "<|reserved_special_token_3|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|reserved_special_token_4|>", + "<|eot_id|>", # end of turn + ] + + [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)], + "3.1": [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|finetune_right_pad_id|>", + "<|reserved_special_token_2|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|eom_id|>", # end of message + "<|eot_id|>", # end of turn + "<|python_tag|>", + ] + + [f"<|reserved_special_token_{i}|>" for i in range(3, 256 - 8)], +} + + def main(): parser = argparse.ArgumentParser() parser.add_argument( @@ -383,9 +422,9 @@ def main(): # Different Llama versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used. parser.add_argument( "--llama_version", - choices=[1, 2, 3], - default=1, - type=int, + choices=["1", "2", "3", "3.1"], + default="1", + type=str, help="Version of the Llama model to convert. Currently supports Llama1 and Llama2. Controls the context size", ) parser.add_argument( @@ -394,11 +433,34 @@ def main(): type=int, help="The number of individual shards used for the model. Does not have to be the same as the number of consolidated_xx.pth", ) + parser.add_argument( + "--special_tokens", + default=None, + type=List[str], + help="The list of special tokens that should be added to the model.", + ) + parser.add_argument( + "--instruct", + default=False, + type=bool, + help="Whether the model is an instruct model or not. Will affect special tokens for llama 3.1.", + ) args = parser.parse_args() if args.model_size is None and args.num_shards is None: raise ValueError("You have to set at least `num_shards` if you are not giving the `model_size`") + if args.special_tokens is None: + args.special_tokens = DEFAULT_LLAMA_SPECIAL_TOKENS[str(args.llama_version)] + spm_path = os.path.join(args.input_dir, "tokenizer.model") - vocab_size = len(write_tokenizer(args.output_dir, spm_path, llama_version=args.llama_version)) + vocab_size = len( + write_tokenizer( + args.output_dir, + spm_path, + llama_version=args.llama_version, + special_tokens=args.special_tokens, + instruct=args.instruct, + ) + ) if args.model_size != "tokenizer_only": write_model( model_path=args.output_dir, @@ -408,6 +470,7 @@ def main(): llama_version=args.llama_version, vocab_size=vocab_size, num_shards=args.num_shards, + instruct=args.instruct, )