diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 0ea405ed6..a050b81fd 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -18,6 +18,7 @@ import copy import math import os +import warnings import torch import torch.nn.functional as F @@ -409,7 +410,7 @@ class T5Attention(nn.Module): key_value_states=None, position_bias=None, past_key_value=None, - head_mask=None, + layer_head_mask=None, query_length=None, use_cache=False, output_attentions=False, @@ -504,8 +505,8 @@ class T5Attention(nn.Module): ) # (batch_size, n_heads, seq_length, key_length) # Mask heads if we want to - if head_mask is not None: - attn_weights = attn_weights * head_mask + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) attn_output = self.o(attn_output) @@ -530,7 +531,7 @@ class T5LayerSelfAttention(nn.Module): hidden_states, attention_mask=None, position_bias=None, - head_mask=None, + layer_head_mask=None, past_key_value=None, use_cache=False, output_attentions=False, @@ -540,7 +541,7 @@ class T5LayerSelfAttention(nn.Module): normed_hidden_states, mask=attention_mask, position_bias=position_bias, - head_mask=head_mask, + layer_head_mask=layer_head_mask, past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, @@ -563,7 +564,7 @@ class T5LayerCrossAttention(nn.Module): key_value_states, attention_mask=None, position_bias=None, - head_mask=None, + layer_head_mask=None, past_key_value=None, use_cache=False, query_length=None, @@ -575,7 +576,7 @@ class T5LayerCrossAttention(nn.Module): mask=attention_mask, key_value_states=key_value_states, position_bias=position_bias, - head_mask=head_mask, + layer_head_mask=layer_head_mask, past_key_value=past_key_value, use_cache=use_cache, query_length=query_length, @@ -605,7 +606,8 @@ class T5Block(nn.Module): encoder_hidden_states=None, encoder_attention_mask=None, encoder_decoder_position_bias=None, - head_mask=None, + layer_head_mask=None, + encoder_layer_head_mask=None, past_key_value=None, use_cache=False, output_attentions=False, @@ -632,7 +634,7 @@ class T5Block(nn.Module): hidden_states, attention_mask=attention_mask, position_bias=position_bias, - head_mask=head_mask, + layer_head_mask=layer_head_mask, past_key_value=self_attn_past_key_value, use_cache=use_cache, output_attentions=output_attentions, @@ -659,7 +661,7 @@ class T5Block(nn.Module): key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, - head_mask=head_mask, + layer_head_mask=encoder_layer_head_mask, past_key_value=cross_attn_past_key_value, query_length=query_length, use_cache=use_cache, @@ -839,6 +841,7 @@ class T5Stack(T5PreTrainedModel): encoder_attention_mask=None, inputs_embeds=None, head_mask=None, + encoder_head_mask=None, past_key_values=None, use_cache=None, output_attentions=None, @@ -906,6 +909,7 @@ class T5Stack(T5PreTrainedModel): # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) + encoder_head_mask = self.get_head_mask(encoder_head_mask, self.config.num_layers) present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -930,6 +934,10 @@ class T5Stack(T5PreTrainedModel): encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) if encoder_decoder_position_bias is not None: encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) + if head_mask is not None: + head_mask = head_mask.to(hidden_states.device) + if encoder_head_mask is not None: + encoder_head_mask = encoder_head_mask.to(hidden_states.device) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -940,7 +948,8 @@ class T5Stack(T5PreTrainedModel): encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, - head_mask=head_mask[i], + layer_head_mask=head_mask[i], + encoder_layer_head_mask=encoder_head_mask[i] if encoder_head_mask is not None else None, past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, @@ -1058,6 +1067,20 @@ T5_INPUTS_DOCSTRING = r""" decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will also be used by default. + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in ``[0, + 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. in the decoder Mask values selected in ``[0, + 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`: `attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a @@ -1069,12 +1092,6 @@ T5_INPUTS_DOCSTRING = r""" If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. - head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert :obj:`input_ids` indices into associated @@ -1141,6 +1158,14 @@ T5_ENCODER_INPUTS_DOCSTRING = r""" Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. """ +# Warning messafe for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask +__HEAD_MASK_WARNING_MSG = """ +The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, +`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. +If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, +num_heads)`. +""" + @add_start_docstrings( "The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.", @@ -1229,9 +1254,10 @@ class T5Model(T5PreTrainedModel): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs=None, past_key_values=None, - head_mask=None, inputs_embeds=None, decoder_inputs_embeds=None, use_cache=None, @@ -1258,6 +1284,12 @@ class T5Model(T5PreTrainedModel): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + # Encode if needed (training, first prediction pass) if encoder_outputs is None: encoder_outputs = self.encoder( @@ -1298,7 +1330,8 @@ class T5Model(T5PreTrainedModel): past_key_values=past_key_values, encoder_hidden_states=hidden_states, encoder_attention_mask=attention_mask, - head_mask=head_mask, + head_mask=decoder_head_mask, + encoder_head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -1409,9 +1442,10 @@ class T5ForConditionalGeneration(T5PreTrainedModel): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs=None, past_key_values=None, - head_mask=None, inputs_embeds=None, decoder_inputs_embeds=None, labels=None, @@ -1447,6 +1481,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + # Encode if needed (training, first prediction pass) if encoder_outputs is None: # Convert encoder inputs in embeddings if needed @@ -1503,7 +1543,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel): past_key_values=past_key_values, encoder_hidden_states=hidden_states, encoder_attention_mask=attention_mask, - head_mask=head_mask, + head_mask=decoder_head_mask, + encoder_head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, diff --git a/src/transformers/models/t5/modeling_tf_t5.py b/src/transformers/models/t5/modeling_tf_t5.py index 5bed19313..c1c65a5e4 100644 --- a/src/transformers/models/t5/modeling_tf_t5.py +++ b/src/transformers/models/t5/modeling_tf_t5.py @@ -18,6 +18,7 @@ import copy import itertools import math +import warnings from typing import Tuple import tensorflow as tf @@ -245,7 +246,7 @@ class TFT5Attention(tf.keras.layers.Layer): key_value_states=None, position_bias=None, past_key_value=None, - head_mask=None, + layer_head_mask=None, query_length=None, use_cache=False, training=False, @@ -342,8 +343,8 @@ class TFT5Attention(tf.keras.layers.Layer): weights = self.dropout(weights, training=training) # (batch_size, n_heads, query_length, key_length) # Mask heads if we want to - if head_mask is not None: - weights = weights * head_mask + if layer_head_mask is not None: + weights = weights * layer_head_mask attn_output = tf.matmul(weights, value_states) # (batch_size, n_heads, query_length, dim_per_head) @@ -373,7 +374,7 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer): hidden_states, attention_mask=None, position_bias=None, - head_mask=None, + layer_head_mask=None, past_key_value=None, use_cache=False, output_attentions=False, @@ -384,7 +385,7 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer): normed_hidden_states, mask=attention_mask, position_bias=position_bias, - head_mask=head_mask, + layer_head_mask=layer_head_mask, past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, @@ -412,7 +413,7 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer): key_value_states, attention_mask=None, position_bias=None, - head_mask=None, + layer_head_mask=None, past_key_value=None, query_length=None, use_cache=False, @@ -425,7 +426,7 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer): mask=attention_mask, key_value_states=key_value_states, position_bias=position_bias, - head_mask=head_mask, + layer_head_mask=layer_head_mask, past_key_value=past_key_value, query_length=query_length, use_cache=use_cache, @@ -467,7 +468,8 @@ class TFT5Block(tf.keras.layers.Layer): encoder_hidden_states=None, encoder_attention_mask=None, encoder_decoder_position_bias=None, - head_mask=None, + layer_head_mask=None, + encoder_layer_head_mask=None, past_key_value=None, use_cache=False, output_attentions=False, @@ -494,7 +496,7 @@ class TFT5Block(tf.keras.layers.Layer): hidden_states, attention_mask=attention_mask, position_bias=position_bias, - head_mask=head_mask, + layer_head_mask=layer_head_mask, past_key_value=self_attn_past_key_value, use_cache=use_cache, output_attentions=output_attentions, @@ -516,7 +518,7 @@ class TFT5Block(tf.keras.layers.Layer): key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, - head_mask=head_mask, + layer_head_mask=encoder_layer_head_mask, past_key_value=cross_attn_past_key_value, query_length=query_length, use_cache=use_cache, @@ -584,6 +586,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): encoder_attention_mask=None, inputs_embeds=None, head_mask=None, + encoder_head_mask=None, past_key_values=None, use_cache=None, output_attentions=None, @@ -601,6 +604,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): encoder_attention_mask=encoder_attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask, + encoder_head_mask=encoder_head_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, @@ -709,6 +713,8 @@ class TFT5MainLayer(tf.keras.layers.Layer): assert inputs["head_mask"] is None, "Head mask not supported" inputs["head_mask"] = [None] * self.num_hidden_layers + assert inputs["encoder_head_mask"] is None, "Encoder head mask not supported" + inputs["encoder_head_mask"] = [None] * self.num_hidden_layers present_key_value_states = () if inputs["use_cache"] and self.is_decoder else None all_hidden_states = () if inputs["output_hidden_states"] else None all_attentions = () if inputs["output_attentions"] else None @@ -727,7 +733,8 @@ class TFT5MainLayer(tf.keras.layers.Layer): encoder_hidden_states=inputs["encoder_hidden_states"], encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, - head_mask=inputs["head_mask"][i], + layer_head_mask=inputs["head_mask"][i], + encoder_layer_head_mask=inputs["encoder_head_mask"][i], past_key_value=past_key_value, use_cache=inputs["use_cache"], output_attentions=inputs["output_attentions"], @@ -950,6 +957,20 @@ T5_INPUTS_DOCSTRING = r""" decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will also be used by default. + head_mask: (:obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in ``[0, + 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask: (:obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in ``[0, + 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tuple(tuple(tf.FloatTensor)`, `optional`): Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`: `attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a @@ -973,12 +994,6 @@ T5_INPUTS_DOCSTRING = r""" If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds` takes the value of :obj:`inputs_embeds`. - head_mask: (:obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up decoding (see :obj:`past_key_values`). @@ -1037,6 +1052,13 @@ T5_ENCODER_INPUTS_DOCSTRING = r""" behaviors between training and evaluation). """ +__HEAD_MASK_WARNING_MSG = """ +The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, +`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. +If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = tf.ones((num_layers, +num_heads))`. +""" + @add_start_docstrings( "The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.", @@ -1075,9 +1097,10 @@ class TFT5Model(TFT5PreTrainedModel): attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs=None, past_key_values=None, - head_mask=None, inputs_embeds=None, decoder_inputs_embeds=None, use_cache=None, @@ -1103,6 +1126,11 @@ class TFT5Model(TFT5PreTrainedModel): """ + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + inputs = input_processing( func=self.call, config=self.config, @@ -1110,9 +1138,10 @@ class TFT5Model(TFT5PreTrainedModel): attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, - head_mask=head_mask, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, @@ -1149,7 +1178,8 @@ class TFT5Model(TFT5PreTrainedModel): encoder_hidden_states=hidden_states, encoder_attention_mask=inputs["attention_mask"], inputs_embeds=inputs["decoder_inputs_embeds"], - head_mask=inputs["head_mask"], + head_mask=inputs["decoder_head_mask"], + encoder_head_mask=inputs["head_mask"], past_key_values=inputs["past_key_values"], use_cache=inputs["use_cache"], output_attentions=inputs["output_attentions"], @@ -1251,9 +1281,10 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs=None, past_key_values=None, - head_mask=None, inputs_embeds=None, decoder_inputs_embeds=None, labels=None, @@ -1289,6 +1320,11 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling >>> result = model.generate(inputs) """ + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + inputs = input_processing( func=self.call, config=self.config, @@ -1296,9 +1332,10 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, - head_mask=head_mask, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, labels=labels, @@ -1340,7 +1377,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling encoder_hidden_states=hidden_states, encoder_attention_mask=inputs["attention_mask"], inputs_embeds=inputs["decoder_inputs_embeds"], - head_mask=inputs["head_mask"], + head_mask=inputs["decoder_head_mask"], past_key_values=inputs["past_key_values"], use_cache=inputs["use_cache"], output_attentions=inputs["output_attentions"], diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 84391b3f5..794238faa 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -155,9 +155,13 @@ class TFModelTesterMixin: "attention_mask", "decoder_input_ids", "decoder_attention_mask", - "encoder_outputs", ] - self.assertListEqual(arg_names[:5], expected_arg_names) + expected_arg_names.extend( + ["head_mask", "decoder_head_mask", "encoder_outputs"] + if "head_mask" and "decoder_head_mask" in arg_names + else ["encoder_outputs"] + ) + self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) else: expected_arg_names = ["input_ids"]