diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 87bca772f..cb04ff337 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -409,7 +409,9 @@ class GenerationMixin: # retrieve encoder hidden states encoder = self.get_encoder() encoder_kwargs = { - argument: value for argument, value in model_kwargs.items() if not argument.startswith("decoder_") + argument: value + for argument, value in model_kwargs.items() + if not (argument.startswith("decoder_") or argument.startswith("cross_attn")) } model_kwargs["encoder_outputs"]: ModelOutput = encoder(input_ids, return_dict=True, **encoder_kwargs) return model_kwargs diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 8f72c64d4..1c66f06a0 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -1327,6 +1327,8 @@ class BartForConditionalGeneration(BartPretrainedModel): past=None, attention_mask=None, head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, use_cache=None, encoder_outputs=None, **kwargs @@ -1342,6 +1344,8 @@ class BartForConditionalGeneration(BartPretrainedModel): "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index ea3f54533..0c3860f85 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -2530,6 +2530,8 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel): past=None, attention_mask=None, head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, use_cache=None, encoder_outputs=None, **kwargs @@ -2545,6 +2547,8 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel): "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 5620c7788..ce4c15160 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -1321,6 +1321,8 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel): past=None, attention_mask=None, head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, use_cache=None, encoder_outputs=None, **kwargs @@ -1336,6 +1338,8 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel): "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 7ddc2e765..d3e80f022 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -1296,6 +1296,8 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel): past=None, attention_mask=None, head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, use_cache=None, encoder_outputs=None, **kwargs @@ -1311,6 +1313,8 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel): "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index 54da504ab..ce0807b1a 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -1215,7 +1215,16 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel): ) def prepare_inputs_for_generation( - self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs + self, + decoder_input_ids, + past=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs ): return { "input_ids": None, # encoder_outputs is defined. input_ids not needed @@ -1223,6 +1232,9 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel): "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 79f33d1db..2541121a2 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -2356,6 +2356,8 @@ class LEDForConditionalGeneration(LEDPreTrainedModel): past=None, attention_mask=None, head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, use_cache=None, encoder_outputs=None, **kwargs, @@ -2371,6 +2373,8 @@ class LEDForConditionalGeneration(LEDPreTrainedModel): "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 4db2be333..4c5803269 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -1324,6 +1324,8 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel): past=None, attention_mask=None, head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, use_cache=None, encoder_outputs=None, **kwargs, @@ -1339,6 +1341,8 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel): "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index dc40dacc4..762113845 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -1309,6 +1309,8 @@ class MarianMTModel(MarianPreTrainedModel): past=None, attention_mask=None, head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, use_cache=None, encoder_outputs=None, **kwargs @@ -1324,6 +1326,8 @@ class MarianMTModel(MarianPreTrainedModel): "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index a445539be..8e9b24499 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -1327,7 +1327,16 @@ class MBartForConditionalGeneration(MBartPreTrainedModel): ) def prepare_inputs_for_generation( - self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs + self, + decoder_input_ids, + past=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs ): # cut decoder_input_ids if past is used if past is not None: @@ -1339,6 +1348,9 @@ class MBartForConditionalGeneration(MBartPreTrainedModel): "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index e43a0bcbb..a8b1ce05b 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -1312,6 +1312,8 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel): past=None, attention_mask=None, head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, use_cache=None, encoder_outputs=None, **kwargs @@ -1327,6 +1329,8 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel): "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 64d8d36e3..c2f642b99 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -2020,6 +2020,8 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel): past=None, attention_mask=None, head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, use_cache=None, encoder_outputs=None, **kwargs, @@ -2036,6 +2038,8 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel): "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, } diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 02b79d890..1fceb7b77 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1655,7 +1655,16 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ) def prepare_inputs_for_generation( - self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs + self, + input_ids, + past=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs ): # cut decoder_input_ids if past is used @@ -1667,6 +1676,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel): "past_key_values": past, "encoder_outputs": encoder_outputs, "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, } diff --git a/tests/test_generation_utils.py b/tests/test_generation_utils.py index 4830e07a2..b377bd3fa 100644 --- a/tests/test_generation_utils.py +++ b/tests/test_generation_utils.py @@ -14,6 +14,7 @@ # limitations under the License. +import inspect import unittest from transformers import is_torch_available @@ -1072,6 +1073,40 @@ class GenerationTesterMixin: output, input_ids, model.config, num_return_sequences=num_return_sequences * beam_scorer.num_beams ) + def test_generate_with_head_masking(self): + """Test designed for encoder-decoder models to ensure the attention head masking is used.""" + attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"] + for model_class in self.all_generative_model_classes: + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + model = model_class(config) + # We want to test only encoder-decoder models + if not config.is_encoder_decoder: + continue + + head_masking = { + "head_mask": torch.zeros(config.encoder_layers, config.encoder_attention_heads), + "decoder_head_mask": torch.zeros(config.decoder_layers, config.decoder_attention_heads), + "cross_attn_head_mask": torch.zeros(config.decoder_layers, config.decoder_attention_heads), + } + + signature = inspect.signature(model.forward) + # We want to test only models where encoder/decoder head masking is implemented + if set(head_masking.keys()) < set([*signature.parameters.keys()]): + continue + + for attn_name, (name, mask) in zip(attention_names, head_masking.items()): + out = model.generate( + input_ids, + num_beams=1, + max_length=max_length, + output_attentions=True, + return_dict_in_generate=True, + **{name: mask}, + ) + # We check the state of decoder_attentions and cross_attentions just from the last step + attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1] + self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0) + def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): batch_size, seq_length = input_ids.shape num_sequences_in_output = batch_size * num_return_sequences diff --git a/tests/test_modeling_prophetnet.py b/tests/test_modeling_prophetnet.py index caeb84131..32f100044 100644 --- a/tests/test_modeling_prophetnet.py +++ b/tests/test_modeling_prophetnet.py @@ -1088,6 +1088,10 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test self.assertIsNotNone(encoder_hidden_states.grad) self.assertIsNotNone(encoder_attentions.grad) + def test_generate_with_head_masking(self): + """Generating with head_masking has not been implemented for ProphetNet models yet.""" + pass + @require_torch class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): diff --git a/tests/test_modeling_t5.py b/tests/test_modeling_t5.py index 55b9c0568..7f538a36c 100644 --- a/tests/test_modeling_t5.py +++ b/tests/test_modeling_t5.py @@ -600,6 +600,37 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): input_names=["input_ids", "decoder_input_ids"], ) + def test_generate_with_head_masking(self): + attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"] + config_and_inputs = self.model_tester.prepare_config_and_inputs() + config = config_and_inputs[0] + max_length = config_and_inputs[1].shape[-1] + 3 + model = T5ForConditionalGeneration(config) + + head_masking = { + "head_mask": torch.zeros(config.num_layers, config.num_heads), + "decoder_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads), + "cross_attn_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads), + } + + for attn_name, (name, mask) in zip(attention_names, head_masking.items()): + head_masks = {name: mask} + # Explicitly pass decoder_head_mask as it is required from T5 model when head_mask specified + if name == "head_mask": + head_masks["decoder_head_mask"] = torch.ones(config.num_decoder_layers, config.num_heads) + + out = model.generate( + config_and_inputs[1], + num_beams=1, + max_length=max_length, + output_attentions=True, + return_dict_in_generate=True, + **head_masks, + ) + # We check the state of decoder_attentions and cross_attentions just from the last step + attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1] + self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0) + class T5EncoderOnlyModelTester: def __init__(