From ddfaf1192629f7efadeb8caea93c77ffd98573cb Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 3 Jul 2024 10:43:44 +0100 Subject: [PATCH] Gemma 2: Update slow tests (#31759) gemma 2 slow tests --- src/transformers/pipelines/text_generation.py | 20 ++--- tests/models/gemma2/test_modeling_gemma2.py | 79 ++++++++++++------- 2 files changed, 60 insertions(+), 39 deletions(-) diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 994a51748..80f59bf42 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -272,12 +272,17 @@ class TextGenerationPipeline(Pipeline): max_length=None, **generate_kwargs, ): + # Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults + tokenizer_kwargs = { + "add_special_tokens": add_special_tokens, + "truncation": truncation, + "padding": padding, + "max_length": max_length, + } + tokenizer_kwargs = {key: value for key, value in tokenizer_kwargs.items() if value is not None} + if isinstance(prompt_text, Chat): - # Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults - tokenizer_kwargs = {} - for tokenizer_kwarg_name in ["truncation", "padding", "max_length"]: - if locals()[tokenizer_kwarg_name] is not None: - tokenizer_kwargs[tokenizer_kwarg_name] = locals()[tokenizer_kwarg_name] + tokenizer_kwargs.pop("add_special_tokens", None) # ignore add_special_tokens on chats inputs = self.tokenizer.apply_chat_template( prompt_text.messages, add_generation_prompt=True, @@ -286,11 +291,6 @@ class TextGenerationPipeline(Pipeline): **tokenizer_kwargs, ) else: - # Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults - tokenizer_kwargs = {} - for tokenizer_kwarg_name in ["add_special_tokens", "truncation", "padding", "max_length"]: - if locals()[tokenizer_kwarg_name] is not None: - tokenizer_kwargs[tokenizer_kwarg_name] = locals()[tokenizer_kwarg_name] inputs = self.tokenizer(prefix + prompt_text, return_tensors=self.framework, **tokenizer_kwargs) inputs["prompt_text"] = prompt_text diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 870265f94..20b8ea3ec 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -16,7 +16,7 @@ import unittest -from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, is_torch_available +from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, is_torch_available, pipeline from transformers.testing_utils import ( require_read_token, require_torch, @@ -102,41 +102,62 @@ class Gemma2IntegrationTest(unittest.TestCase): cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] @require_read_token - def test_model_2b_bf16(self): + def test_model_9b_bf16(self): model_id = "google/gemma-2-9b" EXPECTED_TEXTS = [ - "Hello I am doing a project for a class and I am trying to use the ", - "Hi today. So, I'm going to show you how to do a problem from the textbook. So", + "Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many", + "Hi today I'm going to be talking about the history of the United States. The United States of America", + ] + + model = AutoModelForCausalLM.from_pretrained( + model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="eager" + ).to(torch_device) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_text = tokenizer.batch_decode(output, skip_special_tokens=False) + + self.assertEqual(output_text, EXPECTED_TEXTS) + + @require_read_token + def test_model_9b_fp16(self): + model_id = "google/gemma-2-9b" + EXPECTED_TEXTS = [ + "Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many", + "Hi today I'm going to be talking about the history of the United States. The United States of America", + ] + + model = AutoModelForCausalLM.from_pretrained( + model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16, attn_implementation="eager" + ).to(torch_device) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_text = tokenizer.batch_decode(output, skip_special_tokens=False) + + self.assertEqual(output_text, EXPECTED_TEXTS) + + @require_read_token + def test_model_9b_pipeline_bf16(self): + # See https://github.com/huggingface/transformers/pull/31747 -- pipeline was broken for Gemma2 before this PR + model_id = "google/gemma-2-9b" + # EXPECTED_TEXTS should match the same non-pipeline test, minus the special tokens + EXPECTED_TEXTS = [ + "Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many", + "Hi today I'm going to be talking about the history of the United States. The United States of America", ] model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to( torch_device ) - tokenizer = AutoTokenizer.from_pretrained(model_id) - inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) + pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) - output = model.generate(**inputs, max_new_tokens=20, do_sample=False) - output_text = tokenizer.batch_decode(output, skip_special_tokens=True) + output = pipe(self.input_text, max_new_tokens=20, do_sample=False, padding=True) - self.assertEqual(output_text, EXPECTED_TEXTS) - - @require_read_token - def test_model_2b_fp16(self): - model_id = "google/gemma-2-9b" - EXPECTED_TEXTS = [ - "Hello I am doing a project on the effect of the temperature on the rate of a reaction. I am using a ", - "Hi today I'm going to be talking about the 1000-4000-", - ] - - model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16).to( - torch_device - ) - - tokenizer = AutoTokenizer.from_pretrained(model_id) - inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) - - output = model.generate(**inputs, max_new_tokens=20, do_sample=False) - output_text = tokenizer.batch_decode(output, skip_special_tokens=True) - - self.assertEqual(output_text, EXPECTED_TEXTS) + self.assertEqual(output[0][0]["generated_text"], EXPECTED_TEXTS[0]) + self.assertEqual(output[1][0]["generated_text"], EXPECTED_TEXTS[1])