From c2a26ec8a6cefb51c22f366dca15553cfa5e36fc Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 24 Jun 2020 16:09:17 +0200 Subject: [PATCH] [Use cache] Align logic of `use_cache` with output_attentions and output_hidden_states (#5194) * fix use cache * add bart use cache * fix bart * finish bart --- src/transformers/modeling_bart.py | 21 +++++++++++++++++++-- src/transformers/modeling_ctrl.py | 5 +++-- src/transformers/modeling_gpt2.py | 7 ++++--- src/transformers/modeling_t5.py | 12 +++++++++--- src/transformers/modeling_tf_ctrl.py | 4 +++- src/transformers/modeling_tf_gpt2.py | 9 ++++++--- src/transformers/modeling_tf_t5.py | 14 +++++++++++--- tests/test_modeling_bart.py | 1 + tests/test_modeling_gpt2.py | 9 ++++++++- tests/test_modeling_t5.py | 9 ++++++++- tests/test_modeling_tf_common.py | 1 + tests/test_modeling_tf_gpt2.py | 9 ++++++++- tests/test_modeling_tf_t5.py | 10 +++++++++- 13 files changed, 90 insertions(+), 21 deletions(-) diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index 42ea80ce1..8f090f5b1 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -815,14 +815,19 @@ class BartModel(PretrainedBartModel): encoder_outputs: Optional[Tuple] = None, decoder_attention_mask=None, decoder_cached_states=None, - use_cache=False, + use_cache=None, output_attentions=None, output_hidden_states=None, ): + + if decoder_input_ids is None: + use_cache = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + use_cache = use_cache if use_cache is not None else self.config.use_cache # make masks if user doesn't supply if not use_cache: @@ -915,7 +920,7 @@ class BartForConditionalGeneration(PretrainedBartModel): decoder_attention_mask=None, decoder_cached_states=None, labels=None, - use_cache=False, + use_cache=None, output_attentions=None, output_hidden_states=None, **unused, @@ -968,6 +973,9 @@ class BartForConditionalGeneration(PretrainedBartModel): ) labels = unused.pop("lm_labels") + if labels is not None: + use_cache = False + outputs = self.model( input_ids, attention_mask=attention_mask, @@ -1070,6 +1078,7 @@ class BartForSequenceClassification(PretrainedBartModel): labels=None, output_attentions=None, output_hidden_states=None, + use_cache=None, ): r""" labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): @@ -1106,6 +1115,9 @@ class BartForSequenceClassification(PretrainedBartModel): loss, logits = outputs[:2] """ + if labels is not None: + use_cache = False + outputs = self.model( input_ids, attention_mask=attention_mask, @@ -1114,6 +1126,7 @@ class BartForSequenceClassification(PretrainedBartModel): encoder_outputs=encoder_outputs, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + use_cache=use_cache, ) x = outputs[0] # last hidden state eos_mask = input_ids.eq(self.config.eos_token_id) @@ -1159,6 +1172,7 @@ class BartForQuestionAnswering(PretrainedBartModel): end_positions=None, output_attentions=None, output_hidden_states=None, + use_cache=None, ): r""" start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): @@ -1206,6 +1220,8 @@ class BartForQuestionAnswering(PretrainedBartModel): answer = ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1]) """ + if start_positions is not None and end_positions is not None: + use_cache = False outputs = self.model( input_ids, @@ -1215,6 +1231,7 @@ class BartForQuestionAnswering(PretrainedBartModel): encoder_outputs=encoder_outputs, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + use_cache=use_cache, ) sequence_output = outputs[0] diff --git a/src/transformers/modeling_ctrl.py b/src/transformers/modeling_ctrl.py index d10fe66b0..719b8ccdf 100644 --- a/src/transformers/modeling_ctrl.py +++ b/src/transformers/modeling_ctrl.py @@ -335,7 +335,7 @@ class CTRLModel(CTRLPreTrainedModel): position_ids=None, head_mask=None, inputs_embeds=None, - use_cache=True, + use_cache=None, output_attentions=None, output_hidden_states=None, ): @@ -374,6 +374,7 @@ class CTRLModel(CTRLPreTrainedModel): """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + use_cache = use_cache if use_cache is not None else self.config.use_cache output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) @@ -519,7 +520,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): head_mask=None, inputs_embeds=None, labels=None, - use_cache=True, + use_cache=None, output_attentions=None, output_hidden_states=None, ): diff --git a/src/transformers/modeling_gpt2.py b/src/transformers/modeling_gpt2.py index ed9175dd0..36bd0bc7b 100644 --- a/src/transformers/modeling_gpt2.py +++ b/src/transformers/modeling_gpt2.py @@ -379,7 +379,7 @@ class GPT2Model(GPT2PreTrainedModel): position_ids=None, head_mask=None, inputs_embeds=None, - use_cache=True, + use_cache=None, output_attentions=None, output_hidden_states=None, ): @@ -420,6 +420,7 @@ class GPT2Model(GPT2PreTrainedModel): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + use_cache = use_cache if use_cache is not None else self.config.use_cache if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") @@ -562,7 +563,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): head_mask=None, inputs_embeds=None, labels=None, - use_cache=True, + use_cache=None, output_attentions=None, output_hidden_states=None, ): @@ -671,7 +672,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): mc_token_ids=None, labels=None, mc_labels=None, - use_cache=True, + use_cache=None, output_attentions=None, output_hidden_states=None, **kwargs diff --git a/src/transformers/modeling_t5.py b/src/transformers/modeling_t5.py index 4a606fcf5..572b8382b 100644 --- a/src/transformers/modeling_t5.py +++ b/src/transformers/modeling_t5.py @@ -659,11 +659,12 @@ class T5Stack(T5PreTrainedModel): inputs_embeds=None, head_mask=None, past_key_value_states=None, - use_cache=False, + use_cache=None, output_attentions=None, output_hidden_states=None, ): + use_cache = use_cache if use_cache is not None else self.config.use_cache output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -854,6 +855,7 @@ class T5Model(T5PreTrainedModel): self.shared = nn.Embedding(config.vocab_size, config.d_model) encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False self.encoder = T5Stack(encoder_config, self.shared) decoder_config = copy.deepcopy(config) @@ -893,7 +895,7 @@ class T5Model(T5PreTrainedModel): decoder_input_ids=None, decoder_attention_mask=None, decoder_past_key_value_states=None, - use_cache=True, + use_cache=None, inputs_embeds=None, decoder_inputs_embeds=None, head_mask=None, @@ -933,6 +935,7 @@ class T5Model(T5PreTrainedModel): last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple """ + use_cache = use_cache if use_cache is not None else self.config.use_cache # Encode if needed (training, first prediction pass) if encoder_outputs is None: @@ -985,6 +988,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): self.shared = nn.Embedding(config.vocab_size, config.d_model) encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False self.encoder = T5Stack(encoder_config, self.shared) decoder_config = copy.deepcopy(config) @@ -1021,7 +1025,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): decoder_input_ids=None, decoder_attention_mask=None, decoder_past_key_value_states=None, - use_cache=True, + use_cache=None, labels=None, inputs_embeds=None, decoder_inputs_embeds=None, @@ -1086,6 +1090,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel): labels = kwargs.pop("lm_labels") assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." + use_cache = use_cache if use_cache is not None else self.config.use_cache + # Encode if needed (training, first prediction pass) if encoder_outputs is None: # Convert encoder inputs in embeddings if needed diff --git a/src/transformers/modeling_tf_ctrl.py b/src/transformers/modeling_tf_ctrl.py index 34b8d1aab..5c93b89eb 100644 --- a/src/transformers/modeling_tf_ctrl.py +++ b/src/transformers/modeling_tf_ctrl.py @@ -186,6 +186,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): super().__init__(**kwargs) self.output_hidden_states = config.output_hidden_states self.output_attentions = config.output_attentions + self.use_cache = config.use_cache self.d_model_size = config.n_embd self.num_layers = config.n_layer @@ -235,7 +236,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): position_ids=None, head_mask=None, inputs_embeds=None, - use_cache=True, + use_cache=None, output_attentions=None, output_hidden_states=None, training=False, @@ -270,6 +271,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): output_attentions = output_attentions if output_attentions is not None else self.output_attentions output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states + use_cache = use_cache if use_cache is not None else self.use_cache # If using past key value states, only the last tokens # should be given as an input diff --git a/src/transformers/modeling_tf_gpt2.py b/src/transformers/modeling_tf_gpt2.py index 90f696062..bf83f9d95 100644 --- a/src/transformers/modeling_tf_gpt2.py +++ b/src/transformers/modeling_tf_gpt2.py @@ -215,6 +215,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): super().__init__(*inputs, **kwargs) self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states + self.use_cache = config.use_cache + self.num_hidden_layers = config.n_layer self.vocab_size = config.vocab_size self.n_embd = config.n_embd @@ -254,10 +256,10 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): position_ids=None, head_mask=None, inputs_embeds=None, - use_cache=True, - training=False, + use_cache=None, output_attentions=None, output_hidden_states=None, + training=False, ): if isinstance(inputs, (tuple, list)): input_ids = inputs[0] @@ -288,6 +290,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): output_attentions = output_attentions if output_attentions is not None else self.output_attentions output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states + use_cache = use_cache if use_cache is not None else self.use_cache if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") @@ -622,7 +625,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): head_mask=None, inputs_embeds=None, mc_token_ids=None, - use_cache=True, + use_cache=None, output_attentions=None, output_hidden_states=None, training=False, diff --git a/src/transformers/modeling_tf_t5.py b/src/transformers/modeling_tf_t5.py index 2664dddcc..959ceadd8 100644 --- a/src/transformers/modeling_tf_t5.py +++ b/src/transformers/modeling_tf_t5.py @@ -518,6 +518,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): super().__init__(**kwargs) self.output_hidden_states = config.output_hidden_states self.output_attentions = config.output_attentions + self.use_cache = config.use_cache self.embed_tokens = embed_tokens self.is_decoder = config.is_decoder @@ -556,7 +557,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): inputs_embeds=None, head_mask=None, past_key_value_states=None, - use_cache=False, + use_cache=None, output_attentions=None, output_hidden_states=None, training=False, @@ -586,6 +587,7 @@ class TFT5MainLayer(tf.keras.layers.Layer): output_attentions = output_attentions if output_attentions is not None else self.output_attentions output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states + use_cache = use_cache if use_cache is not None else self.use_cache if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both inputs and inputs_embeds at the same time") @@ -874,6 +876,7 @@ class TFT5Model(TFT5PreTrainedModel): embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name) encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False self.encoder = TFT5MainLayer(encoder_config, embed_tokens, name="encoder") decoder_config = copy.deepcopy(config) @@ -952,11 +955,13 @@ class TFT5Model(TFT5PreTrainedModel): decoder_attention_mask = kwargs.get("decoder_attention_mask", None) decoder_inputs_embeds = kwargs.get("decoder_inputs_embeds", None) decoder_past_key_value_states = kwargs.get("decoder_past_key_value_states", None) - use_cache = kwargs.get("use_cache", True) + use_cache = kwargs.get("use_cache", None) head_mask = kwargs.get("head_mask", None) output_attentions = kwargs.get("output_attentions", None) output_hidden_states = kwargs.get("output_hidden_states", None) + use_cache = use_cache if use_cache is not None else self.config.use_cache + # Encode if needed (training, first prediction pass) if encoder_outputs is None: encoder_outputs = self.encoder( @@ -1014,6 +1019,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel): embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name) encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False self.encoder = TFT5MainLayer(encoder_config, embed_tokens, name="encoder") decoder_config = copy.deepcopy(config) @@ -1095,13 +1101,15 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel): encoder_outputs = kwargs.get("encoder_outputs", None) decoder_attention_mask = kwargs.get("decoder_attention_mask", None) decoder_past_key_value_states = kwargs.get("decoder_past_key_value_states", None) - use_cache = kwargs.get("use_cache", True) + use_cache = kwargs.get("use_cache", None) inputs_embeds = kwargs.get("inputs_embeds", None) decoder_inputs_embeds = kwargs.get("decoder_inputs_embeds", None) head_mask = kwargs.get("head_mask", None) output_attentions = kwargs.get("output_attentions", None) output_hidden_states = kwargs.get("output_hidden_states", None) + use_cache = use_cache if use_cache is not None else self.config.use_cache + # Encode if needed (training, first prediction pass) if encoder_outputs is None: # Convert encoder inputs in embeddings if needed diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 208c896cf..3cafb3a40 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -153,6 +153,7 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase): def test_advanced_inputs(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.use_cache = False inputs_dict["input_ids"][:, -2:] = config.pad_token_id decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_bart_decoder_inputs( config, inputs_dict["input_ids"] diff --git a/tests/test_modeling_gpt2.py b/tests/test_modeling_gpt2.py index eb4161f6d..2dbda2e85 100644 --- a/tests/test_modeling_gpt2.py +++ b/tests/test_modeling_gpt2.py @@ -168,7 +168,14 @@ class GPT2ModelTester: model.eval() # first forward pass - output, past = model(input_ids, token_type_ids=token_type_ids) + outputs = model(input_ids, token_type_ids=token_type_ids, use_cache=True) + outputs_use_cache_conf = model(input_ids, token_type_ids=token_type_ids) + outputs_no_past = model(input_ids, token_type_ids=token_type_ids, use_cache=False) + + self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) + self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) + + output, past = outputs # create hypothetical next token and extent to next_input_ids next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) diff --git a/tests/test_modeling_t5.py b/tests/test_modeling_t5.py index dd480873c..a5d7a1a4d 100644 --- a/tests/test_modeling_t5.py +++ b/tests/test_modeling_t5.py @@ -193,7 +193,14 @@ class T5ModelTester: model.eval() # first forward pass - output, past_key_value_states = model(input_ids, use_cache=True) + outputs = model(input_ids, use_cache=True) + outputs_use_cache_conf = model(input_ids) + outputs_no_past = model(input_ids, use_cache=False) + + self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) + self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) + + output, past_key_value_states = outputs # create hypothetical next token and extent to next_input_ids next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 99a00e5d3..b38d2db4f 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -126,6 +126,7 @@ class TFModelTesterMixin: if "T5" in main_layer_class.__name__: # Take the same values than in TFT5ModelTester for this shared layer shared = TFSharedEmbeddings(99, 32, name="shared") + config.use_cache = False main_layer = main_layer_class(config, embed_tokens=shared) else: main_layer = main_layer_class(config) diff --git a/tests/test_modeling_tf_gpt2.py b/tests/test_modeling_tf_gpt2.py index 6afd2959f..2f733491a 100644 --- a/tests/test_modeling_tf_gpt2.py +++ b/tests/test_modeling_tf_gpt2.py @@ -143,7 +143,14 @@ class TFGPT2ModelTester: model = TFGPT2Model(config=config) # first forward pass - output, past = model(input_ids, token_type_ids=token_type_ids) + outputs = model(input_ids, token_type_ids=token_type_ids, use_cache=True) + outputs_use_cache_conf = model(input_ids, token_type_ids=token_type_ids) + outputs_no_past = model(input_ids, token_type_ids=token_type_ids, use_cache=False) + + self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) + self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) + + output, past = outputs # create hypothetical next token and extent to next_input_ids next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) diff --git a/tests/test_modeling_tf_t5.py b/tests/test_modeling_tf_t5.py index d19e18d6d..28cf60e46 100644 --- a/tests/test_modeling_tf_t5.py +++ b/tests/test_modeling_tf_t5.py @@ -135,7 +135,15 @@ class TFT5ModelTester: self.batch_size = 1 # first forward pass - _, past_key_value_states = model(input_ids, use_cache=True) + outputs = model(input_ids, use_cache=True) + + outputs_use_cache_conf = model(input_ids) + outputs_no_past = model(input_ids, use_cache=False) + + self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) + self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) + + output, past_key_value_states = outputs # create hypothetical next token and extent to next_input_ids next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)