From bbf26c4e619cf42106163e1e2cd5ff98b936ff93 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 19 Mar 2020 23:18:23 +0100 Subject: [PATCH] Support T5 Generation (#3228) * fix conflicts * update bart max length test * correct spelling mistakes * implemented model specific encode function * fix merge conflicts * better naming * save intermediate state -> need to rethink strucuture a bit * leave tf problem as it is for now * current version * add layers.pop * remove ipdb * make style * clean return cut decoding * remove ipdbs * Fix restoring layers in the decoders that doesnt exists. * push good intermediate solution for now * fix conflicts * always good to refuse to merge conflicts when rebasing * fix small bug * improve function calls * remove unused file * add correct scope behavior for t5_generate Co-authored-by: Morgan Funtowicz --- src/transformers/__init__.py | 4 +- src/transformers/configuration_t5.py | 2 + .../convert_pytorch_checkpoint_to_tf2.py | 10 +- src/transformers/modeling_auto.py | 6 +- src/transformers/modeling_bart.py | 18 +- src/transformers/modeling_t5.py | 202 +++++++++------- src/transformers/modeling_tf_auto.py | 10 +- src/transformers/modeling_tf_pytorch_utils.py | 2 + src/transformers/modeling_tf_t5.py | 224 ++++++++++++------ src/transformers/modeling_tf_utils.py | 56 ++++- src/transformers/modeling_utils.py | 57 +++-- tests/test_modeling_bart.py | 10 +- tests/test_modeling_common.py | 12 +- tests/test_modeling_t5.py | 69 +++--- tests/test_modeling_tf_common.py | 17 +- tests/test_modeling_tf_t5.py | 30 ++- 16 files changed, 449 insertions(+), 280 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index e2332654e..82016686e 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -255,7 +255,7 @@ if is_torch_available(): from .modeling_t5 import ( T5PreTrainedModel, T5Model, - T5WithLMHeadModel, + T5ForConditionalGeneration, load_tf_weights_in_t5, T5_PRETRAINED_MODEL_ARCHIVE_MAP, ) @@ -444,7 +444,7 @@ if is_tf_available(): from .modeling_tf_t5 import ( TFT5PreTrainedModel, TFT5Model, - TFT5WithLMHeadModel, + TFT5ForConditionalGeneration, TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP, ) diff --git a/src/transformers/configuration_t5.py b/src/transformers/configuration_t5.py index 767bec762..a86bb2f3b 100644 --- a/src/transformers/configuration_t5.py +++ b/src/transformers/configuration_t5.py @@ -76,6 +76,8 @@ class T5Config(PretrainedConfig): layer_norm_epsilon=1e-6, initializer_factor=1.0, is_encoder_decoder=True, + pad_token_id=0, + eos_token_ids=[1], **kwargs ): super().__init__( diff --git a/src/transformers/convert_pytorch_checkpoint_to_tf2.py b/src/transformers/convert_pytorch_checkpoint_to_tf2.py index 4fb08e0f7..3f3c923d5 100755 --- a/src/transformers/convert_pytorch_checkpoint_to_tf2.py +++ b/src/transformers/convert_pytorch_checkpoint_to_tf2.py @@ -57,7 +57,7 @@ from transformers import ( TFOpenAIGPTLMHeadModel, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, - TFT5WithLMHeadModel, + TFT5ForConditionalGeneration, TFTransfoXLLMHeadModel, TFXLMRobertaForMaskedLM, TFXLMWithLMHeadModel, @@ -108,7 +108,7 @@ if is_torch_available(): CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, - T5WithLMHeadModel, + T5ForConditionalGeneration, T5_PRETRAINED_MODEL_ARCHIVE_MAP, ) else: @@ -145,7 +145,7 @@ else: CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, - T5WithLMHeadModel, + T5ForConditionalGeneration, T5_PRETRAINED_MODEL_ARCHIVE_MAP, ) = ( None, @@ -316,8 +316,8 @@ MODEL_CLASSES = { ), "t5": ( T5Config, - TFT5WithLMHeadModel, - T5WithLMHeadModel, + TFT5ForConditionalGeneration, + T5ForConditionalGeneration, T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5_PRETRAINED_CONFIG_ARCHIVE_MAP, ), diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index 98b202105..958b9f25d 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -93,7 +93,7 @@ from .modeling_roberta import ( RobertaForTokenClassification, RobertaModel, ) -from .modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5Model, T5WithLMHeadModel +from .modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5ForConditionalGeneration, T5Model from .modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TransfoXLLMHeadModel, TransfoXLModel from .modeling_xlm import ( XLM_PRETRAINED_MODEL_ARCHIVE_MAP, @@ -166,7 +166,7 @@ MODEL_MAPPING = OrderedDict( MODEL_FOR_PRETRAINING_MAPPING = OrderedDict( [ - (T5Config, T5WithLMHeadModel), + (T5Config, T5ForConditionalGeneration), (DistilBertConfig, DistilBertForMaskedLM), (AlbertConfig, AlbertForMaskedLM), (CamembertConfig, CamembertForMaskedLM), @@ -186,7 +186,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict( MODEL_WITH_LM_HEAD_MAPPING = OrderedDict( [ - (T5Config, T5WithLMHeadModel), + (T5Config, T5ForConditionalGeneration), (DistilBertConfig, DistilBertForMaskedLM), (AlbertConfig, AlbertForMaskedLM), (CamembertConfig, CamembertForMaskedLM), diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index b1c3e0346..c74b08ef1 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -885,18 +885,17 @@ class BartForConditionalGeneration(PretrainedBartModel): return outputs - def prepare_inputs_for_generation(self, decoder_input_ids, past, encoder_inputs, attention_mask): - assert attention_mask.shape == encoder_inputs.shape, "attn_mask.shape != encoder_input.shape: {} =! {}".format( - attention_mask.shape, encoder_inputs.shape - ) - if past is None: # first step - encoder_outputs, decoder_cached_states = None, None + def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, **kwargs): + assert past is not None, "past has to be defined for encoder_outputs" + + # first step, decoder_cached_states are empty + if not past[1]: + encoder_outputs, decoder_cached_states = past, None else: encoder_outputs, decoder_cached_states = past - input_ids = encoder_inputs return { - "input_ids": input_ids, # ignored after first pass + "input_ids": None, # encoder_outputs is defined. input_ids not needed "encoder_outputs": encoder_outputs, "decoder_cached_states": decoder_cached_states, "decoder_input_ids": decoder_input_ids, @@ -929,6 +928,9 @@ class BartForConditionalGeneration(PretrainedBartModel): past = ((new_enc_out, new_enc_mask), reordered_past) return past + def get_encoder(self): + return self.model.encoder + def get_output_embeddings(self): return self.lm_head diff --git a/src/transformers/modeling_t5.py b/src/transformers/modeling_t5.py index 2c8e7d827..b56917ae1 100644 --- a/src/transformers/modeling_t5.py +++ b/src/transformers/modeling_t5.py @@ -464,7 +464,7 @@ class T5PreTrainedModel(PreTrainedModel): input_mask = torch.tensor(DUMMY_MASK) dummy_inputs = { "decoder_input_ids": input_ids, - "encoder_input_ids": input_ids, + "input_ids": input_ids, "decoder_attention_mask": input_mask, } return dummy_inputs @@ -474,7 +474,7 @@ class T5PreTrainedModel(PreTrainedModel): factor = self.config.initializer_factor # Used for testing weights initialization if isinstance(module, T5LayerNorm): module.weight.data.fill_(factor * 1.0) - elif isinstance(module, (T5Model, T5WithLMHeadModel)): + elif isinstance(module, (T5Model, T5ForConditionalGeneration)): # Mesh TensorFlow embeddings initialization # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) @@ -503,10 +503,12 @@ class T5PreTrainedModel(PreTrainedModel): class T5Stack(T5PreTrainedModel): - def __init__(self, config): + def __init__(self, config, embed_tokens=None): super().__init__(config) self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states + + self.embed_tokens = embed_tokens self.is_decoder = config.is_decoder self.block = nn.ModuleList( @@ -517,21 +519,46 @@ class T5Stack(T5PreTrainedModel): self.init_weights() + def get_input_embeddings(self): + return self.embed_tokens + + def get_output_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + def forward( self, - hidden_states, + input_ids=None, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + inputs_embeds=None, head_mask=None, ): - batch_size, seq_length = hidden_states.shape[0], hidden_states.shape[1] + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + assert self.embed_tokens is not None, "You have to intialize the model with valid token embeddings" + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + if attention_mask is None: - attention_mask = torch.ones(batch_size, seq_length).to(hidden_states.device) + attention_mask = torch.ones(batch_size, seq_length).to(inputs_embeds.device) if self.is_decoder and encoder_attention_mask is None: encoder_seq_length = encoder_hidden_states.shape[1] - encoder_attention_mask = torch.ones(batch_size, encoder_seq_length).to(hidden_states.device) + encoder_attention_mask = torch.ones(batch_size, encoder_seq_length).to(inputs_embeds.device) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. @@ -542,7 +569,7 @@ class T5Stack(T5PreTrainedModel): # - if the model is a decoder, apply a causal mask in addition to the padding mask # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] if self.config.is_decoder: - seq_ids = torch.arange(seq_length, device=hidden_states.device) + seq_ids = torch.arange(seq_length, device=inputs_embeds.device) causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] causal_mask = causal_mask.to(attention_mask) extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] @@ -605,7 +632,7 @@ class T5Stack(T5PreTrainedModel): position_bias = None encoder_decoder_position_bias = None - hidden_states = self.dropout(hidden_states) + hidden_states = self.dropout(inputs_embeds) for i, layer_module in enumerate(self.block): if self.output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -731,11 +758,11 @@ class T5Model(T5PreTrainedModel): self.shared = nn.Embedding(config.vocab_size, config.d_model) encoder_config = copy.deepcopy(config) - self.encoder = T5Stack(encoder_config) + self.encoder = T5Stack(encoder_config, self.shared) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True - self.decoder = T5Stack(decoder_config) + self.decoder = T5Stack(decoder_config, self.shared) self.init_weights() @@ -744,6 +771,8 @@ class T5Model(T5PreTrainedModel): def set_input_embeddings(self, new_embeddings): self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. @@ -753,55 +782,41 @@ class T5Model(T5PreTrainedModel): for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) - def forward(self, **kwargs): - # keyword arguments come in 3 flavors: encoder-specific (prefixed by - # `encoder_`), decoder-specific (prefixed by `decoder_`) and those - # that apply to the model as whole. - # We let the specific kwargs override the common ones in case of conflict. - kwargs_common = dict( - (k, v) for k, v in kwargs.items() if not k.startswith("encoder_") and not k.startswith("decoder_") - ) - kwargs_encoder = kwargs_common.copy() - kwargs_decoder = kwargs_common.copy() - kwargs_encoder.update(dict((k[len("encoder_") :], v) for k, v in kwargs.items() if k.startswith("encoder_"))) - kwargs_decoder.update(dict((k[len("decoder_") :], v) for k, v in kwargs.items() if k.startswith("decoder_"))) + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_outputs=None, + decoder_input_ids=None, + decoder_attention_mask=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + head_mask=None, + ): # Encode if needed (training, first prediction pass) - encoder_hidden_states = kwargs_encoder.pop("hidden_states", None) - encoder_attention_mask = kwargs_encoder.get("attention_mask", None) - if encoder_hidden_states is None: - # Convert encoder inputs in embeddings if needed - hidden_states = kwargs_encoder.pop("inputs_embeds", None) - if hidden_states is None: - encoder_inputs_ids = kwargs_encoder.pop("input_ids") - hidden_states = self.shared(encoder_inputs_ids) # Convert inputs in embeddings + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask + ) - if encoder_attention_mask is not None: - # Apply masking - encoder_attention_mask = (encoder_attention_mask != 0).to(hidden_states) - hidden_states = hidden_states * encoder_attention_mask.unsqueeze(-1) - - encoder_outputs = self.encoder(hidden_states, **kwargs_encoder) - encoder_hidden_states = encoder_outputs[0] - else: - encoder_outputs = () + hidden_states = encoder_outputs[0] # Decode - # Convert decoder inputs in embeddings if needed - hidden_states = kwargs_decoder.pop("inputs_embeds", None) - if hidden_states is None: - decoder_inputs_ids = kwargs_decoder.pop("input_ids") - hidden_states = self.shared(decoder_inputs_ids) - - kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states - kwargs_decoder["encoder_attention_mask"] = encoder_attention_mask - decoder_outputs = self.decoder(hidden_states, **kwargs_decoder) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=head_mask, + ) return decoder_outputs + encoder_outputs @add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING, T5_INPUTS_DOCSTRING) -class T5WithLMHeadModel(T5PreTrainedModel): +class T5ForConditionalGeneration(T5PreTrainedModel): r""" **lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: Labels for computing the masked language modeling loss. @@ -825,7 +840,7 @@ class T5WithLMHeadModel(T5PreTrainedModel): Examples:: tokenizer = T5Tokenizer.from_pretrained('t5-small') - model = T5WithLMHeadModel.from_pretrained('t5-small') + model = T5ForConditionalGeneration.from_pretrained('t5-small') input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 outputs = model(input_ids=input_ids, lm_labels=input_ids) loss, prediction_scores = outputs[:2] @@ -839,11 +854,11 @@ class T5WithLMHeadModel(T5PreTrainedModel): self.shared = nn.Embedding(config.vocab_size, config.d_model) encoder_config = copy.deepcopy(config) - self.encoder = T5Stack(encoder_config) + self.encoder = T5Stack(encoder_config, self.shared) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True - self.decoder = T5Stack(decoder_config) + self.decoder = T5Stack(decoder_config, self.shared) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) @@ -854,50 +869,46 @@ class T5WithLMHeadModel(T5PreTrainedModel): def set_input_embeddings(self, new_embeddings): self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) def get_output_embeddings(self): return self.lm_head - def forward(self, **kwargs): - # keyword arguments come in 3 flavors: encoder-specific (prefixed by - # `encoder_`), decoder-specific (prefixed by `decoder_`) and those - # that apply to the model as whole. - # We let the specific kwargs override the common ones in case of conflict. + def get_encoder(self): + return self.encoder - lm_labels = kwargs.pop("decoder_lm_labels", None) - - kwargs_common = dict( - (k, v) for k, v in kwargs.items() if not k.startswith("encoder_") and not k.startswith("decoder_") - ) - kwargs_encoder = kwargs_common.copy() - kwargs_decoder = kwargs_common.copy() - kwargs_encoder.update(dict((k[len("encoder_") :], v) for k, v in kwargs.items() if k.startswith("encoder_"))) - kwargs_decoder.update(dict((k[len("decoder_") :], v) for k, v in kwargs.items() if k.startswith("decoder_"))) + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_outputs=None, + decoder_input_ids=None, + decoder_attention_mask=None, + lm_labels=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + head_mask=None, + ): # Encode if needed (training, first prediction pass) - encoder_hidden_states = kwargs_encoder.pop("hidden_states", None) - if encoder_hidden_states is None: + if encoder_outputs is None: # Convert encoder inputs in embeddings if needed - hidden_states = kwargs_encoder.pop("inputs_embeds", None) - if hidden_states is None: - encoder_inputs_ids = kwargs_encoder.pop("input_ids") - hidden_states = self.shared(encoder_inputs_ids) # Convert inputs in embeddings + encoder_outputs = self.encoder( + input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask + ) - encoder_outputs = self.encoder(hidden_states, **kwargs_encoder) - encoder_hidden_states = encoder_outputs[0] - else: - encoder_outputs = () + hidden_states = encoder_outputs[0] # Decode - # Convert decoder inputs in embeddings if needed - hidden_states = kwargs_decoder.pop("inputs_embeds", None) - if hidden_states is None: - decoder_inputs_ids = kwargs_decoder.pop("input_ids") - hidden_states = self.shared(decoder_inputs_ids) - - kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states - kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get("attention_mask", None) - decoder_outputs = self.decoder(hidden_states, **kwargs_decoder) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=head_mask, + ) sequence_output = decoder_outputs[0] # Rescale output before projecting on vocab @@ -916,3 +927,22 @@ class T5WithLMHeadModel(T5PreTrainedModel): ) + decoder_outputs # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 return decoder_outputs + encoder_outputs + + def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwargs): + assert past is not None, "past has to be defined for encoder_outputs" + + # first step + if type(past) is tuple: + encoder_outputs = past + else: + encoder_outputs = (past,) + + return { + "decoder_input_ids": input_ids, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + } + + def _reorder_cache(self, past, beam_idx): + # past does not have to be re-ordered for T5. + return past diff --git a/src/transformers/modeling_tf_auto.py b/src/transformers/modeling_tf_auto.py index dd661006d..8ea0e6d8a 100644 --- a/src/transformers/modeling_tf_auto.py +++ b/src/transformers/modeling_tf_auto.py @@ -66,7 +66,7 @@ from .modeling_tf_roberta import ( TFRobertaForTokenClassification, TFRobertaModel, ) -from .modeling_tf_t5 import TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP, TFT5Model, TFT5WithLMHeadModel +from .modeling_tf_t5 import TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP, TFT5ForConditionalGeneration, TFT5Model from .modeling_tf_transfo_xl import ( TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TFTransfoXLLMHeadModel, @@ -128,7 +128,7 @@ TF_MODEL_MAPPING = OrderedDict( TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict( [ - (T5Config, TFT5WithLMHeadModel), + (T5Config, TFT5ForConditionalGeneration), (DistilBertConfig, TFDistilBertForMaskedLM), (AlbertConfig, TFAlbertForMaskedLM), (RobertaConfig, TFRobertaForMaskedLM), @@ -144,7 +144,7 @@ TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict( TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict( [ - (T5Config, TFT5WithLMHeadModel), + (T5Config, TFT5ForConditionalGeneration), (DistilBertConfig, TFDistilBertForMaskedLM), (AlbertConfig, TFAlbertForMaskedLM), (RobertaConfig, TFRobertaForMaskedLM), @@ -507,7 +507,7 @@ class TFAutoModelWithLMHead(object): The model class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): - - contains `t5`: TFT5WithLMHeadModel (T5 model) + - contains `t5`: TFT5ForConditionalGeneration (T5 model) - contains `distilbert`: TFDistilBertForMaskedLM (DistilBERT model) - contains `roberta`: TFRobertaForMaskedLM (RoBERTa model) - contains `bert`: TFBertForMaskedLM (Bert model) @@ -571,7 +571,7 @@ class TFAutoModelWithLMHead(object): The model class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): - - contains `t5`: TFT5WithLMHeadModel (T5 model) + - contains `t5`: TFT5ForConditionalGeneration (T5 model) - contains `distilbert`: TFDistilBertForMaskedLM (DistilBERT model) - contains `roberta`: TFRobertaForMaskedLM (RoBERTa model) - contains `bert`: TFBertForMaskedLM (Bert model) diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index 81290326c..d8012068a 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -160,6 +160,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a if name not in pt_state_dict: if allow_missing_keys: continue + raise AttributeError("{} not found in PyTorch model".format(name)) array = pt_state_dict[name].numpy() @@ -288,6 +289,7 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F if allow_missing_keys: missing_keys_pt.append(pt_weight_name) continue + raise AttributeError("{} not found in TF 2.0 model".format(pt_weight_name)) array, transpose = tf_weights_map[pt_weight_name] diff --git a/src/transformers/modeling_tf_t5.py b/src/transformers/modeling_tf_t5.py index db62e784b..c17b2b678 100644 --- a/src/transformers/modeling_tf_t5.py +++ b/src/transformers/modeling_tf_t5.py @@ -355,16 +355,49 @@ class TFT5Block(tf.keras.layers.Layer): return outputs # hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) +class _NoLayerEmbedTokens(object): + """ + this class wraps a the TFSharedEmbeddingTokens layer into a python 'no-keras-layer' + class to avoid problem with weight restoring. Also it makes sure that the layer is + called from the correct scope to avoid problem with saving/storing the correct weights + """ + + def __init__(self, layer, abs_scope_name=None): + self._layer = layer + self._abs_scope_name = abs_scope_name + + def call(self, inputs, mode="embedding"): + if self._abs_scope_name is None: + return self._layer.call(inputs, mode) + + # if an abs scope name is given to the embedding variable, call variable from absolute scope + with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name: + with tf.name_scope(abs_scope_name.original_name_scope): + return self._layer.call(inputs, mode) + + def __call__(self, inputs, mode="embedding"): + if self._abs_scope_name is None: + return self._layer(inputs, mode) + + # if an abs scope name is given to the embedding variable, call variable from absolute scope + with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name: + with tf.name_scope(abs_scope_name.original_name_scope): + return self._layer(inputs, mode) + + #################################################### # The full model without a specific pretrained or finetuning head is # provided as a tf.keras.layers.Layer usually called "TFT5MainLayer" #################################################### class TFT5MainLayer(tf.keras.layers.Layer): - def __init__(self, config, **kwargs): + def __init__(self, config, embed_tokens=None, **kwargs): super().__init__(**kwargs) self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states + + self.embed_tokens = embed_tokens self.is_decoder = config.is_decoder + self.config = config self.num_hidden_layers = config.num_layers @@ -375,6 +408,15 @@ class TFT5MainLayer(tf.keras.layers.Layer): self.final_layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="final_layer_norm") self.dropout = tf.keras.layers.Dropout(config.dropout_rate) + def get_input_embeddings(self): + return self.embed_tokens + + def get_output_embeddings(self): + return self.embed_tokens + + def set_embed_tokens(self, embed_tokens): + self.embed_tokens = embed_tokens + def _resize_token_embeddings(self, new_num_tokens): raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models @@ -383,15 +425,31 @@ class TFT5MainLayer(tf.keras.layers.Layer): def call( self, - hidden_states, + input_ids, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + inputs_embeds=None, head_mask=None, training=False, ): - batch_size, seq_length = shape_list(hidden_states)[:2] + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + input_ids = tf.reshape(input_ids, (-1, input_shape[-1])) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + assert self.embed_tokens is not None, "You have to intialize the model with valid token embeddings" + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + if attention_mask is None: attention_mask = tf.fill((batch_size, seq_length), 1) if self.is_decoder and encoder_attention_mask is None: @@ -465,6 +523,8 @@ class TFT5MainLayer(tf.keras.layers.Layer): all_attentions = () position_bias = None encoder_decoder_position_bias = None + + hidden_states = inputs_embeds for i, layer_module in enumerate(self.block): if self.output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -527,7 +587,7 @@ class TFT5PreTrainedModel(TFPreTrainedModel): input_mask = tf.constant(DUMMY_MASK) dummy_inputs = { "decoder_input_ids": input_ids, - "encoder_input_ids": input_ids, + "input_ids": input_ids, "decoder_attention_mask": input_mask, } return dummy_inputs @@ -636,12 +696,18 @@ class TFT5Model(TFT5PreTrainedModel): super().__init__(config, *inputs, **kwargs) self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared") + # retrieve correct absolute scope for embed token wrapper + with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name: + pass + + embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name) + encoder_config = copy.deepcopy(config) - self.encoder = TFT5MainLayer(encoder_config, name="encoder") + self.encoder = TFT5MainLayer(encoder_config, embed_tokens, name="encoder") decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True - self.decoder = TFT5MainLayer(decoder_config, name="decoder") + self.decoder = TFT5MainLayer(decoder_config, embed_tokens, name="decoder") def get_input_embeddings(self): return self.shared @@ -650,54 +716,45 @@ class TFT5Model(TFT5PreTrainedModel): return self.shared def call(self, decoder_input_ids, **kwargs): - # We allow two types of multi-inputs: - # - traditional keyword arguments in the call method - # - all the arguments provided as a dict in the first positional argument of call - # The last option is useful to use the tf.keras fit() method. if isinstance(decoder_input_ids, dict): kwargs.update(decoder_input_ids) else: kwargs["decoder_input_ids"] = decoder_input_ids - kwargs_common = dict( - (k, v) for k, v in kwargs.items() if not k.startswith("encoder_") and not k.startswith("decoder_") - ) - kwargs_encoder = kwargs_common.copy() - kwargs_decoder = kwargs_common.copy() - kwargs_encoder.update(dict((k[len("encoder_") :], v) for k, v in kwargs.items() if k.startswith("encoder_"))) - kwargs_decoder.update(dict((k[len("decoder_") :], v) for k, v in kwargs.items() if k.startswith("decoder_"))) + # retrieve arguments + input_ids = kwargs.get("input_ids", None) + decoder_input_ids = kwargs.get("decoder_input_ids", None) + attention_mask = kwargs.get("attention_mask", None) + encoder_outputs = kwargs.get("encoder_outputs", None) + decoder_attention_mask = kwargs.get("decoder_attention_mask", None) + inputs_embeds = kwargs.get("inputs_embeds", None) + decoder_inputs_embeds = kwargs.get("decoder_inputs_embeds", None) + head_mask = kwargs.get("head_mask", None) # Encode if needed (training, first prediction pass) - encoder_hidden_states = kwargs_encoder.pop("hidden_states", None) - if encoder_hidden_states is None: - # Convert encoder inputs in embeddings if needed - hidden_states = kwargs_encoder.pop("inputs_embeds", None) - if hidden_states is None: - encoder_inputs_ids = kwargs_encoder.pop("input_ids") - hidden_states = self.shared(encoder_inputs_ids) # Convert inputs in embeddings + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask + ) - encoder_outputs = self.encoder(hidden_states, **kwargs_encoder) - encoder_hidden_states = encoder_outputs[0] - else: - encoder_outputs = () + hidden_states = encoder_outputs[0] # Decode - # Convert decoder inputs in embeddings if needed - hidden_states = kwargs_decoder.pop("inputs_embeds", None) - if hidden_states is None: - decoder_inputs_ids = kwargs_decoder.pop("input_ids") - hidden_states = self.shared(decoder_inputs_ids) - - kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states - kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get("attention_mask", None) - decoder_outputs = self.decoder(hidden_states, **kwargs_decoder) + decoder_outputs = self.decoder( + decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=head_mask, + ) return decoder_outputs + encoder_outputs @add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING, T5_INPUTS_DOCSTRING) -class TFT5WithLMHeadModel(TFT5PreTrainedModel): +class TFT5ForConditionalGeneration(TFT5PreTrainedModel): r""" Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: **prediction_scores**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` @@ -713,10 +770,10 @@ class TFT5WithLMHeadModel(TFT5PreTrainedModel): Examples:: import tensorflow as tf - from transformers import T5Tokenizer, TFT5WithLMHeadModel + from transformers import T5Tokenizer, TFT5ForConditionalGeneration tokenizer = T5Tokenizer.from_pretrained('t5-small') - model = TFT5WithLMHeadModel.from_pretrained('t5-small') + model = TFT5ForConditionalGeneration.from_pretrained('t5-small') input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1 outputs = model(input_ids=input_ids) prediction_scores = outputs[0] @@ -729,12 +786,18 @@ class TFT5WithLMHeadModel(TFT5PreTrainedModel): self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared") + # retrieve correct absolute scope for embed token wrapper + with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name: + pass + + embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name) + encoder_config = copy.deepcopy(config) - self.encoder = TFT5MainLayer(encoder_config, name="encoder") + self.encoder = TFT5MainLayer(encoder_config, embed_tokens, name="encoder") decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True - self.decoder = TFT5MainLayer(decoder_config, name="decoder") + self.decoder = TFT5MainLayer(decoder_config, embed_tokens, name="decoder") def get_input_embeddings(self): return self.shared @@ -742,52 +805,67 @@ class TFT5WithLMHeadModel(TFT5PreTrainedModel): def get_output_embeddings(self): return self.shared + def get_encoder(self): + return self.encoder + def call(self, decoder_input_ids, **kwargs): - # We allow two types of multi-inputs: - # - traditional keyword arguments in the call method - # - all the arguments provided as a dict in the first positional argument of call - # The last option is useful to use the tf.keras fit() method. if isinstance(decoder_input_ids, dict): kwargs.update(decoder_input_ids) else: kwargs["decoder_input_ids"] = decoder_input_ids - kwargs_common = dict( - (k, v) for k, v in kwargs.items() if not k.startswith("encoder_") and not k.startswith("decoder_") - ) - kwargs_encoder = kwargs_common.copy() - kwargs_decoder = kwargs_common.copy() - kwargs_encoder.update(dict((k[len("encoder_") :], v) for k, v in kwargs.items() if k.startswith("encoder_"))) - kwargs_decoder.update(dict((k[len("decoder_") :], v) for k, v in kwargs.items() if k.startswith("decoder_"))) + # retrieve arguments + input_ids = kwargs.get("input_ids", None) + decoder_input_ids = kwargs.get("decoder_input_ids", None) + attention_mask = kwargs.get("attention_mask", None) + encoder_outputs = kwargs.get("encoder_outputs", None) + decoder_attention_mask = kwargs.get("decoder_attention_mask", None) + inputs_embeds = kwargs.get("inputs_embeds", None) + decoder_inputs_embeds = kwargs.get("decoder_inputs_embeds", None) + head_mask = kwargs.get("head_mask", None) # Encode if needed (training, first prediction pass) - encoder_hidden_states = kwargs_encoder.pop("hidden_states", None) - if encoder_hidden_states is None: + if encoder_outputs is None: # Convert encoder inputs in embeddings if needed - hidden_states = kwargs_encoder.pop("inputs_embeds", None) - if hidden_states is None: - encoder_inputs_ids = kwargs_encoder.pop("input_ids") - hidden_states = self.shared(encoder_inputs_ids) # Convert inputs in embeddings + encoder_outputs = self.encoder( + input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask + ) - encoder_outputs = self.encoder(hidden_states, **kwargs_encoder) - encoder_hidden_states = encoder_outputs[0] - else: - encoder_outputs = () + hidden_states = encoder_outputs[0] # Decode - # Convert decoder inputs in embeddings if needed - hidden_states = kwargs_decoder.pop("inputs_embeds", None) - if hidden_states is None: - decoder_inputs_ids = kwargs_decoder.pop("input_ids") - hidden_states = self.shared(decoder_inputs_ids) - - kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states - kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get("attention_mask", None) - decoder_outputs = self.decoder(hidden_states, **kwargs_decoder) + decoder_outputs = self.decoder( + decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=head_mask, + ) sequence_output = decoder_outputs[0] * (self.model_dim ** -0.5) - lm_logits = self.shared(sequence_output, mode="linear") + embed_tokens = self.get_output_embeddings() + lm_logits = embed_tokens(sequence_output, mode="linear") decoder_outputs = (lm_logits,) + decoder_outputs[1:] return decoder_outputs + encoder_outputs + + def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwargs): + assert past is not None, "past has to be defined for encoder_outputs" + + # first step + if type(past) is tuple: + encoder_outputs = past + else: + encoder_outputs = (past,) + + return { + "inputs": input_ids, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + } + + def _reorder_cache(self, past, beam_idx): + # past does not have to be re-ordered for T5. + return past diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 774085b8a..a9767ccfa 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -474,6 +474,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): no_repeat_ngram_size=None, num_return_sequences=None, attention_mask=None, + decoder_start_token_id=None, ): r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling and beam-search. @@ -586,7 +587,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): if self.get_output_embeddings() is None: raise AttributeError( "You tried to generate sequences with a model that does not have a LM Head." - "Please use another model class (e.g. `TFOpenAIGPTLMHeadModel`, `TFXLNetLMHeadModel`, `TFGPT2LMHeadModel`, `TFCTRLLMHeadModel`, `TFT5WithLMHeadModel`, `TFTransfoXLLMHeadModel`, `TFXLMWithLMHeadModel`)" + "Please use another model class (e.g. `TFOpenAIGPTLMHeadModel`, `TFXLNetLMHeadModel`, `TFGPT2LMHeadModel`, `TFCTRLLMHeadModel`, `TFT5ForConditionalGeneration`, `TFTransfoXLLMHeadModel`)" ) max_length = max_length if max_length is not None else self.config.max_length @@ -608,6 +609,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): num_return_sequences = ( num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences ) + decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id if input_ids is not None: batch_size = shape_list(input_ids)[0] # overriden by the input batch_size @@ -634,6 +636,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): assert (eos_token_ids is None) or ( isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids) ), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers." + assert ( + decoder_start_token_id is not None or self.config.is_encoder_decoder is False + ), "`decoder_start_token_id` has to be defined if model is encoder-decoder model" assert length_penalty > 0, "`length_penalty` should be strictely positive." assert ( isinstance(num_return_sequences, int) and num_return_sequences > 0 @@ -703,6 +708,25 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): attention_mask, (effective_batch_size * num_beams, input_ids_len) ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) + if self.config.is_encoder_decoder: + + assert bos_token_id is not None, "Encoder Decoder Models need to have a bos_token_id" + assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self) + assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder) + + # get encoder and store encoder outputs + encoder = self.get_encoder() + + encoder_outputs = encoder(input_ids, attention_mask=attention_mask) + + # create empty decoder_input_ids + input_ids = tf.ones((effective_batch_size * num_beams, 1), dtype=tf.int32,) * decoder_start_token_id + cur_len = 1 + + else: + encoder_outputs = None + cur_len = shape_list(input_ids)[-1] + if num_beams > 1: output = self._generate_beam_search( input_ids, @@ -716,13 +740,16 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): top_p=top_p, repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, + bos_token_id=bos_token_id, pad_token_id=pad_token_id, eos_token_ids=eos_token_ids, + decoder_start_token_id=decoder_start_token_id, batch_size=effective_batch_size, num_return_sequences=num_return_sequences, length_penalty=length_penalty, num_beams=num_beams, vocab_size=vocab_size, + encoder_outputs=encoder_outputs, attention_mask=attention_mask, ) else: @@ -737,10 +764,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): top_p=top_p, repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, + bos_token_id=bos_token_id, pad_token_id=pad_token_id, eos_token_ids=eos_token_ids, + decoder_start_token_id=decoder_start_token_id, batch_size=effective_batch_size, vocab_size=vocab_size, + encoder_outputs=encoder_outputs, attention_mask=attention_mask, ) @@ -758,10 +788,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): top_p, repetition_penalty, no_repeat_ngram_size, + bos_token_id, pad_token_id, eos_token_ids, + decoder_start_token_id, batch_size, vocab_size, + encoder_outputs, attention_mask, ): """ Generate sequences for each example without beam search (num_beams == 1). @@ -772,7 +805,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): unfinished_sents = tf.ones_like(input_ids[:, 0]) sent_lengths = tf.ones_like(input_ids[:, 0]) * max_length - past = None + past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models while cur_len < max_length: model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask) @@ -859,6 +892,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): if tf.math.reduce_max(unfinished_sents) == 0: break + # extend attention_mask for new generated input if only decoder + if self.config.is_encoder_decoder is False: + attention_mask = tf.concat( + [attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1 + ) + cur_len = cur_len + 1 # if there are different sentences lengths in the batch, some batches have to be padded @@ -896,13 +935,16 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): top_p, repetition_penalty, no_repeat_ngram_size, + bos_token_id, pad_token_id, eos_token_ids, + decoder_start_token_id, batch_size, num_return_sequences, length_penalty, num_beams, vocab_size, + encoder_outputs, attention_mask, ): """ Generate sequences for each example with beam search. @@ -923,8 +965,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): beam_scores = tf.zeros((batch_size, num_beams), dtype=tf.float32) beam_scores = tf.reshape(beam_scores, (batch_size * num_beams,)) + # cache compute states - past = None + past = encoder_outputs # done sentences done = [False for _ in range(batch_size)] @@ -1088,9 +1131,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): input_ids = tf.stack([tf.identity(input_ids[x, :]) for x in beam_idx]) input_ids = tf.concat([input_ids, tf.expand_dims(beam_tokens, 1)], axis=-1) # re-order internal states - if past: + if past is not None: past = self._reorder_cache(past, beam_idx) + if self.config.is_encoder_decoder is False: + attention_mask = tf.concat( + [attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1 + ) + # update current length cur_len = cur_len + 1 diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 80abcdbc9..97bee1809 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -806,10 +806,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): num_return_sequences = ( num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences ) - # TODO: think about how to make this cleaner - decoder_start_token_id = ( - decoder_start_token_id if decoder_start_token_id is not None else self.config.bos_token_id - ) + decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id if input_ids is not None: batch_size = input_ids.shape[0] # overriden by the input batch_size @@ -912,20 +909,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): attention_mask = attention_mask.contiguous().view( effective_batch_size * num_beams, input_ids_len ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) + if self.config.is_encoder_decoder: assert bos_token_id is not None, "Encoder Decoder Models need to have a bos_token_id" - # encoder decoder need to start with empty input_ids and copy the input_ids to encoder_inputs - encoder_inputs = input_ids + assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self) + assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder) + + # get encoder and store encoder outputs + encoder = self.get_encoder() + + encoder_outputs = encoder(input_ids, attention_mask=attention_mask) + + # create empty decoder_input_ids input_ids = torch.full( (effective_batch_size * num_beams, 1), - decoder_start_token_id, # TODO: see whether this is the best result + decoder_start_token_id, dtype=torch.long, device=next(self.parameters()).device, ) cur_len = 1 - else: - encoder_inputs = None + encoder_outputs = None cur_len = input_ids.shape[-1] if num_beams > 1: @@ -944,12 +948,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): bos_token_id=bos_token_id, pad_token_id=pad_token_id, eos_token_ids=eos_token_ids, + decoder_start_token_id=decoder_start_token_id, batch_size=effective_batch_size, num_return_sequences=num_return_sequences, length_penalty=length_penalty, num_beams=num_beams, vocab_size=vocab_size, - encoder_inputs=encoder_inputs, + encoder_outputs=encoder_outputs, attention_mask=attention_mask, ) else: @@ -964,10 +969,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): top_p=top_p, repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, + bos_token_id=bos_token_id, pad_token_id=pad_token_id, eos_token_ids=eos_token_ids, + decoder_start_token_id=decoder_start_token_id, batch_size=effective_batch_size, - encoder_inputs=encoder_inputs, + encoder_outputs=encoder_outputs, attention_mask=attention_mask, ) @@ -985,10 +992,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): top_p, repetition_penalty, no_repeat_ngram_size, + bos_token_id, pad_token_id, eos_token_ids, + decoder_start_token_id, batch_size, - encoder_inputs, + encoder_outputs, attention_mask, ): """ Generate sequences for each example without beam search (num_beams == 1). @@ -998,11 +1007,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): unfinished_sents = input_ids.new(batch_size).fill_(1) sent_lengths = input_ids.new(batch_size).fill_(max_length) - past = None + past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models + while cur_len < max_length: - model_inputs = self.prepare_inputs_for_generation( - input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask - ) + model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask) outputs = self(**model_inputs) next_token_logits = outputs[0][:, -1, :] @@ -1099,12 +1107,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): bos_token_id, pad_token_id, eos_token_ids, + decoder_start_token_id, batch_size, num_return_sequences, length_penalty, num_beams, vocab_size, - encoder_inputs, + encoder_outputs, attention_mask, ): """ Generate sequences for each example with beam search. @@ -1125,15 +1134,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,) # cache compute states - past = None + past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models # done sentences done = [False for _ in range(batch_size)] while cur_len < max_length: - model_inputs = self.prepare_inputs_for_generation( - input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask - ) + model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask) outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size) next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size) @@ -1152,8 +1159,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) if self.config.is_encoder_decoder and do_sample is False: - # TODO: maybe give better naming - scores = self.prepare_scores_for_generation(scores, cur_len, max_length) + # TODO (PVP) still a bit hacky here - there might be a better solutino + scores = self.prepare_scores_for_generation(scores, cur_len=cur_len, max_length=max_length) # set eos token prob to zero if min_length is not reached if eos_token_ids is not None and cur_len < min_length: @@ -1278,7 +1285,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1) # re-order internal states - if past: + if past is not None: past = self._reorder_cache(past, beam_idx) # extend attention_mask for new generated input if only decoder @@ -1345,8 +1352,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): assert (len(hypo) == max_length for hypo in best) decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device) - if self.config.is_encoder_decoder: - return decoded[:, 1:] return decoded # force one of token_ids to be generated by setting prob of all other tokens to 0. diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index f08028c8d..4e26ee68c 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -82,7 +82,7 @@ class ModelTester: dropout=self.hidden_dropout_prob, attention_dropout=self.attention_probs_dropout_prob, max_position_embeddings=self.max_position_embeddings, - eos_token_ids=[2], + eos_token_ids=self.eos_token_ids, bos_token_id=self.bos_token_id, pad_token_id=self.pad_token_id, ) @@ -234,12 +234,10 @@ class BartHeadTests(unittest.TestCase): def test_lm_forward(self): config, input_ids, batch_size = self._get_config_and_data(output_past=False) - decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size).to(torch_device) + lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size).to(torch_device) lm_model = BartForConditionalGeneration(config) lm_model.to(torch_device) - loss, logits, enc_features = lm_model( - input_ids=input_ids, lm_labels=decoder_lm_labels, decoder_input_ids=input_ids - ) + loss, logits, enc_features = lm_model(input_ids=input_ids, lm_labels=lm_labels, decoder_input_ids=input_ids) expected_shape = (batch_size, input_ids.shape[1], config.vocab_size) self.assertEqual(logits.shape, expected_shape) self.assertIsInstance(loss.item(), float) @@ -292,7 +290,7 @@ class BartHeadTests(unittest.TestCase): no_repeat_ngram_size=3, max_length=max_length, ) - self.assertEqual(new_input_ids.shape, (input_ids.shape[0], max_length - 1)) + self.assertEqual(new_input_ids.shape, (input_ids.shape[0], max_length)) # TODO(SS): uneven length batches, empty inputs def test_shift_tokens_right(self): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 23dee7947..a0d0fe402 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -147,7 +147,7 @@ class ModelTesterMixin: 4 # decoder_features_or_logits, decoder_attentions, encoder_features, encoder_attentions ) decoder_attention_idx = 1 - if "lm_labels" in inputs_dict or "decoder_lm_labels" in inputs_dict: # loss will come first + if "lm_labels" in inputs_dict: # loss will come first correct_outlen += 1 # compute loss decoder_attention_idx += 1 self.assertEqual(out_len, correct_outlen) @@ -601,9 +601,9 @@ class ModelTesterMixin: input_ids = inputs_dict["input_ids"] del inputs_dict["input_ids"] else: - encoder_input_ids = inputs_dict["encoder_input_ids"] + encoder_input_ids = inputs_dict["input_ids"] decoder_input_ids = inputs_dict.get("decoder_input_ids", encoder_input_ids) - del inputs_dict["encoder_input_ids"] + del inputs_dict["input_ids"] inputs_dict.pop("decoder_input_ids", None) for model_class in self.all_model_classes: @@ -615,7 +615,7 @@ class ModelTesterMixin: if not self.is_encoder_decoder: inputs_dict["inputs_embeds"] = wte(input_ids) else: - inputs_dict["encoder_inputs_embeds"] = wte(encoder_input_ids) + inputs_dict["inputs_embeds"] = wte(encoder_input_ids) inputs_dict["decoder_inputs_embeds"] = wte(decoder_input_ids) with torch.no_grad(): @@ -624,9 +624,7 @@ class ModelTesterMixin: def test_lm_head_model_random_generate(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - input_ids = inputs_dict.get( - "input_ids", None - ) # TODO (PVP): ugly workaround to make code work for t5 for the moment - has to changed when t5 is fixed. + input_ids = inputs_dict.get("input_ids") if self.is_encoder_decoder: config.output_past = True # needed for Bart TODO: might have to update for other encoder-decoder models diff --git a/tests/test_modeling_t5.py b/tests/test_modeling_t5.py index 1d7738b64..c8f9de3cc 100644 --- a/tests/test_modeling_t5.py +++ b/tests/test_modeling_t5.py @@ -24,14 +24,15 @@ from .utils import CACHE_DIR, require_torch, slow, torch_device if is_torch_available(): - from transformers import T5Config, T5Model, T5WithLMHeadModel + from transformers import T5Config, T5Model, T5ForConditionalGeneration from transformers.modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_MAP @require_torch class T5ModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = (T5Model, T5WithLMHeadModel) if is_torch_available() else () + all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () + all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else () test_pruning = False test_torchscript = False test_resize_embeddings = False @@ -56,6 +57,8 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): relative_attention_num_buckets=8, dropout_rate=0.1, initializer_factor=0.002, + eos_token_ids=[1], + pad_token_id=0, scope=None, ): self.parent = parent @@ -75,20 +78,22 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): self.dropout_rate = dropout_rate self.initializer_factor = initializer_factor self.scope = scope + self.eos_token_ids = eos_token_ids + self.pad_token_id = pad_token_id def prepare_config_and_inputs(self): - encoder_input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size) + input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size) decoder_input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) - encoder_attention_mask = None + attention_mask = None decoder_attention_mask = None if self.use_attention_mask: - encoder_attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2) + attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2) decoder_attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2) - decoder_lm_labels = None + lm_labels = None if self.use_labels: - decoder_lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) config = T5Config( vocab_size=self.vocab_size, @@ -101,41 +106,36 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): relative_attention_num_buckets=self.relative_attention_num_buckets, dropout_rate=self.dropout_rate, initializer_factor=self.initializer_factor, + eos_token_ids=self.eos_token_ids, + bos_token_id=self.pad_token_id, + pad_token_id=self.pad_token_id, ) return ( config, - encoder_input_ids, + input_ids, decoder_input_ids, - encoder_attention_mask, + attention_mask, decoder_attention_mask, - decoder_lm_labels, + lm_labels, ) def check_loss_output(self, result): self.parent.assertListEqual(list(result["loss"].size()), []) def create_and_check_t5_model( - self, - config, - encoder_input_ids, - decoder_input_ids, - encoder_attention_mask, - decoder_attention_mask, - decoder_lm_labels, + self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels, ): model = T5Model(config=config) model.to(torch_device) model.eval() decoder_output, encoder_output = model( - encoder_input_ids=encoder_input_ids, + input_ids=input_ids, decoder_input_ids=decoder_input_ids, - encoder_attention_mask=encoder_attention_mask, + attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, ) - decoder_output, encoder_output = model( - encoder_input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids - ) + decoder_output, encoder_output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) result = { "encoder_output": encoder_output, @@ -149,22 +149,16 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): ) def create_and_check_t5_with_lm_head( - self, - config, - encoder_input_ids, - decoder_input_ids, - encoder_attention_mask, - decoder_attention_mask, - decoder_lm_labels, + self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels, ): - model = T5WithLMHeadModel(config=config) + model = T5ForConditionalGeneration(config=config) model.to(torch_device) model.eval() outputs = model( - encoder_input_ids=encoder_input_ids, + input_ids=input_ids, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, - decoder_lm_labels=decoder_lm_labels, + lm_labels=lm_labels, ) loss, prediction_scores, encoder_features = outputs self.parent.assertEqual(len(outputs), 3) @@ -181,17 +175,18 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.prepare_config_and_inputs() ( config, - encoder_input_ids, + input_ids, decoder_input_ids, - encoder_attention_mask, + attention_mask, decoder_attention_mask, - decoder_lm_labels, + lm_labels, ) = config_and_inputs + inputs_dict = { - "encoder_input_ids": encoder_input_ids, + "input_ids": input_ids, + "attention_mask": attention_mask, "decoder_input_ids": decoder_input_ids, "decoder_attention_mask": decoder_attention_mask, - "encoder_attention_mask": encoder_attention_mask, } return config, inputs_dict diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index dc52472e8..d2d7fd0b4 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -148,10 +148,12 @@ class TFModelTesterMixin: pt_model_class = getattr(transformers, pt_model_class_name) config.output_hidden_states = True + tf_model = model_class(config) pt_model = pt_model_class(config) # Check we can load pt model in tf and vice-versa with model => model functions + tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=inputs_dict) pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model) @@ -221,7 +223,7 @@ class TFModelTesterMixin: if self.is_encoder_decoder: input_ids = { "decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"), - "encoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="encoder_input_ids", dtype="int32"), + "input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"), } else: input_ids = tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32") @@ -393,9 +395,9 @@ class TFModelTesterMixin: input_ids = inputs_dict["input_ids"] del inputs_dict["input_ids"] else: - encoder_input_ids = inputs_dict["encoder_input_ids"] + encoder_input_ids = inputs_dict["input_ids"] decoder_input_ids = inputs_dict["decoder_input_ids"] - del inputs_dict["encoder_input_ids"] + del inputs_dict["input_ids"] del inputs_dict["decoder_input_ids"] for model_class in self.all_model_classes: @@ -405,7 +407,7 @@ class TFModelTesterMixin: if not self.is_encoder_decoder: inputs_dict["inputs_embeds"] = self._get_embeds(wte, input_ids) else: - inputs_dict["encoder_inputs_embeds"] = self._get_embeds(wte, encoder_input_ids) + inputs_dict["inputs_embeds"] = self._get_embeds(wte, encoder_input_ids) inputs_dict["decoder_inputs_embeds"] = self._get_embeds(wte, decoder_input_ids) model(inputs_dict) @@ -413,9 +415,10 @@ class TFModelTesterMixin: def test_lm_head_model_random_generate(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - input_ids = inputs_dict.get( - "input_ids", None - ) # TODO (PVP): ugly workaround to make code work for t5 for the moment - has to changed when t5 is fixed. + input_ids = inputs_dict["input_ids"] + + if self.is_encoder_decoder: + config.output_past = True # needed for Bart TODO: might have to update for other encoder-decoder models for model_class in self.all_generative_model_classes: model = model_class(config) diff --git a/tests/test_modeling_tf_t5.py b/tests/test_modeling_tf_t5.py index d5589eaf1..731de2540 100644 --- a/tests/test_modeling_tf_t5.py +++ b/tests/test_modeling_tf_t5.py @@ -24,14 +24,15 @@ from .utils import CACHE_DIR, require_tf, slow if is_tf_available(): - from transformers.modeling_tf_t5 import TFT5Model, TFT5WithLMHeadModel + from transformers.modeling_tf_t5 import TFT5Model, TFT5ForConditionalGeneration @require_tf class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): is_encoder_decoder = True - all_model_classes = (TFT5Model, TFT5WithLMHeadModel) if is_tf_available() else () + all_model_classes = (TFT5Model, TFT5ForConditionalGeneration) if is_tf_available() else () + all_generative_model_classes = (TFT5ForConditionalGeneration,) if is_tf_available() else () class TFT5ModelTester(object): def __init__( @@ -51,6 +52,8 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): relative_attention_num_buckets=8, dropout_rate=0.1, initializer_factor=0.002, + eos_token_ids=[1], + pad_token_id=0, scope=None, ): self.parent = parent @@ -68,6 +71,8 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): self.relative_attention_num_buckets = relative_attention_num_buckets self.dropout_rate = dropout_rate self.initializer_factor = initializer_factor + self.eos_token_ids = eos_token_ids + self.pad_token_id = pad_token_id self.scope = scope def prepare_config_and_inputs(self): @@ -92,6 +97,9 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): relative_attention_num_buckets=self.relative_attention_num_buckets, dropout_rate=self.dropout_rate, initializer_factor=self.initializer_factor, + eos_token_ids=self.eos_token_ids, + bos_token_id=self.pad_token_id, + pad_token_id=self.pad_token_id, ) return (config, input_ids, input_mask, token_labels) @@ -99,15 +107,13 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): def create_and_check_t5_model(self, config, input_ids, input_mask, token_labels): model = TFT5Model(config=config) inputs = { - "encoder_input_ids": input_ids, + "input_ids": input_ids, "decoder_input_ids": input_ids, "decoder_attention_mask": input_mask, } encoder_output, decoder_output = model(inputs) - encoder_output, decoder_output = model( - input_ids, decoder_attention_mask=input_mask, encoder_input_ids=input_ids - ) + encoder_output, decoder_output = model(input_ids, decoder_attention_mask=input_mask, input_ids=input_ids) result = { "encoder_output": encoder_output.numpy(), @@ -121,13 +127,15 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): ) def create_and_check_t5_with_lm_head(self, config, input_ids, input_mask, token_labels): - model = TFT5WithLMHeadModel(config=config) - inputs = { - "encoder_input_ids": input_ids, + model = TFT5ForConditionalGeneration(config=config) + inputs_dict = { + "input_ids": input_ids, "decoder_input_ids": input_ids, "decoder_attention_mask": input_mask, } - prediction_scores, decoder_output = model(inputs) + + prediction_scores, decoder_output = model(inputs_dict) + result = { "prediction_scores": prediction_scores.numpy(), } @@ -139,7 +147,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): config_and_inputs = self.prepare_config_and_inputs() (config, input_ids, input_mask, token_labels) = config_and_inputs inputs_dict = { - "encoder_input_ids": input_ids, + "input_ids": input_ids, "decoder_input_ids": input_ids, "decoder_attention_mask": input_mask, }