mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Fix loading with only state dict and low_cpu_mem_usage = True (#35217)
* fix loading with only state dict and config * style * add tests --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
parent
0531d7513b
commit
1eee1cedfd
2 changed files with 26 additions and 3 deletions
|
|
@ -4022,8 +4022,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
||||
else:
|
||||
loaded_state_dict_keys = list(state_dict.keys())
|
||||
|
||||
if gguf_path is None and (low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available())):
|
||||
if (
|
||||
gguf_path is None
|
||||
and (low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available()))
|
||||
and pretrained_model_name_or_path is not None
|
||||
):
|
||||
# In case some weights need to be kept in float32 and accelerate is not installed,
|
||||
# we later on want to take the path where state_dict is not None, that is the one
|
||||
# that do not require accelerate.
|
||||
|
|
@ -4679,7 +4682,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
)
|
||||
|
||||
# For GGUF models `state_dict` is never set to None as the state dict is always small
|
||||
if gguf_path:
|
||||
if gguf_path or low_cpu_mem_usage:
|
||||
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
|
||||
error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
|
||||
model_to_load,
|
||||
|
|
|
|||
|
|
@ -1750,6 +1750,26 @@ class ModelUtilsTest(TestCasePlus):
|
|||
new_model.generate(random_ids, max_new_tokens=3)
|
||||
self.assertTrue(len(w) == 0)
|
||||
|
||||
def test_load_model_with_state_dict_only(self):
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
state_dict = model.state_dict()
|
||||
config = model.config
|
||||
|
||||
model_loaded = BertModel.from_pretrained(
|
||||
pretrained_model_name_or_path=None, config=config, state_dict=state_dict
|
||||
)
|
||||
self.assertTrue(check_models_equal(model, model_loaded))
|
||||
|
||||
def test_load_model_with_state_dict_only_low_cpu_mem_usage(self):
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
state_dict = model.state_dict()
|
||||
config = model.config
|
||||
|
||||
model_loaded = BertModel.from_pretrained(
|
||||
pretrained_model_name_or_path=None, config=config, state_dict=state_dict, low_cpu_mem_usage=True
|
||||
)
|
||||
self.assertTrue(check_models_equal(model, model_loaded))
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
|
|
|
|||
Loading…
Reference in a new issue