mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Fix: take into account meta device (#34134)
* Do not load for meta device * Make some minor improvements * Add test * Update tests/utils/test_modeling_utils.py Update test parameters Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Make the test simpler --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
parent
8cadf76e1c
commit
f297af55df
2 changed files with 18 additions and 1 deletions
|
|
@ -361,6 +361,9 @@ def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefi
|
|||
|
||||
Note: We fully disable this if we are using `deepspeed`
|
||||
"""
|
||||
if model_to_load.device.type == "meta":
|
||||
return False
|
||||
|
||||
if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
|
||||
return False
|
||||
|
||||
|
|
@ -375,7 +378,7 @@ def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefi
|
|||
return False
|
||||
|
||||
# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
|
||||
first_key = list(model_to_load.state_dict().keys())[0]
|
||||
first_key = next(iter(model_to_load.state_dict().keys()))
|
||||
if start_prefix + first_key in state_dict:
|
||||
return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@
|
|||
# limitations under the License.
|
||||
import copy
|
||||
import glob
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import os.path
|
||||
|
|
@ -459,6 +460,19 @@ class ModelUtilsTest(TestCasePlus):
|
|||
with self.assertRaises(ValueError):
|
||||
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="int64")
|
||||
|
||||
@require_torch
|
||||
def test_model_from_pretrained_meta_device(self):
|
||||
def is_on_meta(model_id, dtype):
|
||||
with torch.device("meta"):
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype)
|
||||
return all(value.device.type == "meta" for value in model.state_dict().values())
|
||||
|
||||
model_ids = ("fxmarty/tiny-llama-fast-tokenizer", "fxmarty/small-llama-testing")
|
||||
dtypes = (None, "auto", torch.float16)
|
||||
|
||||
for model_id, dtype in itertools.product(model_ids, dtypes):
|
||||
self.assertTrue(is_on_meta(model_id, dtype))
|
||||
|
||||
def test_model_from_pretrained_torch_dtype(self):
|
||||
# test that the model can be instantiated with dtype of either
|
||||
# 1. explicit from_pretrained's torch_dtype argument
|
||||
|
|
|
|||
Loading…
Reference in a new issue