diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e03f87f92..2b7454641 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2770,7 +2770,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix del state_dict[checkpoint_key] return mismatched_keys - folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1]) + if resolved_archive_file is not None: + folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1]) + else: + folder = None if device_map is not None and is_safetensors: param_device_map = expand_device_map(device_map, original_loaded_keys) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index d7d82e694..f5d6357e9 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2749,6 +2749,15 @@ class ModelUtilsTest(TestCasePlus): BertModel.from_pretrained(TINY_T5) self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out) + def test_model_from_pretrained_no_checkpoint(self): + config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert") + model = BertModel(config) + state_dict = model.state_dict() + + new_model = BertModel.from_pretrained(pretrained_model_name_or_path=None, config=config, state_dict=state_dict) + for p1, p2 in zip(model.parameters(), new_model.parameters()): + self.assertTrue(torch.equal(p1, p2)) + @require_torch def test_model_from_config_torch_dtype(self): # test that the model can be instantiated with dtype of user's choice - as long as it's a