mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Update modeling_utils.py
This commit is contained in:
parent
b149b1f6fe
commit
8aca12c774
1 changed files with 18 additions and 14 deletions
|
|
@ -1464,18 +1464,18 @@ def _move_missing_keys_back_to_cpu(
|
|||
key = f"{prefix}.{key}"
|
||||
elif key.startswith(prefix) and ".".join(key.split(".")[1:]) in list(model_state_dict.keys()):
|
||||
key = ".".join(key.split(".")[1:])
|
||||
|
||||
param = model_state_dict[key]
|
||||
|
||||
# upcast in fp32 if any
|
||||
target_dtype = dtype
|
||||
if (
|
||||
keep_in_fp32_modules is not None
|
||||
and dtype == torch.float16
|
||||
and any(module_to_keep_in_fp32 in key.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules)
|
||||
):
|
||||
target_dtype = torch.float32
|
||||
|
||||
if param.device == torch.device("meta"):
|
||||
# upcast in fp32 if any
|
||||
target_dtype = dtype
|
||||
if (
|
||||
keep_in_fp32_modules is not None
|
||||
and dtype == torch.float16
|
||||
and any(module_to_keep_in_fp32 in key.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules)
|
||||
):
|
||||
target_dtype = torch.float32
|
||||
|
||||
value = torch.empty(*param.size(), dtype=target_dtype)
|
||||
if (
|
||||
not is_quantized
|
||||
|
|
@ -4441,18 +4441,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
)
|
||||
|
||||
# Instantiate contexts under which to load the model
|
||||
init_contexts = [no_init_weights(_enable=_fast_init)]
|
||||
|
||||
init_contexts = []
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized and not _is_ds_init_called:
|
||||
import deepspeed
|
||||
|
||||
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
|
||||
init_contexts = [
|
||||
init_contexts.extend([
|
||||
deepspeed.zero.Init(config_dict_or_path=deepspeed_config()),
|
||||
set_zero3_state(),
|
||||
] + init_contexts
|
||||
no_init_weights(_enable=_fast_init)
|
||||
])
|
||||
# In this case, simply load on meta device
|
||||
elif low_cpu_mem_usage:
|
||||
init_contexts.append(init_empty_weights())
|
||||
# Oherwise, load without initializing weights (i.e., not calling functions to instantiate proper weight distribution)
|
||||
else:
|
||||
init_contexts.append(no_init_weights(_enable=_fast_init))
|
||||
|
||||
if is_deepspeed_zero3_enabled() and is_quantized:
|
||||
init_contexts.append(set_quantized_state())
|
||||
|
|
|
|||
Loading…
Reference in a new issue