From 11e378024d3f16b16d897b25179bb76020875cc0 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 5 Feb 2025 13:52:57 +0100 Subject: [PATCH] add type hints/docstring --- src/transformers/modeling_utils.py | 248 +++++++++++++++++------------ 1 file changed, 144 insertions(+), 104 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index cc5052790..446f8847f 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -752,22 +752,21 @@ def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix): def _load_state_dict_into_meta_model( - model, - state_dict, - start_prefix, - expected_keys, - device_map=None, - offload_folder=None, - offload_index=None, - state_dict_folder=None, - state_dict_index=None, - dtype=None, - hf_quantizer=None, - is_safetensors=False, - keep_in_fp32_modules=None, - unexpected_keys=None, # passing `unexpected` for cleanup from quantization items - pretrained_model_name_or_path=None, # for flagging the user when the model contains renamed keys -): + model: "PreTrainedModel", + state_dict: Dict, + start_prefix: str, + expected_keys: Dict, + device_map: Optional[Dict] = None, + disk_offload_folder: Optional[str] = None, + disk_offload_index: Optional[Dict] = None, + cpu_offload_folder: Optional[str] = None, + cpu_offload_index: Optional[Dict] = None, + dtype: Optional[torch.dtype] = None, + hf_quantizer: Optional[HfQuantizer] = None, + is_safetensors: bool = False, + keep_in_fp32_modules: Optional[List[str]] = None, + unexpected_keys: Optional[Dict] = None, # passing `unexpected` for cleanup from quantization items +) -> Tuple[List[str], Optional[Dict], Optional[Dict]]: """ This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the @@ -857,9 +856,9 @@ def _load_state_dict_into_meta_model( if param_device == "disk": if not is_safetensors: - offload_index = offload_weight(param, param_name, offload_folder, offload_index) - elif param_device == "cpu" and state_dict_index is not None: - state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index) + disk_offload_index = offload_weight(param, param_name, disk_offload_folder, disk_offload_index) + elif param_device == "cpu" and cpu_offload_index is not None: + cpu_offload_index = offload_weight(param, param_name, cpu_offload_folder, cpu_offload_index) elif ( not is_quantized or (not hf_quantizer.requires_parameters_quantization) @@ -892,7 +891,7 @@ def _load_state_dict_into_meta_model( setattr(module, tensor_name, value) # TODO: consider removing used param_parts from state_dict before return - return error_msgs, offload_index, state_dict_index + return error_msgs, disk_offload_index, cpu_offload_index def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: @@ -905,7 +904,7 @@ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: def get_checkpoint_files( - pretrained_model_name_or_path: Union[str, os.PathLike], + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], subfolder: str, variant: Optional[str], gguf_file: Optional[str], @@ -920,7 +919,7 @@ def get_checkpoint_files( user_agent: dict, revision: str, commit_hash: Optional[str], -) -> Tuple[List[str], Dict]: +) -> Tuple[Optional[List[str]], Optional[Dict]]: is_sharded = False if pretrained_model_name_or_path is not None and gguf_file is None: @@ -1216,12 +1215,15 @@ def get_checkpoint_files( def get_torch_dtype( cls, torch_dtype: Optional[Union[str, torch.dtype, Dict]], - checkpoint_files: List[str], + checkpoint_files: Optional[List[str]], config: PretrainedConfig, sharded_metadata: Optional[Dict], state_dict: Optional[Dict], weights_only: bool, -): +) -> Tuple[PretrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]: + """Find the correct `torch_dtype` to use based on provided arguments. Also update the `config` based on the + infered dtype. + """ dtype_orig = None is_sharded = sharded_metadata is not None @@ -1282,7 +1284,15 @@ def get_torch_dtype( return config, torch_dtype, dtype_orig -def get_device_map(model, device_map, max_memory, hf_quantizer, torch_dtype, keep_in_fp32_modules): +def get_device_map( + model: "PreTrainedModel", + device_map: Optional[Union[str, Dict]], + max_memory: Optional[Dict], + hf_quantizer: Optional[HfQuantizer], + torch_dtype: Optional[torch.dtype], + keep_in_fp32_modules: Optional[List[str]], +) -> Tuple["PreTrainedModel", Dict]: + """Compute the final `device_map` to use. Also tie model parameters, and check for any device inconsistencies.""" if isinstance(device_map, str): special_dtypes = {} if hf_quantizer is not None: @@ -1342,8 +1352,18 @@ def get_device_map(model, device_map, max_memory, hf_quantizer, torch_dtype, kee def find_missing_and_unexpected_keys( - cls, model, expected_keys, loaded_keys, remove_prefix_from_model, add_prefix_to_model, hf_quantizer, device_map -): + cls, + model: "PreTrainedModel", + expected_keys: List[str], + loaded_keys: List[str], + remove_prefix_from_model: bool, + add_prefix_to_model: bool, + hf_quantizer: Optional[HfQuantizer], + device_map: Optional[Dict], +) -> Tuple["PreTrainedModel", List[str], List[str]]: + """Find missing keys (keys that are part of the model parameters but were NOT found in the loaded state dict keys) and unexpected keys + (keys found in the loaded state dict keys, but that are NOT part of the model parameters) + """ prefix = model.base_model_prefix _prefix = f"{prefix}." @@ -1362,7 +1382,7 @@ def find_missing_and_unexpected_keys( # Clean up buffer for `inv-freq` because RoPE embedding moved under base model (https://github.com/huggingface/transformers/pull/34858) has_inv_freq_buffers = any(buffer.endswith("rotary_emb.inv_freq") for buffer in model_buffers) if has_inv_freq_buffers: - unexpected_keys = {k for k in unexpected_keys if "rotary_emb.inv_freq" not in k} + unexpected_keys = [k for k in unexpected_keys if "rotary_emb.inv_freq" not in k] model.tie_weights() if device_map is None and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled(): @@ -1402,14 +1422,17 @@ def find_missing_and_unexpected_keys( def move_missing_keys_back_to_cpu( - model, - missing_keys, - unexpected_keys, - dtype, - keep_in_fp32_modules, - is_quantized, - hf_quantizer, -): + model: "PreTrainedModel", + missing_keys: List[str], + unexpected_keys: List[str], + dtype: Optional[torch.dtype], + keep_in_fp32_modules: Optional[List[str]], + hf_quantizer: Optional[HfQuantizer], +) -> "PreTrainedModel": + """Move the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts) back + from meta device to cpu. + """ + is_quantized = hf_quantizer is not None prefix = model.base_model_prefix model_state_dict = model.state_dict() @@ -1446,8 +1469,18 @@ def move_missing_keys_back_to_cpu( def initialize_missing_keys( - model, loaded_keys, ignore_mismatched_sizes, has_prefix_module, expects_prefix_module, is_quantized -): + model: "PreTrainedModel", + loaded_keys: List[str], + ignore_mismatched_sizes: bool, + has_prefix_module: bool, + expects_prefix_module: bool, + is_quantized: bool, +) -> "PreTrainedModel": + """Initialize the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts), according to + `_initialize_weights`. Indeed, since the corresponding weights are missing from the state dict, they will not be replaced and need to + be initialized correctly (i.e. weight initialization distribution). + Also take care of setting the `_is_hf_initialized` flag for keys that are not missing. + """ prefix = model.base_model_prefix remove_prefix_from_model = not has_prefix_module and expects_prefix_module @@ -1493,18 +1526,20 @@ def initialize_missing_keys( def find_mismatched_keys( - state_dict, - model_state_dict, - loaded_keys, - original_loaded_keys, - add_prefix_to_model, - remove_prefix_from_model, - ignore_mismatched_sizes, - prefix, -): + state_dict: Dict, + model_state_dict: Dict, + renamed_loaded_keys: List[str], + original_loaded_keys: List[str], + add_prefix_to_model: bool, + remove_prefix_from_model: bool, + ignore_mismatched_sizes: bool, + prefix: str, +) -> List: + """Find mismatch of shapes between the model parameters and the loaded state dict, and optionally remove the + problematic keys from `state_dict` if `ignore_mismatched_sizes` is `True`.""" mismatched_keys = [] if ignore_mismatched_sizes: - for checkpoint_key, model_key in zip(original_loaded_keys, loaded_keys): + for checkpoint_key, model_key in zip(original_loaded_keys, renamed_loaded_keys): # If the checkpoint is sharded, we may not have the key here. if checkpoint_key not in state_dict: continue @@ -4459,10 +4494,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix pretrained_model_name_or_path, ignore_mismatched_sizes=ignore_mismatched_sizes, sharded_metadata=sharded_metadata, - _fast_init=_fast_init, low_cpu_mem_usage=low_cpu_mem_usage, device_map=device_map, - offload_folder=offload_folder, + disk_offload_folder=offload_folder, offload_state_dict=offload_state_dict, dtype=torch_dtype, hf_quantizer=hf_quantizer, @@ -4637,22 +4671,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix @classmethod def _load_pretrained_model( cls, - model, - state_dict, - checkpoint_files, - pretrained_model_name_or_path, - ignore_mismatched_sizes=False, - sharded_metadata=None, - _fast_init=True, - low_cpu_mem_usage=False, - device_map=None, - offload_folder=None, - offload_state_dict=None, - dtype=None, - hf_quantizer=None, - keep_in_fp32_modules=None, - gguf_file=None, - weights_only=True, + model: "PreTrainedModel", + state_dict: Optional[Dict], + checkpoint_files: Optional[List[str]], + pretrained_model_name_or_path: Optional[str], + ignore_mismatched_sizes: bool = False, + sharded_metadata: Optional[Dict] = None, + low_cpu_mem_usage: bool = False, + device_map: Optional[Dict] = None, + disk_offload_folder: Optional[str] = None, + offload_state_dict: Optional[bool] = None, + dtype: Optional[torch.dtype] = None, + hf_quantizer: Optional[HfQuantizer] = None, + keep_in_fp32_modules: Optional[List[str]] = None, + gguf_file: Optional[str] = None, + weights_only: bool = True, ): # Get all the keys of the state dicts that we have to initialize the model if sharded_metadata is not None: @@ -4667,16 +4700,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix is_from_file = pretrained_model_name_or_path is not None or gguf_file is not None is_offloaded_safetensors = False - if is_from_file and device_map is not None and "disk" in device_map.values(): - is_offloaded_safetensors = checkpoint_files[0].endswith(".safetensors") - if offload_folder is None and not is_offloaded_safetensors: + if device_map is not None and "disk" in device_map.values(): + is_offloaded_safetensors = is_from_file and checkpoint_files[0].endswith(".safetensors") + if disk_offload_folder is None and not is_offloaded_safetensors: raise ValueError( "The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`" " for them. Alternatively, make sure you have `safetensors` installed if the model you are using" " offers the weights in this format." ) - if offload_folder is not None: - os.makedirs(offload_folder, exist_ok=True) + if disk_offload_folder is not None: + os.makedirs(disk_offload_folder, exist_ok=True) if offload_state_dict is None: offload_state_dict = True @@ -4693,9 +4726,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if hf_quantizer is not None: expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_state_dict_keys) - loaded_keys = [cls._fix_state_dict_key_on_load(key)[0] for key in loaded_state_dict_keys] + renamed_loaded_keys = [cls._fix_state_dict_key_on_load(key)[0] for key in loaded_state_dict_keys] - has_prefix_module = any(s.startswith(prefix) for s in loaded_keys) if len(prefix) > 0 else False + has_prefix_module = any(s.startswith(prefix) for s in renamed_loaded_keys) if len(prefix) > 0 else False expects_prefix_module = any(s.startswith(prefix) for s in expected_keys) if len(prefix) > 0 else False remove_prefix_from_model = not has_prefix_module and expects_prefix_module @@ -4713,7 +4746,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix cls, model, expected_keys, - loaded_keys, + renamed_loaded_keys, remove_prefix_from_model, add_prefix_to_model, hf_quantizer, @@ -4728,10 +4761,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) # correctly initialize the missing keys - if _fast_init: - model = initialize_missing_keys( - model, loaded_keys, ignore_mismatched_sizes, has_prefix_module, expects_prefix_module, is_quantized - ) + model = initialize_missing_keys( + model, renamed_loaded_keys, ignore_mismatched_sizes, has_prefix_module, expects_prefix_module, is_quantized + ) # Set some modules to fp32 if needed if keep_in_fp32_modules is not None: @@ -4748,7 +4780,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module: model_to_load = getattr(model, cls.base_model_prefix) base_model_expected_keys = list(model_to_load.state_dict().keys()) - if any(key in expected_keys_not_prefixed and key not in base_model_expected_keys for key in loaded_keys): + if any( + key in expected_keys_not_prefixed and key not in base_model_expected_keys + for key in renamed_loaded_keys + ): raise ValueError( "The state dictionary of the model you are trying to load is corrupted. Are you sure it was " "properly saved?" @@ -4762,8 +4797,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix else: folder = None - # In case we need to offload to disk, compute the index - offload_index = None + # This offload index if for params explicitly on the "disk" in the device_map + disk_offload_index = None if is_offloaded_safetensors: param_device_map = expand_device_map(device_map, loaded_state_dict_keys, start_prefix) str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32" @@ -4771,21 +4806,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix weight_map = {p: checkpoint_files[0] for p in loaded_state_dict_keys} else: weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()} - offload_index = { + disk_offload_index = { p[len(start_prefix) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype} for p, f in weight_map.items() if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk" } elif device_map is not None and "disk" in device_map.values(): - offload_index = {} + disk_offload_index = {} - state_dict_folder = None - state_dict_index = None + # This offload index if for params that are supposed to be on the "cpu", either with or without a device_map + # It allows to load parameters one-by-one from the state dict, avoiding a memory peak of 2 x state_dict_size, + # i.e. 1x to load it, and 1x to copy it to model + cpu_offload_folder = None + cpu_offload_index = None if offload_state_dict: - state_dict_folder = tempfile.mkdtemp() - state_dict_index = {} + cpu_offload_folder = tempfile.mkdtemp() + cpu_offload_index = {} - # Find checkpoint files containing only weights offloaded to disk if any + # Find checkpoint files containing only weights offloaded to disk if any (allows to be faster) disk_only_shard_files = [] if is_sharded_offloaded_safetensors: disk_only_shard_files = get_disk_only_shard_files( @@ -4830,7 +4868,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix mismatched_keys += find_mismatched_keys( state_dict, model_state_dict, - loaded_keys, + renamed_loaded_keys, loaded_state_dict_keys, add_prefix_to_model, remove_prefix_from_model, @@ -4846,16 +4884,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) else: fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict) - new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( + new_error_msgs, disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model( model_to_load, fixed_state_dict, start_prefix, expected_keys, device_map=device_map, - offload_folder=offload_folder, - offload_index=offload_index, - state_dict_folder=state_dict_folder, - state_dict_index=state_dict_index, + disk_offload_folder=disk_offload_folder, + disk_offload_index=disk_offload_index, + cpu_offload_folder=cpu_offload_folder, + cpu_offload_index=cpu_offload_index, dtype=dtype, hf_quantizer=hf_quantizer, is_safetensors=is_offloaded_safetensors, @@ -4878,25 +4916,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix del state_dict gc.collect() - if offload_index is not None and len(offload_index) > 0: + # Adjust offloaded weights name and save if needed + if disk_offload_index is not None and len(disk_offload_index) > 0: if model != model_to_load: # We need to add the prefix of the base model prefix = cls.base_model_prefix if not is_offloaded_safetensors: - for weight_name in offload_index: + for weight_name in disk_offload_index: shutil.move( - os.path.join(offload_folder, f"{weight_name}.dat"), - os.path.join(offload_folder, f"{prefix}.{weight_name}.dat"), + os.path.join(disk_offload_folder, f"{weight_name}.dat"), + os.path.join(disk_offload_folder, f"{prefix}.{weight_name}.dat"), ) - offload_index = {f"{prefix}.{key}": value for key, value in offload_index.items()} + disk_offload_index = {f"{prefix}.{key}": value for key, value in disk_offload_index.items()} if not is_offloaded_safetensors: - save_offload_index(offload_index, offload_folder) - offload_index = None + save_offload_index(disk_offload_index, disk_offload_folder) + disk_offload_index = None + # 1-by-1 param loading for the cpu params if offload_state_dict: # Load back temporarily offloaded state dict - load_offloaded_weights(model_to_load, state_dict_index, state_dict_folder) - shutil.rmtree(state_dict_folder) + load_offloaded_weights(model_to_load, cpu_offload_index, cpu_offload_folder) + shutil.rmtree(cpu_offload_folder) if len(error_msgs) > 0: error_msg = "\n\t".join(error_msgs) @@ -4947,7 +4987,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix " to use it for predictions and inference." ) - return model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs + return model, missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, error_msgs @classmethod def _load_from_tf(cls, model, config, checkpoint_files):