From 8aca12c77471685be31e34dfcb2f28a66f0c0163 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 6 Feb 2025 11:19:55 +0100 Subject: [PATCH] Update modeling_utils.py --- src/transformers/modeling_utils.py | 32 +++++++++++++++++------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 8e41d6e3b..e100aca98 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1464,18 +1464,18 @@ def _move_missing_keys_back_to_cpu( key = f"{prefix}.{key}" elif key.startswith(prefix) and ".".join(key.split(".")[1:]) in list(model_state_dict.keys()): key = ".".join(key.split(".")[1:]) + param = model_state_dict[key] - - # upcast in fp32 if any - target_dtype = dtype - if ( - keep_in_fp32_modules is not None - and dtype == torch.float16 - and any(module_to_keep_in_fp32 in key.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules) - ): - target_dtype = torch.float32 - if param.device == torch.device("meta"): + # upcast in fp32 if any + target_dtype = dtype + if ( + keep_in_fp32_modules is not None + and dtype == torch.float16 + and any(module_to_keep_in_fp32 in key.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules) + ): + target_dtype = torch.float32 + value = torch.empty(*param.size(), dtype=target_dtype) if ( not is_quantized @@ -4441,18 +4441,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) # Instantiate contexts under which to load the model - init_contexts = [no_init_weights(_enable=_fast_init)] - + init_contexts = [] if is_deepspeed_zero3_enabled() and not is_quantized and not _is_ds_init_called: import deepspeed logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model") - init_contexts = [ + init_contexts.extend([ deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state(), - ] + init_contexts + no_init_weights(_enable=_fast_init) + ]) + # In this case, simply load on meta device elif low_cpu_mem_usage: init_contexts.append(init_empty_weights()) + # Oherwise, load without initializing weights (i.e., not calling functions to instantiate proper weight distribution) + else: + init_contexts.append(no_init_weights(_enable=_fast_init)) if is_deepspeed_zero3_enabled() and is_quantized: init_contexts.append(set_quantized_state())