mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Update modeling_utils.py
This commit is contained in:
parent
1bdb7bba52
commit
c3e818561e
1 changed files with 32 additions and 15 deletions
|
|
@ -195,28 +195,32 @@ TORCH_INIT_FUNCTIONS = {
|
|||
|
||||
|
||||
@contextmanager
|
||||
def no_init_weights():
|
||||
def no_init_weights(_enable=True):
|
||||
"""
|
||||
Context manager to globally disable weight initialization to speed up loading large models.
|
||||
|
||||
TODO(Patrick): Delete safety argument `_enable=True` at next major version. .
|
||||
"""
|
||||
global _init_weights
|
||||
old_init_weights = _init_weights
|
||||
_init_weights = False
|
||||
|
||||
def _skip_init(*args, **kwargs):
|
||||
pass
|
||||
if _enable:
|
||||
_init_weights = False
|
||||
|
||||
# Save the original initialization functions
|
||||
for name, init_func in TORCH_INIT_FUNCTIONS.items():
|
||||
setattr(torch.nn.init, name, _skip_init)
|
||||
def _skip_init(*args, **kwargs):
|
||||
pass
|
||||
|
||||
# # Save the original initialization functions
|
||||
for name, init_func in TORCH_INIT_FUNCTIONS.items():
|
||||
setattr(torch.nn.init, name, _skip_init)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_init_weights = old_init_weights
|
||||
# Restore the original initialization functions
|
||||
for name, init_func in TORCH_INIT_FUNCTIONS.items():
|
||||
setattr(torch.nn.init, name, init_func)
|
||||
if _enable:
|
||||
# # Restore the original initialization functions
|
||||
for name, init_func in TORCH_INIT_FUNCTIONS.items():
|
||||
setattr(torch.nn.init, name, init_func)
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
|
@ -3904,7 +3908,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
|
||||
|
||||
</Tip>
|
||||
_fast_init(`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to disable fast initialization.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
One should only disable *_fast_init* to ensure backwards compatibility with `transformers.__version__ <
|
||||
4.6.0` for seeded model initialization. This argument will be removed at the next major version. See
|
||||
[pull request 11471](https://github.com/huggingface/transformers/pull/11471) for more information.
|
||||
|
||||
</Tip>
|
||||
attn_implementation (`str`, *optional*):
|
||||
The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.
|
||||
|
||||
|
|
@ -4052,6 +4065,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
_ = kwargs.pop("mirror", None)
|
||||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||
from_auto_class = kwargs.pop("_from_auto", False)
|
||||
_fast_init = kwargs.pop("_fast_init", True)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", None)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
|
|
@ -4397,7 +4411,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
)
|
||||
|
||||
# Instantiate contexts under which to load the model
|
||||
init_contexts = [no_init_weights()]
|
||||
init_contexts = [no_init_weights(_enable=_fast_init)]
|
||||
|
||||
if is_deepspeed_zero3_enabled() and not is_quantized and not _is_ds_init_called:
|
||||
import deepspeed
|
||||
|
|
@ -4478,6 +4492,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
gguf_file=gguf_file,
|
||||
weights_only=weights_only,
|
||||
_fast_init=_fast_init,
|
||||
)
|
||||
|
||||
# make sure token embedding weights are still tied if needed
|
||||
|
|
@ -4661,6 +4676,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
keep_in_fp32_modules: Optional[List[str]] = None,
|
||||
gguf_file: Optional[str] = None,
|
||||
weights_only: bool = True,
|
||||
_fast_init: bool = True,
|
||||
):
|
||||
# Get all the keys of the state dicts that we have to initialize the model
|
||||
if sharded_metadata is not None:
|
||||
|
|
@ -4721,10 +4737,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
if param.device == torch.device("meta"):
|
||||
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype))
|
||||
|
||||
# correctly initialize the missing keys
|
||||
model = _initialize_missing_keys(
|
||||
model, renamed_loaded_keys, ignore_mismatched_sizes, has_prefix_module, expects_prefix_module, is_quantized
|
||||
)
|
||||
# correctly initialize the missing keys if it was skipped before
|
||||
if _fast_init:
|
||||
model = _initialize_missing_keys(
|
||||
model, renamed_loaded_keys, ignore_mismatched_sizes, has_prefix_module, expects_prefix_module, is_quantized
|
||||
)
|
||||
|
||||
# Set some modules to fp32 if needed
|
||||
if keep_in_fp32_modules is not None:
|
||||
|
|
|
|||
Loading…
Reference in a new issue