Update modeling_utils.py

This commit is contained in:
Cyril Vallez 2025-02-05 15:35:10 +01:00
parent 1bdb7bba52
commit c3e818561e
No known key found for this signature in database

View file

@ -195,28 +195,32 @@ TORCH_INIT_FUNCTIONS = {
@contextmanager
def no_init_weights():
def no_init_weights(_enable=True):
"""
Context manager to globally disable weight initialization to speed up loading large models.
TODO(Patrick): Delete safety argument `_enable=True` at next major version. .
"""
global _init_weights
old_init_weights = _init_weights
_init_weights = False
def _skip_init(*args, **kwargs):
pass
if _enable:
_init_weights = False
# Save the original initialization functions
for name, init_func in TORCH_INIT_FUNCTIONS.items():
setattr(torch.nn.init, name, _skip_init)
def _skip_init(*args, **kwargs):
pass
# # Save the original initialization functions
for name, init_func in TORCH_INIT_FUNCTIONS.items():
setattr(torch.nn.init, name, _skip_init)
try:
yield
finally:
_init_weights = old_init_weights
# Restore the original initialization functions
for name, init_func in TORCH_INIT_FUNCTIONS.items():
setattr(torch.nn.init, name, init_func)
if _enable:
# # Restore the original initialization functions
for name, init_func in TORCH_INIT_FUNCTIONS.items():
setattr(torch.nn.init, name, init_func)
@contextmanager
@ -3904,7 +3908,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
</Tip>
_fast_init(`bool`, *optional*, defaults to `True`):
Whether or not to disable fast initialization.
<Tip warning={true}>
One should only disable *_fast_init* to ensure backwards compatibility with `transformers.__version__ <
4.6.0` for seeded model initialization. This argument will be removed at the next major version. See
[pull request 11471](https://github.com/huggingface/transformers/pull/11471) for more information.
</Tip>
attn_implementation (`str`, *optional*):
The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.
@ -4052,6 +4065,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
_ = kwargs.pop("mirror", None)
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
_fast_init = kwargs.pop("_fast_init", True)
torch_dtype = kwargs.pop("torch_dtype", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", None)
device_map = kwargs.pop("device_map", None)
@ -4397,7 +4411,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
# Instantiate contexts under which to load the model
init_contexts = [no_init_weights()]
init_contexts = [no_init_weights(_enable=_fast_init)]
if is_deepspeed_zero3_enabled() and not is_quantized and not _is_ds_init_called:
import deepspeed
@ -4478,6 +4492,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
keep_in_fp32_modules=keep_in_fp32_modules,
gguf_file=gguf_file,
weights_only=weights_only,
_fast_init=_fast_init,
)
# make sure token embedding weights are still tied if needed
@ -4661,6 +4676,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
keep_in_fp32_modules: Optional[List[str]] = None,
gguf_file: Optional[str] = None,
weights_only: bool = True,
_fast_init: bool = True,
):
# Get all the keys of the state dicts that we have to initialize the model
if sharded_metadata is not None:
@ -4721,10 +4737,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if param.device == torch.device("meta"):
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype))
# correctly initialize the missing keys
model = _initialize_missing_keys(
model, renamed_loaded_keys, ignore_mismatched_sizes, has_prefix_module, expects_prefix_module, is_quantized
)
# correctly initialize the missing keys if it was skipped before
if _fast_init:
model = _initialize_missing_keys(
model, renamed_loaded_keys, ignore_mismatched_sizes, has_prefix_module, expects_prefix_module, is_quantized
)
# Set some modules to fp32 if needed
if keep_in_fp32_modules is not None: