diff --git a/docs/source/index.rst b/docs/source/index.rst index 5180ae9af..5ce54d5d9 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -103,3 +103,4 @@ The library currently contains PyTorch and Tensorflow implementations, pre-train model_doc/xlmroberta model_doc/flaubert model_doc/bart + model_doc/t5 diff --git a/docs/source/model_doc/t5.rst b/docs/source/model_doc/t5.rst new file mode 100644 index 000000000..3e4c28cf3 --- /dev/null +++ b/docs/source/model_doc/t5.rst @@ -0,0 +1,69 @@ +T5 +---------------------------------------------------- +**DISCLAIMER:** This model is still a work in progress, if you see something strange, +file a `Github Issue `_ + +Overview +~~~~~ +The T5 model was presented in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer `_ by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu in +Here the abstract: + +*Transfer learning, where a model is first pre-trained on a data-rich task before being fine-tuned on a downstream task, has emerged as a powerful technique in natural language processing (NLP). The effectiveness of transfer learning has given rise to a diversity of approaches, methodology, and practice. +In this paper, we explore the landscape of transfer learning techniques for NLP by introducing a unified framework that converts every language problem into a text-to-text format. +Our systematic study compares pre-training objectives, architectures, unlabeled datasets, transfer approaches, and other factors on dozens of language understanding tasks. +By combining the insights from our exploration with scale and our new "Colossal Clean Crawled Corpus", we achieve state-of-the-art results on many benchmarks covering summarization, question answering, text classification, and more. +To facilitate future work on transfer learning for NLP, we release our dataset, pre-trained models, and code.* + +The Authors' code can be found `here `_ . + +Tips +~~~~~~~~~~~~~~~~~~~~ +- T5 is an encoder-decoder model pre-trained on a multi-task mixture of unsupervised + and supervised tasks and which each task is cast as a sequence to sequence task. + Therefore T5 works well on a variety of tasks out-of-the-box by prepending a different prefix to the input corresponding to each task, e.g.: for translation: *translate English to German: ..., summarize: ...*. + For more information about the which prefix to use, it is easiest to look into Appendix D of the `paper `_ . +- For sequence to sequence generation, it is recommended to use ``T5ForConditionalGeneration.generate()``. The method takes care of feeding the encoded input via cross-attention layers to the decoder and auto-regressively generating the decoder output. +- T5 uses relative scalar embeddings. Encoder input padding can be done on the left and on the right. + + +T5Config +~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.T5Config + :members: + + +T5Tokenizer +~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.T5Tokenizer + :members: build_inputs_with_special_tokens, get_special_tokens_mask, + create_token_type_ids_from_sequences, save_vocabulary + + +T5Model +~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.T5Model + :members: + + +T5ForConditionalGeneration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.T5ForConditionalGeneration + :members: + + +TFT5Model +~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TFT5Model + :members: + + +TFT5ForConditionalGeneration +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TFT5ForConditionalGeneration + :members: diff --git a/docs/source/pretrained_models.rst b/docs/source/pretrained_models.rst index 565c861cc..65f718c3d 100644 --- a/docs/source/pretrained_models.rst +++ b/docs/source/pretrained_models.rst @@ -275,7 +275,6 @@ For a list that includes community-uploaded models, refer to `https://huggingfac | | | | FlauBERT large architecture | | | | (see `details `__) | +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ | Bart | ``bart-large`` | | 12-layer, 1024-hidden, 16-heads, 406M parameters | | | | (see `details `_) | | +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ @@ -285,6 +284,3 @@ For a list that includes community-uploaded models, refer to `https://huggingfac | | ``bart-large-cnn`` | | 12-layer, 1024-hidden, 16-heads, 406M parameters (same as base) | | | | | bart-large base architecture finetuned on cnn summarization task | +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ - - -.. `__ diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index fe374d79d..d2f92b005 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -72,6 +72,10 @@ BART_INPUTS_DOCSTRING = r""" Mask to avoid performing attention on padding token indices in input_ids. Mask values selected in ``[0, 1]``: ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + encoder_outputs (tuple(:obj:`tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`): + Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`) + `last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder. + Used in the cross-attention of the decoder. decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`): Provide for translation and summarization training. By default, the model will create this tensor by shifting the input_ids right, following the paper. decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`): @@ -972,7 +976,7 @@ class BartForSequenceClassification(PretrainedBartModel): Returns: :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BartConfig`) and inputs: loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided): - Classification loss (cross entropy) + Classification loss (cross entropy) logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`): Classification (or regression if config.num_labels==1) scores (before SoftMax). hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): diff --git a/src/transformers/modeling_t5.py b/src/transformers/modeling_t5.py index d03c40f5a..cdfa1ce91 100644 --- a/src/transformers/modeling_t5.py +++ b/src/transformers/modeling_t5.py @@ -27,7 +27,7 @@ from torch import nn from torch.nn import CrossEntropyLoss from .configuration_t5 import T5Config -from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings +from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_callable from .modeling_utils import PreTrainedModel, prune_linear_layer @@ -696,8 +696,8 @@ T5_START_DOCSTRING = r""" The T5 model was proposed in """ T5_INPUTS_DOCSTRING = r""" - Inputs: - **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. To match pre-training, T5 input sequence should be formatted with [CLS] and [SEP] tokens as follows: @@ -715,11 +715,27 @@ T5_INPUTS_DOCSTRING = r""" Indices can be obtained using :class:`transformers.T5Tokenizer`. See :func:`transformers.PreTrainedTokenizer.encode` and :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. - **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``: + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. - **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``: + encoder_outputs (tuple(:obj:`tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`): + Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`) + `last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder. + Used in the cross-attention of the decoder. + decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`): + Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation. + decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`): + Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): + Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + head_mask: (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`): Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. @@ -729,31 +745,8 @@ T5_INPUTS_DOCSTRING = r""" @add_start_docstrings( "The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.", T5_START_DOCSTRING, - T5_INPUTS_DOCSTRING, ) class T5Model(T5PreTrainedModel): - r""" - Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: - **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` - Sequence of hidden-states at the output of the last layer of the model. - **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) - list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) - of shape ``(batch_size, sequence_length, hidden_size)``: - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - **attentions**: (`optional`, returned when ``config.output_attentions=True``) - list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. - - Examples:: - - tokenizer = T5Tokenizer.from_pretrained('t5-small') - model = T5Model.from_pretrained('t5-small') - input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 - outputs = model(input_ids=input_ids) - last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple - - """ - def __init__(self, config): super().__init__(config) self.shared = nn.Embedding(config.vocab_size, config.d_model) @@ -783,6 +776,7 @@ class T5Model(T5PreTrainedModel): for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) + @add_start_docstrings_to_callable(T5_INPUTS_DOCSTRING) def forward( self, input_ids=None, @@ -794,6 +788,34 @@ class T5Model(T5PreTrainedModel): decoder_inputs_embeds=None, head_mask=None, ): + r""" + Return: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.T5Config`) and inputs. + last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + Examples:: + + from transformers import T5Tokenizer, T5Model + + tokenizer = T5Tokenizer.from_pretrained('t5-small') + model = T5Model.from_pretrained('t5-small') + input_ids = tokenizer.encode("Hello, my dog is cute", return_tensors="pt") # Batch size 1 + outputs = model(input_ids=input_ids, decoder_input_ids=input_ids) + last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple + + """ # Encode if needed (training, first prediction pass) if encoder_outputs is None: @@ -816,38 +838,8 @@ class T5Model(T5PreTrainedModel): return decoder_outputs + encoder_outputs -@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING, T5_INPUTS_DOCSTRING) +@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING) class T5ForConditionalGeneration(T5PreTrainedModel): - r""" - **lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: - Labels for computing the masked language modeling loss. - Indices should either be in ``[0, ..., config.vocab_size]`` or -100 (see ``input_ids`` docstring). - Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels - in ``[0, ..., config.vocab_size]``. - - Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: - **loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: - Masked language modeling loss. - **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) - list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) - of shape ``(batch_size, sequence_length, hidden_size)``: - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - **attentions**: (`optional`, returned when ``config.output_attentions=True``) - list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. - - Examples:: - - tokenizer = T5Tokenizer.from_pretrained('t5-small') - model = T5ForConditionalGeneration.from_pretrained('t5-small') - input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 - outputs = model(input_ids=input_ids, lm_labels=input_ids) - loss, prediction_scores = outputs[:2] - - """ - def __init__(self, config): super().__init__(config) self.model_dim = config.d_model @@ -879,6 +871,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): def get_encoder(self): return self.encoder + @add_start_docstrings_to_callable(T5_INPUTS_DOCSTRING) def forward( self, input_ids=None, @@ -891,6 +884,43 @@ class T5ForConditionalGeneration(T5PreTrainedModel): decoder_inputs_embeds=None, head_mask=None, ): + r""" + lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for computing the sequence classification/regression loss. + Indices should be in :obj:`[0, ..., config.vocab_size - 1]`. + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.T5Config`) and inputs. + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`lm_label` is provided): + Classification loss (cross entropy). + prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`) + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention. + + Examples:: + + from transformers import T5Tokenizer, T5ForConditionalGeneration + + tokenizer = T5Tokenizer.from_pretrained('t5-small') + model = T5ForConditionalGeneration.from_pretrained('t5-small') + input_ids = tokenizer.encode("Hello, my dog is cute", return_tensors="pt") # Batch size 1 + outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, lm_labels=input_ids) + loss, prediction_scores = outputs[:2] + + tokenizer = T5Tokenizer.from_pretrained('t5-small') + model = T5ForConditionalGeneration.from_pretrained('t5-small') + input_ids = tokenizer.encode("summarize: Hello, my dog is cute", return_tensors="pt") # Batch size 1 + outputs = model.generate(input_ids) + """ # Encode if needed (training, first prediction pass) if encoder_outputs is None: diff --git a/src/transformers/modeling_tf_t5.py b/src/transformers/modeling_tf_t5.py index c17b2b678..ddc6b7a80 100644 --- a/src/transformers/modeling_tf_t5.py +++ b/src/transformers/modeling_tf_t5.py @@ -24,7 +24,7 @@ import math import tensorflow as tf from .configuration_t5 import T5Config -from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings +from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_callable from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list @@ -630,8 +630,12 @@ T5_START_DOCSTRING = r""" The T5 model was proposed in """ T5_INPUTS_DOCSTRING = r""" - Inputs: - **input_ids**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length)``: + Args: + decoder_input_ids are usually used as a `dict` (see T5 description above for more information) containing all the following. + decoder_input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`): + Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation. + + input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. To match pre-training, T5 input sequence should be formatted with [CLS] and [SEP] tokens as follows: @@ -643,18 +647,31 @@ T5_INPUTS_DOCSTRING = r""" ``tokens: [CLS] the dog is hairy . [SEP]`` - T5 is a model with relative position embeddings so you should be able to pad the inputs on the right or the left. Indices can be obtained using :class:`transformers.T5Tokenizer`. See :func:`transformers.PreTrainedTokenizer.encode` and :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. - **attention_mask**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length)``: + attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. - **head_mask**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``: + encoder_outputs (tuple(:obj:`tuple(tf.FloatTensor)`, `optional`, defaults to :obj:`None`): + Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`) + `last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder. + Used in the cross-attention of the decoder. + decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`): + Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default. + inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + decoder_inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): + Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + head_mask: (:obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`): Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. @@ -664,34 +681,8 @@ T5_INPUTS_DOCSTRING = r""" @add_start_docstrings( "The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.", T5_START_DOCSTRING, - T5_INPUTS_DOCSTRING, ) class TFT5Model(TFT5PreTrainedModel): - r""" - Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: - **last_hidden_state**: ``tf.Tensor`` of shape ``(batch_size, sequence_length, hidden_size)`` - Sequence of hidden-states at the output of the last layer of the model. - **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) - list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings) - of shape ``(batch_size, sequence_length, hidden_size)``: - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - **attentions**: (`optional`, returned when ``config.output_attentions=True``) - list of ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. - - Examples:: - - import tensorflow as tf - from transformers import T5Tokenizer, TFT5Model - - tokenizer = T5Tokenizer.from_pretrained('t5-small') - model = TFT5Model.from_pretrained('t5-small') - input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1 - outputs = model(input_ids=input_ids) - last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple - - """ - def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared") @@ -715,7 +706,36 @@ class TFT5Model(TFT5PreTrainedModel): def get_output_embeddings(self): return self.shared + @add_start_docstrings_to_callable(T5_INPUTS_DOCSTRING) def call(self, decoder_input_ids, **kwargs): + r""" + Return: + :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.T5Config`) and inputs. + last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): + Tuple of :obj:`tf.Tensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + Examples:: + + from transformers import T5Tokenizer, TFT5Model + + tokenizer = T5Tokenizer.from_pretrained('t5-small') + model = TFT5Model.from_pretrained('t5-small') + input_ids = tokenizer.encode("Hello, my dog is cute", return_tensors="tf") # Batch size 1 + outputs = model(input_ids, input_ids=input_ids) + last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple + + """ if isinstance(decoder_input_ids, dict): kwargs.update(decoder_input_ids) @@ -753,33 +773,8 @@ class TFT5Model(TFT5PreTrainedModel): return decoder_outputs + encoder_outputs -@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING, T5_INPUTS_DOCSTRING) +@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING) class TFT5ForConditionalGeneration(TFT5PreTrainedModel): - r""" - Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: - **prediction_scores**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) - list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings) - of shape ``(batch_size, sequence_length, hidden_size)``: - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - **attentions**: (`optional`, returned when ``config.output_attentions=True``) - list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. - - Examples:: - - import tensorflow as tf - from transformers import T5Tokenizer, TFT5ForConditionalGeneration - - tokenizer = T5Tokenizer.from_pretrained('t5-small') - model = TFT5ForConditionalGeneration.from_pretrained('t5-small') - input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1 - outputs = model(input_ids=input_ids) - prediction_scores = outputs[0] - - """ - def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) self.model_dim = config.d_model @@ -808,7 +803,47 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel): def get_encoder(self): return self.encoder + @add_start_docstrings_to_callable(T5_INPUTS_DOCSTRING) def call(self, decoder_input_ids, **kwargs): + r""" + lm_labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for computing the sequence classification/regression loss. + Indices should be in :obj:`[0, ..., config.vocab_size - 1]`. + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Return: + :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.T5Config`) and inputs. + loss (:obj:`tf.Tensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`lm_label` is provided): + Classification loss (cross entropy). + prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`) + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): + Tuple of :obj:`tf.Tensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention. + + Examples:: + + from transformers import T5Tokenizer, TFT5ForConditionalGeneration + + tokenizer = T5Tokenizer.from_pretrained('t5-small') + model = TFT5ForConditionalGeneration.from_pretrained('t5-small') + input_ids = tokenizer.encode("Hello, my dog is cute", return_tensors="tf") # Batch size 1 + outputs = model(input_ids, input_ids=input_ids, lm_labels=input_ids) + prediction_scores = outputs[:1] # TODO: TFT5 still needs to implement + + tokenizer = T5Tokenizer.from_pretrained('t5-small') + model = TFT5ForConditionalGeneration.from_pretrained('t5-small') + input_ids = tokenizer.encode("summarize: Hello, my dog is cute", return_tensors="tf") # Batch size 1 + model.generate(input_ids) + + """ if isinstance(decoder_input_ids, dict): kwargs.update(decoder_input_ids) @@ -844,6 +879,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel): head_mask=head_mask, ) + # TODO (thom / patrick): add lm_labels for loss function sequence_output = decoder_outputs[0] * (self.model_dim ** -0.5) embed_tokens = self.get_output_embeddings() lm_logits = embed_tokens(sequence_output, mode="linear") diff --git a/src/transformers/tokenization_t5.py b/src/transformers/tokenization_t5.py index 3d0d4e698..bc1cf9fd0 100644 --- a/src/transformers/tokenization_t5.py +++ b/src/transformers/tokenization_t5.py @@ -61,14 +61,34 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { class T5Tokenizer(PreTrainedTokenizer): """ - SentencePiece based tokenizer. Peculiarities: + Constructs an XLNet tokenizer. Based on `SentencePiece `__ . - - requires `SentencePiece `_ - - `extra_ids` add a number of extra ids added to the end of the vocabulary for use as sentinels. - These tokens are accessible as `` where `{%d}` is a number between 0 and extra_ids-1. - Extra tokens are indexed from the end of the vocabulary up to beginnning ( is the last token in the vocabulary) - (like in T5 preprocessing + This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users + should refer to the superclass for more information regarding methods. + + Args: + vocab_file (:obj:`string`): + `SentencePiece `__ file (generally has a `.spm` extension) that + contains the vocabulary necessary to instantiate a tokenizer. + eos_token (:obj:`string`, `optional`, defaults to ""): + The end of sequence token. + + .. note:: + + When building a sequence using special tokens, this is not the token that is used for the end + of sequence. The token used is the :obj:`sep_token`. + unk_token (:obj:`string`, `optional`, defaults to ""): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (:obj:`string`, `optional`, defaults to ""): + The token used for padding, for example when batching sequences of different lengths. + extra_ids (:obj:`List[str]`, `optional`, defaults to :obj:`100`): + Add a number of extra ids added to the end of the vocabulary for use as sentinels. + These tokens are accessible as "" where "{%d}" is a number between 0 and extra_ids-1. + Extra tokens are indexed from the end of the vocabulary up to beginnning ("" is the last token in the vocabulary like in T5 preprocessing see: https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117) + additional_special_tokens (:obj:`List[str]`, `optional`, defaults to :obj:`None`): + Additional special tokens used by the tokenizer. """ vocab_files_names = VOCAB_FILES_NAMES