mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
simplify renaming logic
This commit is contained in:
parent
1c0c5cf7b4
commit
89c740d07a
1 changed files with 21 additions and 30 deletions
|
|
@ -1439,41 +1439,35 @@ def _find_missing_and_unexpected_keys(
|
|||
|
||||
|
||||
def _find_mismatched_keys(
|
||||
model: "PreTrainedModel",
|
||||
model_to_load: "PreTrainedModel",
|
||||
state_dict: Dict,
|
||||
renamed_loaded_keys: List[str],
|
||||
original_loaded_keys: List[str],
|
||||
loading_base_model_from_task_state_dict: bool,
|
||||
loading_task_model_from_base_state_dict: bool,
|
||||
ignore_mismatched_sizes: bool,
|
||||
prefix_to_remove: str,
|
||||
) -> List:
|
||||
"""Find mismatch of shapes between the model parameters and the loaded state dict, and optionally remove the
|
||||
problematic keys from `state_dict` if `ignore_mismatched_sizes` is `True`."""
|
||||
mismatched_keys = []
|
||||
if ignore_mismatched_sizes:
|
||||
prefix = model.base_model_prefix
|
||||
model_state_dict = model.state_dict()
|
||||
for checkpoint_key, model_key in zip(original_loaded_keys, renamed_loaded_keys):
|
||||
model_state_dict = model_to_load.state_dict()
|
||||
for key in renamed_loaded_keys:
|
||||
# If the checkpoint is sharded, we may not have the key here.
|
||||
if checkpoint_key not in state_dict:
|
||||
if key not in state_dict:
|
||||
continue
|
||||
model_key = _adjust_loaded_keys_prefix(
|
||||
model_key, prefix, loading_base_model_from_task_state_dict, loading_task_model_from_base_state_dict
|
||||
)
|
||||
# Remove the prefix if needed
|
||||
adjusted_key = key[len(prefix_to_remove) :]
|
||||
|
||||
if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
|
||||
if adjusted_key in model_state_dict and state_dict[key].shape != model_state_dict[adjusted_key].shape:
|
||||
if (
|
||||
state_dict[checkpoint_key].shape[-1] == 1
|
||||
and state_dict[checkpoint_key].numel() * 2 == model_state_dict[model_key].numel()
|
||||
state_dict[key].shape[-1] == 1
|
||||
and state_dict[key].numel() * 2 == model_state_dict[adjusted_key].numel()
|
||||
):
|
||||
# This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences.
|
||||
# Without matching with module type or paramter type it seems like a practical way to detect valid 4bit weights.
|
||||
pass
|
||||
else:
|
||||
mismatched_keys.append(
|
||||
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
||||
)
|
||||
del state_dict[checkpoint_key]
|
||||
mismatched_keys.append((key, state_dict[key].shape, model_state_dict[adjusted_key].shape))
|
||||
del state_dict[key]
|
||||
return mismatched_keys
|
||||
|
||||
|
||||
|
|
@ -4691,7 +4685,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
loaded_state_dict_keys = list(load_state_dict(checkpoint_files[0], weights_only=weights_only).keys())
|
||||
|
||||
# Rename the keys (use dummy values in input dict as a small trick)
|
||||
renamed_loaded_keys = list(cls._fix_state_dict_keys_on_load({key: "" for key in loaded_state_dict_keys}).keys())
|
||||
renamed_loaded_keys = list(
|
||||
cls._fix_state_dict_keys_on_load({key: "" for key in loaded_state_dict_keys}).keys()
|
||||
)
|
||||
|
||||
# Check if we are in a special state, i.e. loading from a state dict coming from a different architecture
|
||||
prefix = model.base_model_prefix
|
||||
|
|
@ -4839,9 +4835,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
# Compute expected model keys
|
||||
expected_keys = list(model_to_load.state_dict().keys())
|
||||
if hf_quantizer is not None:
|
||||
modified_keys = loaded_state_dict_keys
|
||||
modified_keys = renamed_loaded_keys
|
||||
if loading_base_model_from_task_state_dict:
|
||||
modified_keys = [s[len(_prefix) :] for s in loaded_state_dict_keys if s.startswith(_prefix)]
|
||||
modified_keys = [s[len(_prefix) :] for s in modified_keys if s.startswith(_prefix)]
|
||||
expected_keys = hf_quantizer.update_expected_keys(model_to_load, expected_keys, modified_keys)
|
||||
|
||||
error_msgs = []
|
||||
|
|
@ -4858,20 +4854,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
|
||||
)
|
||||
|
||||
# Modify the keys in-place if needed
|
||||
state_dict = cls._fix_state_dict_keys_on_load(state_dict)
|
||||
|
||||
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
||||
# matching the weights in the model.
|
||||
mismatched_keys += _find_mismatched_keys(
|
||||
model,
|
||||
state_dict,
|
||||
renamed_loaded_keys,
|
||||
loaded_state_dict_keys,
|
||||
loading_base_model_from_task_state_dict,
|
||||
loading_task_model_from_base_state_dict,
|
||||
ignore_mismatched_sizes,
|
||||
model_to_load, state_dict, renamed_loaded_keys, ignore_mismatched_sizes, start_prefix_to_remove
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage or gguf_file is not None:
|
||||
state_dict = cls._fix_state_dict_keys_on_load(state_dict)
|
||||
# Skip it with fsdp on ranks other than 0
|
||||
if not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
|
||||
new_error_msgs, disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
|
||||
|
|
@ -4898,7 +4890,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
assign_to_params_buffers = check_support_param_buffer_assignment(
|
||||
model_to_load, state_dict, start_prefix_to_remove
|
||||
)
|
||||
state_dict = cls._fix_state_dict_keys_on_load(state_dict)
|
||||
error_msgs += _load_state_dict_into_model(
|
||||
model_to_load, state_dict, start_prefix_to_remove, assign_to_params_buffers
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue