diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 1c67ee1f8..114c90524 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -246,6 +246,25 @@ def set_zero3_state(): _is_ds_init_called = False +def restore_default_torch_dtype(func): + """ + Decorator to restore the default torch dtype + at the end of the function. Serves + as a backup in case calling the function raises + an error after the function has changed the default dtype but before it could restore it. + """ + + @wraps(func) + def _wrapper(*args, **kwargs): + old_dtype = torch.get_default_dtype() + try: + return func(*args, **kwargs) + finally: + torch.set_default_dtype(old_dtype) + + return _wrapper + + def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]): try: return next(parameter.parameters()).device @@ -1407,6 +1426,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix self.model_tags.append(tag) @classmethod + @restore_default_torch_dtype def _from_config(cls, config, **kwargs): """ All context managers that the model should be initialized under go here. @@ -3142,6 +3162,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix return super().float(*args) @classmethod + @restore_default_torch_dtype def from_pretrained( cls: Type[SpecificPreTrainedModelType], pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 2179c4be5..cae38422f 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -39,6 +39,7 @@ from transformers import ( AutoModelForSequenceClassification, DynamicCache, LlavaForConditionalGeneration, + MistralForCausalLM, OwlViTForObjectDetection, PretrainedConfig, is_torch_available, @@ -318,6 +319,14 @@ def check_models_equal(model1, model2): @require_torch class ModelUtilsTest(TestCasePlus): + def setUp(self): + self.old_dtype = torch.get_default_dtype() + super().setUp() + + def tearDown(self): + torch.set_default_dtype(self.old_dtype) + super().tearDown() + @slow def test_model_from_pretrained(self): model_name = "google-bert/bert-base-uncased" @@ -1819,6 +1828,67 @@ class ModelUtilsTest(TestCasePlus): self.assertIsNone(model_outputs.past_key_values) self.assertTrue(model.training) + def test_restore_default_torch_dtype_from_pretrained(self): + """ + Tests that the default torch dtype is restored + when an error happens during the loading of a model. + """ + old_dtype = torch.get_default_dtype() + # set default type to float32 + torch.set_default_dtype(torch.float32) + + # Mock injection point which is right after the call to `_set_default_torch_dtype` + original_set_default_torch_dtype = MistralForCausalLM._set_default_torch_dtype + + def debug(*args, **kwargs): + # call the method as usual, than raise a RuntimeError + original_set_default_torch_dtype(*args, **kwargs) + raise RuntimeError + + with mock.patch( + "transformers.models.mistral.modeling_mistral.MistralForCausalLM._set_default_torch_dtype", + side_effect=debug, + ): + with self.assertRaises(RuntimeError): + _ = AutoModelForCausalLM.from_pretrained(TINY_MISTRAL, device_map="auto", torch_dtype=torch.float16) + # default should still be float32 + assert torch.get_default_dtype() == torch.float32 + torch.set_default_dtype(old_dtype) + + def test_restore_default_torch_dtype_from_config(self): + """ + Tests that the default torch dtype is restored + when an error happens during the loading of a model. + """ + old_dtype = torch.get_default_dtype() + # set default type to float32 + torch.set_default_dtype(torch.float32) + + config = AutoConfig.from_pretrained( + TINY_MISTRAL, + ) + + # Mock injection point which is right after the call to `_set_default_torch_dtype` + original_set_default_torch_dtype = MistralForCausalLM._set_default_torch_dtype + + def debug(*args, **kwargs): + # call the method as usual, than raise a RuntimeError + original_set_default_torch_dtype(*args, **kwargs) + raise RuntimeError + + with mock.patch( + "transformers.models.mistral.modeling_mistral.MistralForCausalLM._set_default_torch_dtype", + side_effect=debug, + ): + with self.assertRaises(RuntimeError): + config.torch_dtype = torch.float16 + _ = AutoModelForCausalLM.from_config( + config, + ) + # default should still be float32 + assert torch.get_default_dtype() == torch.float32 + torch.set_default_dtype(old_dtype) + def test_unknown_quantization_config(self): with tempfile.TemporaryDirectory() as tmpdir: config = BertConfig(