This commit is contained in:
Cyril Vallez 2025-02-07 11:07:53 +01:00
parent 8025c92c7b
commit f88bb46428
No known key found for this signature in database

View file

@ -1416,6 +1416,7 @@ def _find_missing_and_unexpected_keys(
loading_base_model_from_task_state_dict: bool, loading_base_model_from_task_state_dict: bool,
loading_task_model_from_base_state_dict: bool, loading_task_model_from_base_state_dict: bool,
hf_quantizer: Optional[HfQuantizer], hf_quantizer: Optional[HfQuantizer],
low_cpu_mem_usage: bool,
) -> Tuple[List[str], List[str]]: ) -> Tuple[List[str], List[str]]:
"""Find missing keys (keys that are part of the model parameters but were NOT found in the loaded state dict keys) and unexpected keys """Find missing keys (keys that are part of the model parameters but were NOT found in the loaded state dict keys) and unexpected keys
(keys found in the loaded state dict keys, but that are NOT part of the model parameters) (keys found in the loaded state dict keys, but that are NOT part of the model parameters)
@ -1442,7 +1443,7 @@ def _find_missing_and_unexpected_keys(
unexpected_keys.add(key) unexpected_keys.add(key)
# Remove nonpersistent buffers from unexpected keys: they are not in the expected keys (model state dict), but # Remove nonpersistent buffers from unexpected keys: they are not in the expected keys (model state dict), but
# may be in the loaded keys # may be in the loaded keys. Note that removing all buffers does the job, as they were part of the expected keys anyway
model_buffers = {n for n, _ in model.named_buffers()} model_buffers = {n for n, _ in model.named_buffers()}
unexpected_keys = sorted(unexpected_keys - model_buffers) unexpected_keys = sorted(unexpected_keys - model_buffers)
@ -1453,7 +1454,17 @@ def _find_missing_and_unexpected_keys(
unexpected_keys = [k for k in unexpected_keys if "rotary_emb.inv_freq" not in k] unexpected_keys = [k for k in unexpected_keys if "rotary_emb.inv_freq" not in k]
model.tie_weights() model.tie_weights()
tied_params = find_tied_parameters(model) if not low_cpu_mem_usage and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
ptrs = collections.defaultdict(list)
for name, tensor in model.state_dict().items():
id_tensor = id_tensor_storage(tensor)
ptrs[id_tensor].append(name)
# These are all the pointers of shared tensors.
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
else:
# id function doesn't work for meta tensor so we need this function
tied_params = find_tied_parameters(model)
for group in tied_params: for group in tied_params:
missing_in_group = [k for k in missing_keys if k in group] missing_in_group = [k for k in missing_keys if k in group]
@ -4794,6 +4805,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
loading_base_model_from_task_state_dict, loading_base_model_from_task_state_dict,
loading_task_model_from_base_state_dict, loading_task_model_from_base_state_dict,
hf_quantizer, hf_quantizer,
low_cpu_mem_usage,
) )
# Move missing keys back to cpu from meta device (because they won't be moved when loading the weights as # Move missing keys back to cpu from meta device (because they won't be moved when loading the weights as