diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d7e0580e4..2d7db724f 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1173,14 +1173,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix Args: torch_dtype (`torch.dtype`, *optional*): Override the default `torch.dtype` and load the model under this dtype. + use_flash_attention_2 (`bool`, *optional*): + Whether to load the model with Flash Attention 2 modules. """ torch_dtype = kwargs.pop("torch_dtype", None) + use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) # override default dtype if needed dtype_orig = None if torch_dtype is not None: dtype_orig = cls._set_default_torch_dtype(torch_dtype) + if use_flash_attention_2: + config = cls._check_and_enable_flash_attn_2(config, torch_dtype) + if is_deepspeed_zero3_enabled(): import deepspeed diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 0edc23c7a..49d64dc20 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -33,6 +33,7 @@ from pytest import mark import transformers from transformers import ( AutoModel, + AutoModelForCausalLM, AutoModelForSequenceClassification, PretrainedConfig, is_torch_available, @@ -3269,6 +3270,53 @@ class ModelTesterMixin: # Check models are equal self.assertTrue(check_models_equal(flax_model_1, flax_model_2)) + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_from_config(self): + import torch + + for model_class in self.all_generative_model_classes: + if not model_class._supports_flash_attn_2: + return + + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + # TODO: to change it in the future with other relevant auto classes + fa2_model = AutoModelForCausalLM.from_config( + config, use_flash_attention_2=True, torch_dtype=torch.bfloat16 + ).to(torch_device) + + dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) + dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device) + + fa2_correctly_converted = False + + for _, module in fa2_model.named_modules(): + if "FlashAttention" in module.__class__.__name__: + fa2_correctly_converted = True + break + + self.assertTrue(fa2_correctly_converted) + + _ = fa2_model(input_ids=dummy_input, attention_mask=dummy_attention_mask) + + with tempfile.TemporaryDirectory() as tmpdirname: + fa2_model.save_pretrained(tmpdirname) + + model_from_pretrained = AutoModelForCausalLM.from_pretrained(tmpdirname) + + self.assertFalse(getattr(model_from_pretrained.config, "_flash_attn_2_enabled", False)) + + fa2_correctly_converted = False + + for _, module in model_from_pretrained.named_modules(): + if "FlashAttention" in module.__class__.__name__: + fa2_correctly_converted = True + break + + self.assertFalse(fa2_correctly_converted) + global_rng = random.Random()