From 1bdb7bba52d2e030d66d9672841912bb429a67b9 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 5 Feb 2025 15:27:08 +0100 Subject: [PATCH] Update modeling_utils.py --- src/transformers/modeling_utils.py | 37 +++++++++++++++--------------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 61f5b4d2b..20c2fa129 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -899,7 +899,7 @@ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: return weights_name -def get_checkpoint_files( +def _get_checkpoint_files( pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], subfolder: str, variant: Optional[str], @@ -916,6 +916,7 @@ def get_checkpoint_files( revision: str, commit_hash: Optional[str], ) -> Tuple[Optional[List[str]], Optional[Dict]]: + """Get the full checkpoint filenames where the weights reside, and optional metadata if the checkpoints are sharded.""" is_sharded = False if pretrained_model_name_or_path is not None and gguf_file is None: @@ -1208,7 +1209,7 @@ def get_checkpoint_files( return checkpoint_files, sharded_metadata -def get_torch_dtype( +def _get_torch_dtype( cls, torch_dtype: Optional[Union[str, torch.dtype, Dict]], checkpoint_files: Optional[List[str]], @@ -1284,14 +1285,14 @@ def get_torch_dtype( return config, torch_dtype, dtype_orig -def get_device_map( +def _get_device_map( model: "PreTrainedModel", device_map: Optional[Union[str, Dict]], max_memory: Optional[Dict], hf_quantizer: Optional[HfQuantizer], torch_dtype: Optional[torch.dtype], keep_in_fp32_modules: Optional[List[str]], -) -> Tuple["PreTrainedModel", Dict]: +) -> Dict: """Compute the final `device_map` to use. Also tie model parameters, and check for any device inconsistencies.""" if isinstance(device_map, str): special_dtypes = {} @@ -1348,10 +1349,10 @@ def get_device_map( # check if we don't have tied param in different devices check_tied_parameters_on_same_device(tied_params, device_map) - return model, device_map + return device_map -def find_missing_and_unexpected_keys( +def _find_missing_and_unexpected_keys( cls, model: "PreTrainedModel", expected_keys: List[str], @@ -1359,7 +1360,7 @@ def find_missing_and_unexpected_keys( remove_prefix_from_model: bool, add_prefix_to_model: bool, hf_quantizer: Optional[HfQuantizer], -) -> Tuple["PreTrainedModel", List[str], List[str]]: +) -> Tuple[List[str], List[str]]: """Find missing keys (keys that are part of the model parameters but were NOT found in the loaded state dict keys) and unexpected keys (keys found in the loaded state dict keys, but that are NOT part of the model parameters) """ @@ -1407,10 +1408,10 @@ def find_missing_and_unexpected_keys( if hf_quantizer is not None: missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix) - return model, missing_keys, unexpected_keys + return missing_keys, unexpected_keys -def move_missing_keys_back_to_cpu( +def _move_missing_keys_back_to_cpu( model: "PreTrainedModel", missing_keys: List[str], unexpected_keys: List[str], @@ -1457,7 +1458,7 @@ def move_missing_keys_back_to_cpu( return model -def initialize_missing_keys( +def _initialize_missing_keys( model: "PreTrainedModel", loaded_keys: List[str], ignore_mismatched_sizes: bool, @@ -1514,7 +1515,7 @@ def initialize_missing_keys( return model -def find_mismatched_keys( +def _find_mismatched_keys( state_dict: Dict, model_state_dict: Dict, renamed_loaded_keys: List[str], @@ -4312,7 +4313,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix "You cannot combine Quantization and loading a model from a GGUF file, try again by making sure you did not passed a `quantization_config` or that you did not load a quantized model from the Hub." ) - checkpoint_files, sharded_metadata = get_checkpoint_files( + checkpoint_files, sharded_metadata = _get_checkpoint_files( pretrained_model_name_or_path=pretrained_model_name_or_path, subfolder=subfolder, variant=variant, @@ -4378,7 +4379,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ] # Find the correct dtype based on current state - config, torch_dtype, dtype_orig = get_torch_dtype( + config, torch_dtype, dtype_orig = _get_torch_dtype( cls, torch_dtype, checkpoint_files, config, sharded_metadata, state_dict, weights_only ) @@ -4440,7 +4441,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Prepare the full device map if device_map is not None: - model, device_map = get_device_map( + device_map = _get_device_map( model, device_map, max_memory, hf_quantizer, torch_dtype, keep_in_fp32_modules ) @@ -4698,7 +4699,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix expected_keys = [".".join([prefix, s]) for s in expected_keys] # Find missing and unexpected keys from the state dict - model, missing_keys, unexpected_keys = find_missing_and_unexpected_keys( + missing_keys, unexpected_keys = _find_missing_and_unexpected_keys( cls, model, expected_keys, @@ -4711,7 +4712,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Move missing keys back to cpu from meta device (because they won't be moved when loading the weights as # they are not in the loaded state dict) if low_cpu_mem_usage: - model = move_missing_keys_back_to_cpu( + model = _move_missing_keys_back_to_cpu( model, missing_keys, unexpected_keys, dtype, keep_in_fp32_modules, hf_quantizer ) # In this case we also need to move everything back @@ -4721,7 +4722,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype)) # correctly initialize the missing keys - model = initialize_missing_keys( + model = _initialize_missing_keys( model, renamed_loaded_keys, ignore_mismatched_sizes, has_prefix_module, expects_prefix_module, is_quantized ) @@ -4831,7 +4832,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not # matching the weights in the model. - mismatched_keys += find_mismatched_keys( + mismatched_keys += _find_mismatched_keys( state_dict, model_state_dict, renamed_loaded_keys,