remove old functions

This commit is contained in:
Cyril Vallez 2025-02-08 16:44:50 +01:00
parent 98aa2bdad6
commit aca9b22c25
No known key found for this signature in database
2 changed files with 16 additions and 76 deletions

View file

@ -725,32 +725,6 @@ def find_submodule_and_param_name(model, long_key, start_prefix):
return submodule, split_key[0]
def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix):
"""
Moves `loaded_state_dict_keys` in model to meta device which frees up the memory taken by those params.
`start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
`bert.pooler.dense.weight`
"""
# dematerialize param storage for keys that are going to be replaced by state_dict, by
# putting those on the meta device
for k in loaded_state_dict_keys:
submodule, param_name = find_submodule_and_param_name(model, k, start_prefix)
if submodule is not None:
# selectively switch to the meta device only those params/buffers that will
# be next replaced from state_dict. This a complex way to do p.to_("meta")
# since we have no in-place to_ for tensors.
new_val = getattr(submodule, param_name)
if isinstance(new_val, torch.nn.Parameter):
# isinstance returns False for Params on meta device, so switch after the check
new_val = torch.nn.Parameter(new_val.to("meta"))
else:
new_val = new_val.to("meta")
setattr(submodule, param_name, new_val)
def _load_state_dict_into_meta_model(
model: "PreTrainedModel",
state_dict: Dict,
@ -932,7 +906,7 @@ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
return weights_name
def _get_checkpoint_files(
def get_resolved_checkpoint_files(
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
subfolder: str,
variant: Optional[str],
@ -949,7 +923,10 @@ def _get_checkpoint_files(
revision: str,
commit_hash: Optional[str],
) -> Tuple[Optional[List[str]], Optional[Dict]]:
"""Get the full checkpoint filenames where the weights reside, and optional metadata if the checkpoints are sharded."""
"""Get all the checkpoint filenames based on `pretrained_model_name_or_path`, and optional metadata if the
checkpoints are sharded.
This function will download the data if necesary.
"""
is_sharded = False
if pretrained_model_name_or_path is not None and gguf_file is None:
@ -4293,7 +4270,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"You cannot combine Quantization and loading a model from a GGUF file, try again by making sure you did not passed a `quantization_config` or that you did not load a quantized model from the Hub."
)
checkpoint_files, sharded_metadata = _get_checkpoint_files(
checkpoint_files, sharded_metadata = get_resolved_checkpoint_files(
pretrained_model_name_or_path=pretrained_model_name_or_path,
subfolder=subfolder,
variant=variant,
@ -5035,47 +5012,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return retrieved_modules
@staticmethod
def _load_pretrained_model_low_mem(
model,
loaded_state_dict_keys,
resolved_archive_file,
start_prefix="",
hf_quantizer=None,
pretrained_model_name_or_path=None,
weights_only=True,
):
"""
This is an experimental function that loads the model using ~1.x model size CPU memory
Before you call it do:
1. save which state_dict keys are available
2. drop state_dict before model is created, since the latter takes 1x model size memory
Here then we continue:
3. switch to the meta device all params/buffers that are going to be replaced from the loaded state_dict
4. load state_dict 2nd time
5. replace the params/buffers from the state_dict
Currently, it doesn't handle missing_keys, unexpected_keys, mismatched_keys. It can't handle deepspeed. To
handle bitsandbytes, needs non-empty hf_quantizer argument.
"""
_move_model_to_meta(model, loaded_state_dict_keys, start_prefix)
state_dict = load_state_dict(resolved_archive_file, weights_only=weights_only)
expected_keys = loaded_state_dict_keys # plug for missing expected_keys. TODO: replace with proper keys
fixed_state_dict = model._fix_state_dict_keys_on_load(state_dict)
error_msgs = _load_state_dict_into_meta_model(
model,
fixed_state_dict,
start_prefix,
expected_keys=expected_keys,
hf_quantizer=hf_quantizer,
)
return error_msgs
@classmethod
def register_for_auto_class(cls, auto_class="AutoModel"):
"""

View file

@ -35,7 +35,7 @@ from torch import Tensor
from vissl.models.model_helpers import get_trunk_forward_outputs
from transformers import AutoImageProcessor, RegNetConfig, RegNetForImageClassification, RegNetModel
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_utils import _load_state_dict_into_meta_model, load_state_dict
from transformers.utils import logging
@ -244,14 +244,18 @@ def convert_weights_and_push(save_directory: Path, model_name: str = None, push_
our_model_func = RegNetModel
if "in1k" in model_name:
our_model_func = RegNetForImageClassification
our_model = our_model_func(our_config)
# place our model to the meta device (so remove all the weights)
our_model.to(torch.device("meta"))
with torch.device("meta"):
our_model = our_model_func(our_config)
logger.info("Loading state_dict in our model.")
# load state dict
state_dict_keys = our_model.state_dict().keys()
PreTrainedModel._load_pretrained_model_low_mem(
our_model, state_dict_keys, [save_directory / f"{model_name}.pth"]
state_dict = load_state_dict(save_directory / f"{model_name}.pth", weights_only=True)
fixed_state_dict = our_model._fix_state_dict_keys_on_load(state_dict)
_load_state_dict_into_meta_model(
our_model,
fixed_state_dict,
start_prefix="",
expected_keys=state_dict_keys,
)
logger.info("Finally, pushing!")
# push it to hub