mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Add separated decoder_head_mask for T5 Models (#9634)
* Add decoder_head_mask for PyTorch T5 model * Add decoder_head_mask args into T5Model and T5ForConditionalGeneration * Slightly change the order of input args to be in accordance with the convention from BART-based models introduced within the PR #9569. * Make style for modeling_t5.py * Add decoder_head_mask for TF T5 models * Separate head_mask and decoder_head_mask args in TF T5 models * Slightly change the order of input args to follow convention of BART-based models updated in PR #9569 * Update test_forward_signature tests/test_modeling_tf_common.py w.r.t. the changed order of input args * Add FutureWarnings for T5 and TFT5 models * Add FutureWarnings for T5 and TFT5 models warning a user that input argument `head_mask` was split into two arguments - `head_mask` and `decoder_head_mask` * Add default behaviour - `decoder_head_mask` is set to copy `head_mask` * Fix T5 modeling and FutureWarning * Make proper usage of head_mask and decoder_head_mask in cross_attention * Fix conditions for raising FutureWarning * Reformat FutureWarning in T5 modeling * Refactor the warning message
This commit is contained in:
parent
e4c06ed664
commit
2ebbbf558c
3 changed files with 128 additions and 46 deletions
|
|
@ -18,6 +18,7 @@
|
|||
import copy
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
|
@ -409,7 +410,7 @@ class T5Attention(nn.Module):
|
|||
key_value_states=None,
|
||||
position_bias=None,
|
||||
past_key_value=None,
|
||||
head_mask=None,
|
||||
layer_head_mask=None,
|
||||
query_length=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
|
|
@ -504,8 +505,8 @@ class T5Attention(nn.Module):
|
|||
) # (batch_size, n_heads, seq_length, key_length)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask
|
||||
if layer_head_mask is not None:
|
||||
attn_weights = attn_weights * layer_head_mask
|
||||
|
||||
attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
|
||||
attn_output = self.o(attn_output)
|
||||
|
|
@ -530,7 +531,7 @@ class T5LayerSelfAttention(nn.Module):
|
|||
hidden_states,
|
||||
attention_mask=None,
|
||||
position_bias=None,
|
||||
head_mask=None,
|
||||
layer_head_mask=None,
|
||||
past_key_value=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
|
|
@ -540,7 +541,7 @@ class T5LayerSelfAttention(nn.Module):
|
|||
normed_hidden_states,
|
||||
mask=attention_mask,
|
||||
position_bias=position_bias,
|
||||
head_mask=head_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
|
|
@ -563,7 +564,7 @@ class T5LayerCrossAttention(nn.Module):
|
|||
key_value_states,
|
||||
attention_mask=None,
|
||||
position_bias=None,
|
||||
head_mask=None,
|
||||
layer_head_mask=None,
|
||||
past_key_value=None,
|
||||
use_cache=False,
|
||||
query_length=None,
|
||||
|
|
@ -575,7 +576,7 @@ class T5LayerCrossAttention(nn.Module):
|
|||
mask=attention_mask,
|
||||
key_value_states=key_value_states,
|
||||
position_bias=position_bias,
|
||||
head_mask=head_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
query_length=query_length,
|
||||
|
|
@ -605,7 +606,8 @@ class T5Block(nn.Module):
|
|||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
encoder_decoder_position_bias=None,
|
||||
head_mask=None,
|
||||
layer_head_mask=None,
|
||||
encoder_layer_head_mask=None,
|
||||
past_key_value=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
|
|
@ -632,7 +634,7 @@ class T5Block(nn.Module):
|
|||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_bias=position_bias,
|
||||
head_mask=head_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
|
|
@ -659,7 +661,7 @@ class T5Block(nn.Module):
|
|||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
position_bias=encoder_decoder_position_bias,
|
||||
head_mask=head_mask,
|
||||
layer_head_mask=encoder_layer_head_mask,
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
query_length=query_length,
|
||||
use_cache=use_cache,
|
||||
|
|
@ -839,6 +841,7 @@ class T5Stack(T5PreTrainedModel):
|
|||
encoder_attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
|
|
@ -906,6 +909,7 @@ class T5Stack(T5PreTrainedModel):
|
|||
|
||||
# Prepare head mask if needed
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
|
||||
encoder_head_mask = self.get_head_mask(encoder_head_mask, self.config.num_layers)
|
||||
present_key_value_states = () if use_cache else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
|
|
@ -930,6 +934,10 @@ class T5Stack(T5PreTrainedModel):
|
|||
encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
|
||||
if encoder_decoder_position_bias is not None:
|
||||
encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
|
||||
if head_mask is not None:
|
||||
head_mask = head_mask.to(hidden_states.device)
|
||||
if encoder_head_mask is not None:
|
||||
encoder_head_mask = encoder_head_mask.to(hidden_states.device)
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
|
|
@ -940,7 +948,8 @@ class T5Stack(T5PreTrainedModel):
|
|||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
||||
head_mask=head_mask[i],
|
||||
layer_head_mask=head_mask[i],
|
||||
encoder_layer_head_mask=encoder_head_mask[i] if encoder_head_mask is not None else None,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
|
|
@ -1058,6 +1067,20 @@ T5_INPUTS_DOCSTRING = r"""
|
|||
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
|
||||
also be used by default.
|
||||
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in ``[0,
|
||||
1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the self-attention modules. in the decoder Mask values selected in ``[0,
|
||||
1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
|
||||
Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`:
|
||||
`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a
|
||||
|
|
@ -1069,12 +1092,6 @@ T5_INPUTS_DOCSTRING = r"""
|
|||
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
|
||||
|
|
@ -1141,6 +1158,14 @@ T5_ENCODER_INPUTS_DOCSTRING = r"""
|
|||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||
"""
|
||||
|
||||
# Warning messafe for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
|
||||
__HEAD_MASK_WARNING_MSG = """
|
||||
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
|
||||
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
|
||||
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
|
||||
num_heads)`.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.",
|
||||
|
|
@ -1229,9 +1254,10 @@ class T5Model(T5PreTrainedModel):
|
|||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
use_cache=None,
|
||||
|
|
@ -1258,6 +1284,12 @@ class T5Model(T5PreTrainedModel):
|
|||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
|
||||
if head_mask is not None and decoder_head_mask is None:
|
||||
if self.config.num_layers == self.config.num_decoder_layers:
|
||||
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
|
||||
decoder_head_mask = head_mask
|
||||
|
||||
# Encode if needed (training, first prediction pass)
|
||||
if encoder_outputs is None:
|
||||
encoder_outputs = self.encoder(
|
||||
|
|
@ -1298,7 +1330,8 @@ class T5Model(T5PreTrainedModel):
|
|||
past_key_values=past_key_values,
|
||||
encoder_hidden_states=hidden_states,
|
||||
encoder_attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
head_mask=decoder_head_mask,
|
||||
encoder_head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
|
|
@ -1409,9 +1442,10 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
labels=None,
|
||||
|
|
@ -1447,6 +1481,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
|
||||
if head_mask is not None and decoder_head_mask is None:
|
||||
if self.config.num_layers == self.config.num_decoder_layers:
|
||||
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
|
||||
decoder_head_mask = head_mask
|
||||
|
||||
# Encode if needed (training, first prediction pass)
|
||||
if encoder_outputs is None:
|
||||
# Convert encoder inputs in embeddings if needed
|
||||
|
|
@ -1503,7 +1543,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||
past_key_values=past_key_values,
|
||||
encoder_hidden_states=hidden_states,
|
||||
encoder_attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
head_mask=decoder_head_mask,
|
||||
encoder_head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@
|
|||
import copy
|
||||
import itertools
|
||||
import math
|
||||
import warnings
|
||||
from typing import Tuple
|
||||
|
||||
import tensorflow as tf
|
||||
|
|
@ -245,7 +246,7 @@ class TFT5Attention(tf.keras.layers.Layer):
|
|||
key_value_states=None,
|
||||
position_bias=None,
|
||||
past_key_value=None,
|
||||
head_mask=None,
|
||||
layer_head_mask=None,
|
||||
query_length=None,
|
||||
use_cache=False,
|
||||
training=False,
|
||||
|
|
@ -342,8 +343,8 @@ class TFT5Attention(tf.keras.layers.Layer):
|
|||
weights = self.dropout(weights, training=training) # (batch_size, n_heads, query_length, key_length)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
weights = weights * head_mask
|
||||
if layer_head_mask is not None:
|
||||
weights = weights * layer_head_mask
|
||||
|
||||
attn_output = tf.matmul(weights, value_states) # (batch_size, n_heads, query_length, dim_per_head)
|
||||
|
||||
|
|
@ -373,7 +374,7 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
|
|||
hidden_states,
|
||||
attention_mask=None,
|
||||
position_bias=None,
|
||||
head_mask=None,
|
||||
layer_head_mask=None,
|
||||
past_key_value=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
|
|
@ -384,7 +385,7 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
|
|||
normed_hidden_states,
|
||||
mask=attention_mask,
|
||||
position_bias=position_bias,
|
||||
head_mask=head_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
|
|
@ -412,7 +413,7 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
|
|||
key_value_states,
|
||||
attention_mask=None,
|
||||
position_bias=None,
|
||||
head_mask=None,
|
||||
layer_head_mask=None,
|
||||
past_key_value=None,
|
||||
query_length=None,
|
||||
use_cache=False,
|
||||
|
|
@ -425,7 +426,7 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
|
|||
mask=attention_mask,
|
||||
key_value_states=key_value_states,
|
||||
position_bias=position_bias,
|
||||
head_mask=head_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
past_key_value=past_key_value,
|
||||
query_length=query_length,
|
||||
use_cache=use_cache,
|
||||
|
|
@ -467,7 +468,8 @@ class TFT5Block(tf.keras.layers.Layer):
|
|||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
encoder_decoder_position_bias=None,
|
||||
head_mask=None,
|
||||
layer_head_mask=None,
|
||||
encoder_layer_head_mask=None,
|
||||
past_key_value=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
|
|
@ -494,7 +496,7 @@ class TFT5Block(tf.keras.layers.Layer):
|
|||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_bias=position_bias,
|
||||
head_mask=head_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
|
|
@ -516,7 +518,7 @@ class TFT5Block(tf.keras.layers.Layer):
|
|||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
position_bias=encoder_decoder_position_bias,
|
||||
head_mask=head_mask,
|
||||
layer_head_mask=encoder_layer_head_mask,
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
query_length=query_length,
|
||||
use_cache=use_cache,
|
||||
|
|
@ -584,6 +586,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||
encoder_attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
|
|
@ -601,6 +604,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||
encoder_attention_mask=encoder_attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
head_mask=head_mask,
|
||||
encoder_head_mask=encoder_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
|
|
@ -709,6 +713,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||
|
||||
assert inputs["head_mask"] is None, "Head mask not supported"
|
||||
inputs["head_mask"] = [None] * self.num_hidden_layers
|
||||
assert inputs["encoder_head_mask"] is None, "Encoder head mask not supported"
|
||||
inputs["encoder_head_mask"] = [None] * self.num_hidden_layers
|
||||
present_key_value_states = () if inputs["use_cache"] and self.is_decoder else None
|
||||
all_hidden_states = () if inputs["output_hidden_states"] else None
|
||||
all_attentions = () if inputs["output_attentions"] else None
|
||||
|
|
@ -727,7 +733,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
||||
head_mask=inputs["head_mask"][i],
|
||||
layer_head_mask=inputs["head_mask"][i],
|
||||
encoder_layer_head_mask=inputs["encoder_head_mask"][i],
|
||||
past_key_value=past_key_value,
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
|
|
@ -950,6 +957,20 @@ T5_INPUTS_DOCSTRING = r"""
|
|||
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
|
||||
also be used by default.
|
||||
head_mask: (:obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in ``[0,
|
||||
1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
decoder_head_mask: (:obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in ``[0,
|
||||
1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_outputs (:obj:`tuple(tuple(tf.FloatTensor)`, `optional`):
|
||||
Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`:
|
||||
`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a
|
||||
|
|
@ -973,12 +994,6 @@ T5_INPUTS_DOCSTRING = r"""
|
|||
|
||||
If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds`
|
||||
takes the value of :obj:`inputs_embeds`.
|
||||
head_mask: (:obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||
decoding (see :obj:`past_key_values`).
|
||||
|
|
@ -1037,6 +1052,13 @@ T5_ENCODER_INPUTS_DOCSTRING = r"""
|
|||
behaviors between training and evaluation).
|
||||
"""
|
||||
|
||||
__HEAD_MASK_WARNING_MSG = """
|
||||
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
|
||||
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
|
||||
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = tf.ones((num_layers,
|
||||
num_heads))`.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.",
|
||||
|
|
@ -1075,9 +1097,10 @@ class TFT5Model(TFT5PreTrainedModel):
|
|||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
use_cache=None,
|
||||
|
|
@ -1103,6 +1126,11 @@ class TFT5Model(TFT5PreTrainedModel):
|
|||
|
||||
|
||||
"""
|
||||
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
|
||||
if head_mask is not None and decoder_head_mask is None:
|
||||
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
|
||||
decoder_head_mask = head_mask
|
||||
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
|
|
@ -1110,9 +1138,10 @@ class TFT5Model(TFT5PreTrainedModel):
|
|||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
|
|
@ -1149,7 +1178,8 @@ class TFT5Model(TFT5PreTrainedModel):
|
|||
encoder_hidden_states=hidden_states,
|
||||
encoder_attention_mask=inputs["attention_mask"],
|
||||
inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
head_mask=inputs["head_mask"],
|
||||
head_mask=inputs["decoder_head_mask"],
|
||||
encoder_head_mask=inputs["head_mask"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
|
|
@ -1251,9 +1281,10 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
labels=None,
|
||||
|
|
@ -1289,6 +1320,11 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||
>>> result = model.generate(inputs)
|
||||
|
||||
"""
|
||||
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
|
||||
if head_mask is not None and decoder_head_mask is None:
|
||||
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
|
||||
decoder_head_mask = head_mask
|
||||
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
|
|
@ -1296,9 +1332,10 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
labels=labels,
|
||||
|
|
@ -1340,7 +1377,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||
encoder_hidden_states=hidden_states,
|
||||
encoder_attention_mask=inputs["attention_mask"],
|
||||
inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
head_mask=inputs["head_mask"],
|
||||
head_mask=inputs["decoder_head_mask"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
|
|
|
|||
|
|
@ -155,9 +155,13 @@ class TFModelTesterMixin:
|
|||
"attention_mask",
|
||||
"decoder_input_ids",
|
||||
"decoder_attention_mask",
|
||||
"encoder_outputs",
|
||||
]
|
||||
self.assertListEqual(arg_names[:5], expected_arg_names)
|
||||
expected_arg_names.extend(
|
||||
["head_mask", "decoder_head_mask", "encoder_outputs"]
|
||||
if "head_mask" and "decoder_head_mask" in arg_names
|
||||
else ["encoder_outputs"]
|
||||
)
|
||||
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
||||
|
||||
else:
|
||||
expected_arg_names = ["input_ids"]
|
||||
|
|
|
|||
Loading…
Reference in a new issue