diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 6ce64ee67..f642d011c 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3953,8 +3953,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). max_memory (`Dict`, *optional*): - A dictionary device identifier to maximum memory. Will default to the maximum memory available for each + A dictionary device identifier to maximum memory if using `device_map`. Will default to the maximum memory available for each GPU and the available CPU RAM if unset. + tp_plan (`str`, *optional*): + A torch tensor parallel plan, see [here](https://pytorch.org/tutorials/intermediate/TP_tutorial.html). Currently, it only accepts + `tp_plan="auto"` to use predefined plan based on the model. Note that if you use it, you should launch your script accordingly with + `torchrun [args] script.py`. This will be much faster than using a `device_map`, but has limitations. offload_folder (`str` or `os.PathLike`, *optional*): If the `device_map` contains any value `"disk"`, the folder where we will offload weights. offload_state_dict (`bool`, *optional*): @@ -3978,12 +3982,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix use_safetensors (`bool`, *optional*, defaults to `None`): Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors` is not installed, it will be set to `False`. - weights_only (`bool`, *optional*, defaults to `True`): Indicates whether unpickler should be restricted to loading only tensors, primitive types, dictionaries and any types added via torch.serialization.add_safe_globals(). When set to False, we can load wrapper tensor subclass weights. - + key_mapping (`Dict[str, str], *optional*): + A potential mapping of the weight names if using a model on the Hub which is compatible to a Transformers + architecture, but was not converted accordingly. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., `output_attentions=True`). Behaves differently depending on whether a `config` is provided or @@ -4071,6 +4076,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix generation_config = kwargs.pop("generation_config", None) gguf_file = kwargs.pop("gguf_file", None) tp_plan = kwargs.pop("tp_plan", None) + key_mapping = kwargs.pop("key_mapping", None) if state_dict is not None and (pretrained_model_name_or_path is not None or gguf_file is not None): raise ValueError( @@ -4494,6 +4500,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix keep_in_fp32_modules=keep_in_fp32_modules, gguf_file=gguf_file, device_mesh=device_mesh, + key_mapping=key_mapping, weights_only=weights_only, _fast_init=_fast_init, ) @@ -4609,7 +4616,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix return key, False @classmethod - def _fix_state_dict_keys_on_load(cls, state_dict: Dict[str, Any]): + def _fix_state_dict_keys_on_load(cls, state_dict: Dict[str, Any], key_mapping: Optional[Dict[str, str]] = None): """Fixes state dict keys by replacing legacy parameter names with their modern equivalents. Logs if any parameters have been renamed. """ @@ -4618,6 +4625,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix state_dict_keys = list(state_dict.keys()) for key in state_dict_keys: new_key, has_changed = cls._fix_state_dict_key_on_load(key) + # Optionally map the key according to `key_mapping` + if key_mapping is not None: + for pattern, replacement in key_mapping.items(): + new_key, n_replace = re.subn(pattern, replacement, new_key) + # Early exit of the loop + if n_replace > 0: + has_changed = True + break + if has_changed: state_dict[new_key] = state_dict.pop(key) @@ -4670,6 +4686,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix keep_in_fp32_modules: Optional[List[str]] = None, gguf_file: Optional[str] = None, device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, + key_mapping: Optional[Dict[str, str]] = None, weights_only: bool = True, _fast_init: bool = True, ): @@ -4686,13 +4703,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Rename the keys (use dummy values in input dict as a small trick) renamed_loaded_keys = list( - cls._fix_state_dict_keys_on_load({key: "" for key in loaded_state_dict_keys}).keys() + cls._fix_state_dict_keys_on_load({key: "" for key in loaded_state_dict_keys}, key_mapping).keys() ) # The device_map and weight_map have to be changed accordingly to match the keys that will be loaded, then renamed if device_map is not None: - device_map = cls._fix_state_dict_keys_on_load(device_map) + device_map = cls._fix_state_dict_keys_on_load(device_map, key_mapping) if sharded_metadata is not None and "weight_map" in sharded_metadata.keys(): - sharded_metadata["weight_map"] = cls._fix_state_dict_keys_on_load(sharded_metadata["weight_map"]) + sharded_metadata["weight_map"] = cls._fix_state_dict_keys_on_load(sharded_metadata["weight_map"], key_mapping) # Check if we are in a special state, i.e. loading from a state dict coming from a different architecture prefix = model.base_model_prefix @@ -4863,7 +4880,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) # Modify the keys if needed - state_dict = cls._fix_state_dict_keys_on_load(state_dict) + state_dict = cls._fix_state_dict_keys_on_load(state_dict, key_mapping) # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not # matching the weights in the model.