Fix: take into account meta device (#34134)

* Do not load for meta device

* Make some minor improvements

* Add test

* Update tests/utils/test_modeling_utils.py

Update test parameters

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Make the test simpler

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Tibor Reiss 2024-11-20 11:32:07 +01:00 committed by GitHub
parent 8cadf76e1c
commit f297af55df
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 18 additions and 1 deletions

View file

@ -361,6 +361,9 @@ def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefi
Note: We fully disable this if we are using `deepspeed`
"""
if model_to_load.device.type == "meta":
return False
if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
return False
@ -375,7 +378,7 @@ def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefi
return False
# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
first_key = list(model_to_load.state_dict().keys())[0]
first_key = next(iter(model_to_load.state_dict().keys()))
if start_prefix + first_key in state_dict:
return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype

View file

@ -14,6 +14,7 @@
# limitations under the License.
import copy
import glob
import itertools
import json
import os
import os.path
@ -459,6 +460,19 @@ class ModelUtilsTest(TestCasePlus):
with self.assertRaises(ValueError):
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="int64")
@require_torch
def test_model_from_pretrained_meta_device(self):
def is_on_meta(model_id, dtype):
with torch.device("meta"):
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype)
return all(value.device.type == "meta" for value in model.state_dict().values())
model_ids = ("fxmarty/tiny-llama-fast-tokenizer", "fxmarty/small-llama-testing")
dtypes = (None, "auto", torch.float16)
for model_id, dtype in itertools.product(model_ids, dtypes):
self.assertTrue(is_on_meta(model_id, dtype))
def test_model_from_pretrained_torch_dtype(self):
# test that the model can be instantiated with dtype of either
# 1. explicit from_pretrained's torch_dtype argument