diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4de417eed..81902d02b 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1416,6 +1416,7 @@ def _find_missing_and_unexpected_keys( loading_base_model_from_task_state_dict: bool, loading_task_model_from_base_state_dict: bool, hf_quantizer: Optional[HfQuantizer], + low_cpu_mem_usage: bool, ) -> Tuple[List[str], List[str]]: """Find missing keys (keys that are part of the model parameters but were NOT found in the loaded state dict keys) and unexpected keys (keys found in the loaded state dict keys, but that are NOT part of the model parameters) @@ -1442,7 +1443,7 @@ def _find_missing_and_unexpected_keys( unexpected_keys.add(key) # Remove nonpersistent buffers from unexpected keys: they are not in the expected keys (model state dict), but - # may be in the loaded keys + # may be in the loaded keys. Note that removing all buffers does the job, as they were part of the expected keys anyway model_buffers = {n for n, _ in model.named_buffers()} unexpected_keys = sorted(unexpected_keys - model_buffers) @@ -1453,7 +1454,17 @@ def _find_missing_and_unexpected_keys( unexpected_keys = [k for k in unexpected_keys if "rotary_emb.inv_freq" not in k] model.tie_weights() - tied_params = find_tied_parameters(model) + if not low_cpu_mem_usage and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled(): + ptrs = collections.defaultdict(list) + for name, tensor in model.state_dict().items(): + id_tensor = id_tensor_storage(tensor) + ptrs[id_tensor].append(name) + + # These are all the pointers of shared tensors. + tied_params = [names for _, names in ptrs.items() if len(names) > 1] + else: + # id function doesn't work for meta tensor so we need this function + tied_params = find_tied_parameters(model) for group in tied_params: missing_in_group = [k for k in missing_keys if k in group] @@ -4794,6 +4805,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix loading_base_model_from_task_state_dict, loading_task_model_from_base_state_dict, hf_quantizer, + low_cpu_mem_usage, ) # Move missing keys back to cpu from meta device (because they won't be moved when loading the weights as