mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
update
This commit is contained in:
parent
8025c92c7b
commit
f88bb46428
1 changed files with 14 additions and 2 deletions
|
|
@ -1416,6 +1416,7 @@ def _find_missing_and_unexpected_keys(
|
|||
loading_base_model_from_task_state_dict: bool,
|
||||
loading_task_model_from_base_state_dict: bool,
|
||||
hf_quantizer: Optional[HfQuantizer],
|
||||
low_cpu_mem_usage: bool,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
"""Find missing keys (keys that are part of the model parameters but were NOT found in the loaded state dict keys) and unexpected keys
|
||||
(keys found in the loaded state dict keys, but that are NOT part of the model parameters)
|
||||
|
|
@ -1442,7 +1443,7 @@ def _find_missing_and_unexpected_keys(
|
|||
unexpected_keys.add(key)
|
||||
|
||||
# Remove nonpersistent buffers from unexpected keys: they are not in the expected keys (model state dict), but
|
||||
# may be in the loaded keys
|
||||
# may be in the loaded keys. Note that removing all buffers does the job, as they were part of the expected keys anyway
|
||||
model_buffers = {n for n, _ in model.named_buffers()}
|
||||
unexpected_keys = sorted(unexpected_keys - model_buffers)
|
||||
|
||||
|
|
@ -1453,7 +1454,17 @@ def _find_missing_and_unexpected_keys(
|
|||
unexpected_keys = [k for k in unexpected_keys if "rotary_emb.inv_freq" not in k]
|
||||
|
||||
model.tie_weights()
|
||||
tied_params = find_tied_parameters(model)
|
||||
if not low_cpu_mem_usage and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
|
||||
ptrs = collections.defaultdict(list)
|
||||
for name, tensor in model.state_dict().items():
|
||||
id_tensor = id_tensor_storage(tensor)
|
||||
ptrs[id_tensor].append(name)
|
||||
|
||||
# These are all the pointers of shared tensors.
|
||||
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
|
||||
else:
|
||||
# id function doesn't work for meta tensor so we need this function
|
||||
tied_params = find_tied_parameters(model)
|
||||
|
||||
for group in tied_params:
|
||||
missing_in_group = [k for k in missing_keys if k in group]
|
||||
|
|
@ -4794,6 +4805,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
loading_base_model_from_task_state_dict,
|
||||
loading_task_model_from_base_state_dict,
|
||||
hf_quantizer,
|
||||
low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
# Move missing keys back to cpu from meta device (because they won't be moved when loading the weights as
|
||||
|
|
|
|||
Loading…
Reference in a new issue