add key_mapping keyword

This commit is contained in:
Cyril Vallez 2025-02-10 15:09:45 +01:00
parent 65d0cbc495
commit cb5de19d8f
No known key found for this signature in database

View file

@ -3953,8 +3953,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
more information about each option see [designing a device
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
max_memory (`Dict`, *optional*):
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
A dictionary device identifier to maximum memory if using `device_map`. Will default to the maximum memory available for each
GPU and the available CPU RAM if unset.
tp_plan (`str`, *optional*):
A torch tensor parallel plan, see [here](https://pytorch.org/tutorials/intermediate/TP_tutorial.html). Currently, it only accepts
`tp_plan="auto"` to use predefined plan based on the model. Note that if you use it, you should launch your script accordingly with
`torchrun [args] script.py`. This will be much faster than using a `device_map`, but has limitations.
offload_folder (`str` or `os.PathLike`, *optional*):
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
offload_state_dict (`bool`, *optional*):
@ -3978,12 +3982,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
use_safetensors (`bool`, *optional*, defaults to `None`):
Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors`
is not installed, it will be set to `False`.
weights_only (`bool`, *optional*, defaults to `True`):
Indicates whether unpickler should be restricted to loading only tensors, primitive types,
dictionaries and any types added via torch.serialization.add_safe_globals().
When set to False, we can load wrapper tensor subclass weights.
key_mapping (`Dict[str, str], *optional*):
A potential mapping of the weight names if using a model on the Hub which is compatible to a Transformers
architecture, but was not converted accordingly.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
@ -4071,6 +4076,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
generation_config = kwargs.pop("generation_config", None)
gguf_file = kwargs.pop("gguf_file", None)
tp_plan = kwargs.pop("tp_plan", None)
key_mapping = kwargs.pop("key_mapping", None)
if state_dict is not None and (pretrained_model_name_or_path is not None or gguf_file is not None):
raise ValueError(
@ -4494,6 +4500,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
keep_in_fp32_modules=keep_in_fp32_modules,
gguf_file=gguf_file,
device_mesh=device_mesh,
key_mapping=key_mapping,
weights_only=weights_only,
_fast_init=_fast_init,
)
@ -4609,7 +4616,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return key, False
@classmethod
def _fix_state_dict_keys_on_load(cls, state_dict: Dict[str, Any]):
def _fix_state_dict_keys_on_load(cls, state_dict: Dict[str, Any], key_mapping: Optional[Dict[str, str]] = None):
"""Fixes state dict keys by replacing legacy parameter names with their modern equivalents.
Logs if any parameters have been renamed.
"""
@ -4618,6 +4625,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
state_dict_keys = list(state_dict.keys())
for key in state_dict_keys:
new_key, has_changed = cls._fix_state_dict_key_on_load(key)
# Optionally map the key according to `key_mapping`
if key_mapping is not None:
for pattern, replacement in key_mapping.items():
new_key, n_replace = re.subn(pattern, replacement, new_key)
# Early exit of the loop
if n_replace > 0:
has_changed = True
break
if has_changed:
state_dict[new_key] = state_dict.pop(key)
@ -4670,6 +4686,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
keep_in_fp32_modules: Optional[List[str]] = None,
gguf_file: Optional[str] = None,
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
key_mapping: Optional[Dict[str, str]] = None,
weights_only: bool = True,
_fast_init: bool = True,
):
@ -4686,13 +4703,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Rename the keys (use dummy values in input dict as a small trick)
renamed_loaded_keys = list(
cls._fix_state_dict_keys_on_load({key: "" for key in loaded_state_dict_keys}).keys()
cls._fix_state_dict_keys_on_load({key: "" for key in loaded_state_dict_keys}, key_mapping).keys()
)
# The device_map and weight_map have to be changed accordingly to match the keys that will be loaded, then renamed
if device_map is not None:
device_map = cls._fix_state_dict_keys_on_load(device_map)
device_map = cls._fix_state_dict_keys_on_load(device_map, key_mapping)
if sharded_metadata is not None and "weight_map" in sharded_metadata.keys():
sharded_metadata["weight_map"] = cls._fix_state_dict_keys_on_load(sharded_metadata["weight_map"])
sharded_metadata["weight_map"] = cls._fix_state_dict_keys_on_load(sharded_metadata["weight_map"], key_mapping)
# Check if we are in a special state, i.e. loading from a state dict coming from a different architecture
prefix = model.base_model_prefix
@ -4863,7 +4880,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
# Modify the keys if needed
state_dict = cls._fix_state_dict_keys_on_load(state_dict)
state_dict = cls._fix_state_dict_keys_on_load(state_dict, key_mapping)
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.