From 89c740d07a653cebaa8509fa4d192dec9e31c139 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 10 Feb 2025 11:52:55 +0100 Subject: [PATCH] simplify renaming logic --- src/transformers/modeling_utils.py | 51 ++++++++++++------------------ 1 file changed, 21 insertions(+), 30 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ab7f85941..1e0ea7ad1 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1439,41 +1439,35 @@ def _find_missing_and_unexpected_keys( def _find_mismatched_keys( - model: "PreTrainedModel", + model_to_load: "PreTrainedModel", state_dict: Dict, renamed_loaded_keys: List[str], - original_loaded_keys: List[str], - loading_base_model_from_task_state_dict: bool, - loading_task_model_from_base_state_dict: bool, ignore_mismatched_sizes: bool, + prefix_to_remove: str, ) -> List: """Find mismatch of shapes between the model parameters and the loaded state dict, and optionally remove the problematic keys from `state_dict` if `ignore_mismatched_sizes` is `True`.""" mismatched_keys = [] if ignore_mismatched_sizes: - prefix = model.base_model_prefix - model_state_dict = model.state_dict() - for checkpoint_key, model_key in zip(original_loaded_keys, renamed_loaded_keys): + model_state_dict = model_to_load.state_dict() + for key in renamed_loaded_keys: # If the checkpoint is sharded, we may not have the key here. - if checkpoint_key not in state_dict: + if key not in state_dict: continue - model_key = _adjust_loaded_keys_prefix( - model_key, prefix, loading_base_model_from_task_state_dict, loading_task_model_from_base_state_dict - ) + # Remove the prefix if needed + adjusted_key = key[len(prefix_to_remove) :] - if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape: + if adjusted_key in model_state_dict and state_dict[key].shape != model_state_dict[adjusted_key].shape: if ( - state_dict[checkpoint_key].shape[-1] == 1 - and state_dict[checkpoint_key].numel() * 2 == model_state_dict[model_key].numel() + state_dict[key].shape[-1] == 1 + and state_dict[key].numel() * 2 == model_state_dict[adjusted_key].numel() ): # This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences. # Without matching with module type or paramter type it seems like a practical way to detect valid 4bit weights. pass else: - mismatched_keys.append( - (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) - ) - del state_dict[checkpoint_key] + mismatched_keys.append((key, state_dict[key].shape, model_state_dict[adjusted_key].shape)) + del state_dict[key] return mismatched_keys @@ -4691,7 +4685,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix loaded_state_dict_keys = list(load_state_dict(checkpoint_files[0], weights_only=weights_only).keys()) # Rename the keys (use dummy values in input dict as a small trick) - renamed_loaded_keys = list(cls._fix_state_dict_keys_on_load({key: "" for key in loaded_state_dict_keys}).keys()) + renamed_loaded_keys = list( + cls._fix_state_dict_keys_on_load({key: "" for key in loaded_state_dict_keys}).keys() + ) # Check if we are in a special state, i.e. loading from a state dict coming from a different architecture prefix = model.base_model_prefix @@ -4839,9 +4835,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Compute expected model keys expected_keys = list(model_to_load.state_dict().keys()) if hf_quantizer is not None: - modified_keys = loaded_state_dict_keys + modified_keys = renamed_loaded_keys if loading_base_model_from_task_state_dict: - modified_keys = [s[len(_prefix) :] for s in loaded_state_dict_keys if s.startswith(_prefix)] + modified_keys = [s[len(_prefix) :] for s in modified_keys if s.startswith(_prefix)] expected_keys = hf_quantizer.update_expected_keys(model_to_load, expected_keys, modified_keys) error_msgs = [] @@ -4858,20 +4854,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only ) + # Modify the keys in-place if needed + state_dict = cls._fix_state_dict_keys_on_load(state_dict) + # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not # matching the weights in the model. mismatched_keys += _find_mismatched_keys( - model, - state_dict, - renamed_loaded_keys, - loaded_state_dict_keys, - loading_base_model_from_task_state_dict, - loading_task_model_from_base_state_dict, - ignore_mismatched_sizes, + model_to_load, state_dict, renamed_loaded_keys, ignore_mismatched_sizes, start_prefix_to_remove ) if low_cpu_mem_usage or gguf_file is not None: - state_dict = cls._fix_state_dict_keys_on_load(state_dict) # Skip it with fsdp on ranks other than 0 if not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized): new_error_msgs, disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model( @@ -4898,7 +4890,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix assign_to_params_buffers = check_support_param_buffer_assignment( model_to_load, state_dict, start_prefix_to_remove ) - state_dict = cls._fix_state_dict_keys_on_load(state_dict) error_msgs += _load_state_dict_into_model( model_to_load, state_dict, start_prefix_to_remove, assign_to_params_buffers )