From e3ecbaa4abfeb98807d75fcb3fb7f68d88d41c4d Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Mon, 9 Jan 2023 18:12:13 +0100 Subject: [PATCH] Patch-past-refactor (#21050) * small patches, forgot a line * refactor PT * the actual fix --- .../modeling_tf_vision_encoder_decoder.py | 6 ++---- .../modeling_vision_encoder_decoder.py | 4 ++-- .../modeling_{{cookiecutter.lowercase_modelname}}.py | 2 +- tests/pipelines/test_pipelines_image_to_text.py | 1 - 4 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py index 16b9d77f9..50564de22 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py @@ -729,13 +729,11 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos ) def prepare_inputs_for_generation( - self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs + self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs ): - decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past) + decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values) decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None past_key_values = decoder_inputs.get("past_key_values") - if past_key_values is None: - past_key_values = decoder_inputs.get("past") # e.g. on TF GPT2 input_dict = { "pixel_values": None, # needs to be passed to make Keras.layer.__call__ happy "attention_mask": attention_mask, diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index aa49086fc..e6c7658da 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -649,9 +649,9 @@ class VisionEncoderDecoderModel(PreTrainedModel): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) def prepare_inputs_for_generation( - self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs + self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs ): - decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past) + decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values) decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None input_dict = { "attention_mask": attention_mask, diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index cc2f295a9..a64bd6c8f 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -3333,7 +3333,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m if attention_mask is None: attention_mask = input_ids.new_ones(input_ids.shape) - if past: + if past_key_values: input_ids = input_ids[:, -1:] # first step, decoder_cached_states are empty return { diff --git a/tests/pipelines/test_pipelines_image_to_text.py b/tests/pipelines/test_pipelines_image_to_text.py index 9beb846ee..0e1e805f9 100644 --- a/tests/pipelines/test_pipelines_image_to_text.py +++ b/tests/pipelines/test_pipelines_image_to_text.py @@ -55,7 +55,6 @@ class ImageToTextPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta ) @require_tf - @unittest.skip("Arthur will fix me!") def test_small_model_tf(self): pipe = pipeline("image-to-text", model="hf-internal-testing/tiny-random-vit-gpt2", framework="tf") image = "./tests/fixtures/tests_samples/COCO/000000039769.png"