mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
remove old functions
This commit is contained in:
parent
98aa2bdad6
commit
aca9b22c25
2 changed files with 16 additions and 76 deletions
|
|
@ -725,32 +725,6 @@ def find_submodule_and_param_name(model, long_key, start_prefix):
|
|||
return submodule, split_key[0]
|
||||
|
||||
|
||||
def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix):
|
||||
"""
|
||||
Moves `loaded_state_dict_keys` in model to meta device which frees up the memory taken by those params.
|
||||
|
||||
`start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
|
||||
`bert.pooler.dense.weight`
|
||||
|
||||
"""
|
||||
|
||||
# dematerialize param storage for keys that are going to be replaced by state_dict, by
|
||||
# putting those on the meta device
|
||||
for k in loaded_state_dict_keys:
|
||||
submodule, param_name = find_submodule_and_param_name(model, k, start_prefix)
|
||||
if submodule is not None:
|
||||
# selectively switch to the meta device only those params/buffers that will
|
||||
# be next replaced from state_dict. This a complex way to do p.to_("meta")
|
||||
# since we have no in-place to_ for tensors.
|
||||
new_val = getattr(submodule, param_name)
|
||||
if isinstance(new_val, torch.nn.Parameter):
|
||||
# isinstance returns False for Params on meta device, so switch after the check
|
||||
new_val = torch.nn.Parameter(new_val.to("meta"))
|
||||
else:
|
||||
new_val = new_val.to("meta")
|
||||
setattr(submodule, param_name, new_val)
|
||||
|
||||
|
||||
def _load_state_dict_into_meta_model(
|
||||
model: "PreTrainedModel",
|
||||
state_dict: Dict,
|
||||
|
|
@ -932,7 +906,7 @@ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
|||
return weights_name
|
||||
|
||||
|
||||
def _get_checkpoint_files(
|
||||
def get_resolved_checkpoint_files(
|
||||
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
||||
subfolder: str,
|
||||
variant: Optional[str],
|
||||
|
|
@ -949,7 +923,10 @@ def _get_checkpoint_files(
|
|||
revision: str,
|
||||
commit_hash: Optional[str],
|
||||
) -> Tuple[Optional[List[str]], Optional[Dict]]:
|
||||
"""Get the full checkpoint filenames where the weights reside, and optional metadata if the checkpoints are sharded."""
|
||||
"""Get all the checkpoint filenames based on `pretrained_model_name_or_path`, and optional metadata if the
|
||||
checkpoints are sharded.
|
||||
This function will download the data if necesary.
|
||||
"""
|
||||
is_sharded = False
|
||||
|
||||
if pretrained_model_name_or_path is not None and gguf_file is None:
|
||||
|
|
@ -4293,7 +4270,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
"You cannot combine Quantization and loading a model from a GGUF file, try again by making sure you did not passed a `quantization_config` or that you did not load a quantized model from the Hub."
|
||||
)
|
||||
|
||||
checkpoint_files, sharded_metadata = _get_checkpoint_files(
|
||||
checkpoint_files, sharded_metadata = get_resolved_checkpoint_files(
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
subfolder=subfolder,
|
||||
variant=variant,
|
||||
|
|
@ -5035,47 +5012,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
|
||||
return retrieved_modules
|
||||
|
||||
@staticmethod
|
||||
def _load_pretrained_model_low_mem(
|
||||
model,
|
||||
loaded_state_dict_keys,
|
||||
resolved_archive_file,
|
||||
start_prefix="",
|
||||
hf_quantizer=None,
|
||||
pretrained_model_name_or_path=None,
|
||||
weights_only=True,
|
||||
):
|
||||
"""
|
||||
This is an experimental function that loads the model using ~1.x model size CPU memory
|
||||
|
||||
Before you call it do:
|
||||
|
||||
1. save which state_dict keys are available
|
||||
2. drop state_dict before model is created, since the latter takes 1x model size memory
|
||||
|
||||
Here then we continue:
|
||||
|
||||
3. switch to the meta device all params/buffers that are going to be replaced from the loaded state_dict
|
||||
4. load state_dict 2nd time
|
||||
5. replace the params/buffers from the state_dict
|
||||
|
||||
Currently, it doesn't handle missing_keys, unexpected_keys, mismatched_keys. It can't handle deepspeed. To
|
||||
handle bitsandbytes, needs non-empty hf_quantizer argument.
|
||||
"""
|
||||
|
||||
_move_model_to_meta(model, loaded_state_dict_keys, start_prefix)
|
||||
state_dict = load_state_dict(resolved_archive_file, weights_only=weights_only)
|
||||
expected_keys = loaded_state_dict_keys # plug for missing expected_keys. TODO: replace with proper keys
|
||||
fixed_state_dict = model._fix_state_dict_keys_on_load(state_dict)
|
||||
error_msgs = _load_state_dict_into_meta_model(
|
||||
model,
|
||||
fixed_state_dict,
|
||||
start_prefix,
|
||||
expected_keys=expected_keys,
|
||||
hf_quantizer=hf_quantizer,
|
||||
)
|
||||
return error_msgs
|
||||
|
||||
@classmethod
|
||||
def register_for_auto_class(cls, auto_class="AutoModel"):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ from torch import Tensor
|
|||
from vissl.models.model_helpers import get_trunk_forward_outputs
|
||||
|
||||
from transformers import AutoImageProcessor, RegNetConfig, RegNetForImageClassification, RegNetModel
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.modeling_utils import _load_state_dict_into_meta_model, load_state_dict
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
|
|
@ -244,14 +244,18 @@ def convert_weights_and_push(save_directory: Path, model_name: str = None, push_
|
|||
our_model_func = RegNetModel
|
||||
if "in1k" in model_name:
|
||||
our_model_func = RegNetForImageClassification
|
||||
our_model = our_model_func(our_config)
|
||||
# place our model to the meta device (so remove all the weights)
|
||||
our_model.to(torch.device("meta"))
|
||||
with torch.device("meta"):
|
||||
our_model = our_model_func(our_config)
|
||||
logger.info("Loading state_dict in our model.")
|
||||
# load state dict
|
||||
state_dict_keys = our_model.state_dict().keys()
|
||||
PreTrainedModel._load_pretrained_model_low_mem(
|
||||
our_model, state_dict_keys, [save_directory / f"{model_name}.pth"]
|
||||
state_dict = load_state_dict(save_directory / f"{model_name}.pth", weights_only=True)
|
||||
fixed_state_dict = our_model._fix_state_dict_keys_on_load(state_dict)
|
||||
_load_state_dict_into_meta_model(
|
||||
our_model,
|
||||
fixed_state_dict,
|
||||
start_prefix="",
|
||||
expected_keys=state_dict_keys,
|
||||
)
|
||||
logger.info("Finally, pushing!")
|
||||
# push it to hub
|
||||
|
|
|
|||
Loading…
Reference in a new issue