Update modeling_utils.py

This commit is contained in:
Cyril Vallez 2025-02-06 11:19:55 +01:00
parent b149b1f6fe
commit 8aca12c774
No known key found for this signature in database

View file

@ -1464,18 +1464,18 @@ def _move_missing_keys_back_to_cpu(
key = f"{prefix}.{key}"
elif key.startswith(prefix) and ".".join(key.split(".")[1:]) in list(model_state_dict.keys()):
key = ".".join(key.split(".")[1:])
param = model_state_dict[key]
# upcast in fp32 if any
target_dtype = dtype
if (
keep_in_fp32_modules is not None
and dtype == torch.float16
and any(module_to_keep_in_fp32 in key.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules)
):
target_dtype = torch.float32
if param.device == torch.device("meta"):
# upcast in fp32 if any
target_dtype = dtype
if (
keep_in_fp32_modules is not None
and dtype == torch.float16
and any(module_to_keep_in_fp32 in key.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules)
):
target_dtype = torch.float32
value = torch.empty(*param.size(), dtype=target_dtype)
if (
not is_quantized
@ -4441,18 +4441,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
# Instantiate contexts under which to load the model
init_contexts = [no_init_weights(_enable=_fast_init)]
init_contexts = []
if is_deepspeed_zero3_enabled() and not is_quantized and not _is_ds_init_called:
import deepspeed
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
init_contexts = [
init_contexts.extend([
deepspeed.zero.Init(config_dict_or_path=deepspeed_config()),
set_zero3_state(),
] + init_contexts
no_init_weights(_enable=_fast_init)
])
# In this case, simply load on meta device
elif low_cpu_mem_usage:
init_contexts.append(init_empty_weights())
# Oherwise, load without initializing weights (i.e., not calling functions to instantiate proper weight distribution)
else:
init_contexts.append(no_init_weights(_enable=_fast_init))
if is_deepspeed_zero3_enabled() and is_quantized:
init_contexts.append(set_quantized_state())