This commit is contained in:
Cyril Vallez 2025-02-05 22:12:55 +01:00
parent ff1078387e
commit 27e1615466
No known key found for this signature in database

View file

@ -894,7 +894,6 @@ def _load_state_dict_into_meta_model(
setattr(module, tensor_name, value)
# TODO: consider removing used param_parts from state_dict before return
# In this case, let's parallelize the modules!
if tp_key_registry is not None:
plan = None
@ -906,7 +905,7 @@ def _load_state_dict_into_meta_model(
plan = tp_key_registry[module_prefix]["plan"]
prefix = module_prefix
break
if plan is not None:
del tp_key_registry[prefix]
parent_module = model
@ -4775,7 +4774,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# correctly initialize the missing keys if it was skipped before
if _fast_init:
model = _initialize_missing_keys(
model, renamed_loaded_keys, ignore_mismatched_sizes, has_prefix_module, expects_prefix_module, is_quantized
model,
renamed_loaded_keys,
ignore_mismatched_sizes,
has_prefix_module,
expects_prefix_module,
is_quantized,
)
# Set some modules to fp32 if needed
@ -4971,7 +4975,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
save_offload_index(disk_offload_index, disk_offload_folder)
disk_offload_index = None
# 1-by-1 param loading for the cpu params
# one-at-a-time param loading for the cpu offloaded params
if offload_state_dict:
# Load back temporarily offloaded state dict
load_offloaded_weights(model_to_load, cpu_offload_index, cpu_offload_folder)