Fix loading with only state dict and low_cpu_mem_usage = True (#35217)

* fix loading with only state dict and config

* style

* add tests

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Marc Sun 2024-12-18 09:54:32 +01:00 committed by GitHub
parent 0531d7513b
commit 1eee1cedfd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 26 additions and 3 deletions

View file

@ -4022,8 +4022,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
else:
loaded_state_dict_keys = list(state_dict.keys())
if gguf_path is None and (low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available())):
if (
gguf_path is None
and (low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available()))
and pretrained_model_name_or_path is not None
):
# In case some weights need to be kept in float32 and accelerate is not installed,
# we later on want to take the path where state_dict is not None, that is the one
# that do not require accelerate.
@ -4679,7 +4682,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
# For GGUF models `state_dict` is never set to None as the state dict is always small
if gguf_path:
if gguf_path or low_cpu_mem_usage:
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,

View file

@ -1750,6 +1750,26 @@ class ModelUtilsTest(TestCasePlus):
new_model.generate(random_ids, max_new_tokens=3)
self.assertTrue(len(w) == 0)
def test_load_model_with_state_dict_only(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
state_dict = model.state_dict()
config = model.config
model_loaded = BertModel.from_pretrained(
pretrained_model_name_or_path=None, config=config, state_dict=state_dict
)
self.assertTrue(check_models_equal(model, model_loaded))
def test_load_model_with_state_dict_only_low_cpu_mem_usage(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
state_dict = model.state_dict()
config = model.config
model_loaded = BertModel.from_pretrained(
pretrained_model_name_or_path=None, config=config, state_dict=state_dict, low_cpu_mem_usage=True
)
self.assertTrue(check_models_equal(model, model_loaded))
@slow
@require_torch