simplify renaming logic

This commit is contained in:
Cyril Vallez 2025-02-10 11:52:55 +01:00
parent 1c0c5cf7b4
commit 89c740d07a
No known key found for this signature in database

View file

@ -1439,41 +1439,35 @@ def _find_missing_and_unexpected_keys(
def _find_mismatched_keys(
model: "PreTrainedModel",
model_to_load: "PreTrainedModel",
state_dict: Dict,
renamed_loaded_keys: List[str],
original_loaded_keys: List[str],
loading_base_model_from_task_state_dict: bool,
loading_task_model_from_base_state_dict: bool,
ignore_mismatched_sizes: bool,
prefix_to_remove: str,
) -> List:
"""Find mismatch of shapes between the model parameters and the loaded state dict, and optionally remove the
problematic keys from `state_dict` if `ignore_mismatched_sizes` is `True`."""
mismatched_keys = []
if ignore_mismatched_sizes:
prefix = model.base_model_prefix
model_state_dict = model.state_dict()
for checkpoint_key, model_key in zip(original_loaded_keys, renamed_loaded_keys):
model_state_dict = model_to_load.state_dict()
for key in renamed_loaded_keys:
# If the checkpoint is sharded, we may not have the key here.
if checkpoint_key not in state_dict:
if key not in state_dict:
continue
model_key = _adjust_loaded_keys_prefix(
model_key, prefix, loading_base_model_from_task_state_dict, loading_task_model_from_base_state_dict
)
# Remove the prefix if needed
adjusted_key = key[len(prefix_to_remove) :]
if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
if adjusted_key in model_state_dict and state_dict[key].shape != model_state_dict[adjusted_key].shape:
if (
state_dict[checkpoint_key].shape[-1] == 1
and state_dict[checkpoint_key].numel() * 2 == model_state_dict[model_key].numel()
state_dict[key].shape[-1] == 1
and state_dict[key].numel() * 2 == model_state_dict[adjusted_key].numel()
):
# This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences.
# Without matching with module type or paramter type it seems like a practical way to detect valid 4bit weights.
pass
else:
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
mismatched_keys.append((key, state_dict[key].shape, model_state_dict[adjusted_key].shape))
del state_dict[key]
return mismatched_keys
@ -4691,7 +4685,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
loaded_state_dict_keys = list(load_state_dict(checkpoint_files[0], weights_only=weights_only).keys())
# Rename the keys (use dummy values in input dict as a small trick)
renamed_loaded_keys = list(cls._fix_state_dict_keys_on_load({key: "" for key in loaded_state_dict_keys}).keys())
renamed_loaded_keys = list(
cls._fix_state_dict_keys_on_load({key: "" for key in loaded_state_dict_keys}).keys()
)
# Check if we are in a special state, i.e. loading from a state dict coming from a different architecture
prefix = model.base_model_prefix
@ -4839,9 +4835,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Compute expected model keys
expected_keys = list(model_to_load.state_dict().keys())
if hf_quantizer is not None:
modified_keys = loaded_state_dict_keys
modified_keys = renamed_loaded_keys
if loading_base_model_from_task_state_dict:
modified_keys = [s[len(_prefix) :] for s in loaded_state_dict_keys if s.startswith(_prefix)]
modified_keys = [s[len(_prefix) :] for s in modified_keys if s.startswith(_prefix)]
expected_keys = hf_quantizer.update_expected_keys(model_to_load, expected_keys, modified_keys)
error_msgs = []
@ -4858,20 +4854,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
)
# Modify the keys in-place if needed
state_dict = cls._fix_state_dict_keys_on_load(state_dict)
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
mismatched_keys += _find_mismatched_keys(
model,
state_dict,
renamed_loaded_keys,
loaded_state_dict_keys,
loading_base_model_from_task_state_dict,
loading_task_model_from_base_state_dict,
ignore_mismatched_sizes,
model_to_load, state_dict, renamed_loaded_keys, ignore_mismatched_sizes, start_prefix_to_remove
)
if low_cpu_mem_usage or gguf_file is not None:
state_dict = cls._fix_state_dict_keys_on_load(state_dict)
# Skip it with fsdp on ranks other than 0
if not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
new_error_msgs, disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
@ -4898,7 +4890,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
assign_to_params_buffers = check_support_param_buffer_assignment(
model_to_load, state_dict, start_prefix_to_remove
)
state_dict = cls._fix_state_dict_keys_on_load(state_dict)
error_msgs += _load_state_dict_into_model(
model_to_load, state_dict, start_prefix_to_remove, assign_to_params_buffers
)