mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
add key_mapping keyword
This commit is contained in:
parent
65d0cbc495
commit
cb5de19d8f
1 changed files with 25 additions and 8 deletions
|
|
@ -3953,8 +3953,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
more information about each option see [designing a device
|
||||
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
||||
max_memory (`Dict`, *optional*):
|
||||
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
|
||||
A dictionary device identifier to maximum memory if using `device_map`. Will default to the maximum memory available for each
|
||||
GPU and the available CPU RAM if unset.
|
||||
tp_plan (`str`, *optional*):
|
||||
A torch tensor parallel plan, see [here](https://pytorch.org/tutorials/intermediate/TP_tutorial.html). Currently, it only accepts
|
||||
`tp_plan="auto"` to use predefined plan based on the model. Note that if you use it, you should launch your script accordingly with
|
||||
`torchrun [args] script.py`. This will be much faster than using a `device_map`, but has limitations.
|
||||
offload_folder (`str` or `os.PathLike`, *optional*):
|
||||
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
|
||||
offload_state_dict (`bool`, *optional*):
|
||||
|
|
@ -3978,12 +3982,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors`
|
||||
is not installed, it will be set to `False`.
|
||||
|
||||
weights_only (`bool`, *optional*, defaults to `True`):
|
||||
Indicates whether unpickler should be restricted to loading only tensors, primitive types,
|
||||
dictionaries and any types added via torch.serialization.add_safe_globals().
|
||||
When set to False, we can load wrapper tensor subclass weights.
|
||||
|
||||
key_mapping (`Dict[str, str], *optional*):
|
||||
A potential mapping of the weight names if using a model on the Hub which is compatible to a Transformers
|
||||
architecture, but was not converted accordingly.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
|
||||
|
|
@ -4071,6 +4076,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
generation_config = kwargs.pop("generation_config", None)
|
||||
gguf_file = kwargs.pop("gguf_file", None)
|
||||
tp_plan = kwargs.pop("tp_plan", None)
|
||||
key_mapping = kwargs.pop("key_mapping", None)
|
||||
|
||||
if state_dict is not None and (pretrained_model_name_or_path is not None or gguf_file is not None):
|
||||
raise ValueError(
|
||||
|
|
@ -4494,6 +4500,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
gguf_file=gguf_file,
|
||||
device_mesh=device_mesh,
|
||||
key_mapping=key_mapping,
|
||||
weights_only=weights_only,
|
||||
_fast_init=_fast_init,
|
||||
)
|
||||
|
|
@ -4609,7 +4616,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
return key, False
|
||||
|
||||
@classmethod
|
||||
def _fix_state_dict_keys_on_load(cls, state_dict: Dict[str, Any]):
|
||||
def _fix_state_dict_keys_on_load(cls, state_dict: Dict[str, Any], key_mapping: Optional[Dict[str, str]] = None):
|
||||
"""Fixes state dict keys by replacing legacy parameter names with their modern equivalents.
|
||||
Logs if any parameters have been renamed.
|
||||
"""
|
||||
|
|
@ -4618,6 +4625,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
state_dict_keys = list(state_dict.keys())
|
||||
for key in state_dict_keys:
|
||||
new_key, has_changed = cls._fix_state_dict_key_on_load(key)
|
||||
# Optionally map the key according to `key_mapping`
|
||||
if key_mapping is not None:
|
||||
for pattern, replacement in key_mapping.items():
|
||||
new_key, n_replace = re.subn(pattern, replacement, new_key)
|
||||
# Early exit of the loop
|
||||
if n_replace > 0:
|
||||
has_changed = True
|
||||
break
|
||||
|
||||
if has_changed:
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
|
|
@ -4670,6 +4686,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
keep_in_fp32_modules: Optional[List[str]] = None,
|
||||
gguf_file: Optional[str] = None,
|
||||
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
|
||||
key_mapping: Optional[Dict[str, str]] = None,
|
||||
weights_only: bool = True,
|
||||
_fast_init: bool = True,
|
||||
):
|
||||
|
|
@ -4686,13 +4703,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
|
||||
# Rename the keys (use dummy values in input dict as a small trick)
|
||||
renamed_loaded_keys = list(
|
||||
cls._fix_state_dict_keys_on_load({key: "" for key in loaded_state_dict_keys}).keys()
|
||||
cls._fix_state_dict_keys_on_load({key: "" for key in loaded_state_dict_keys}, key_mapping).keys()
|
||||
)
|
||||
# The device_map and weight_map have to be changed accordingly to match the keys that will be loaded, then renamed
|
||||
if device_map is not None:
|
||||
device_map = cls._fix_state_dict_keys_on_load(device_map)
|
||||
device_map = cls._fix_state_dict_keys_on_load(device_map, key_mapping)
|
||||
if sharded_metadata is not None and "weight_map" in sharded_metadata.keys():
|
||||
sharded_metadata["weight_map"] = cls._fix_state_dict_keys_on_load(sharded_metadata["weight_map"])
|
||||
sharded_metadata["weight_map"] = cls._fix_state_dict_keys_on_load(sharded_metadata["weight_map"], key_mapping)
|
||||
|
||||
# Check if we are in a special state, i.e. loading from a state dict coming from a different architecture
|
||||
prefix = model.base_model_prefix
|
||||
|
|
@ -4863,7 +4880,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
)
|
||||
|
||||
# Modify the keys if needed
|
||||
state_dict = cls._fix_state_dict_keys_on_load(state_dict)
|
||||
state_dict = cls._fix_state_dict_keys_on_load(state_dict, key_mapping)
|
||||
|
||||
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
||||
# matching the weights in the model.
|
||||
|
|
|
|||
Loading…
Reference in a new issue