fix tests with main revision and read token (#33560)

* fix tests with main revision and read token

* [run-slow]mamba2

* test previously skipped tests

* [run-slow]mamba2

* skip some tests

* [run-slow]mamba2

* finalize tests

* [run-slow]mamba2
This commit is contained in:
Pablo Montalvo 2024-09-19 17:10:22 +02:00 committed by GitHub
parent 80b774eb29
commit 4f0246e535
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -20,7 +20,7 @@ from typing import Dict, List, Tuple
from parameterized import parameterized
from transformers import AutoTokenizer, Mamba2Config, is_torch_available
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
from transformers.testing_utils import require_read_token, require_torch, require_torch_gpu, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
@ -96,7 +96,7 @@ class Mamba2ModelTester:
self.tie_word_embeddings = tie_word_embeddings
def get_large_model_config(self):
return Mamba2Config.from_pretrained("revision='refs/pr/9'")
return Mamba2Config.from_pretrained("mistralai/Mamba-Codestral-7B-v0.1")
def prepare_config_and_inputs(
self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False
@ -199,34 +199,26 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
def test_tied_weights_keys(self):
pass
@unittest.skip(reason="To fix, Mamba 2 cache slicing is interacting with beam search")
def test_beam_search_generate_dict_outputs_use_cache(self):
pass
@unittest.skip(reason="To fix, Mamba 2 cache slicing is interacting with beam search")
def test_beam_sample_generate(self):
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
def test_generate_without_input_ids(self):
pass
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
def test_generate_without_input_ids(self):
def test_generate_from_inputs_embeds_decoder_only(self):
pass
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
def test_greedy_generate_dict_outputs_use_cache(self):
pass
@unittest.skip(reason="Initialization of mamba2 fails this")
def test_save_load_fast_init_from_base(self):
@unittest.skip(reason="To fix, Mamba 2 cache slicing is interacting with beam search")
def test_beam_search_generate_dict_outputs_use_cache(self):
pass
@unittest.skip(reason="A large mamba2 would be necessary (and costly) for that")
def test_multi_gpu_data_parallel_forward(self):
pass
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
def test_generate_from_inputs_embeds_decoder_only(self):
pass
def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -292,12 +284,11 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
@require_torch
@slow
@require_read_token
class Mamba2IntegrationTest(unittest.TestCase):
def setUp(self):
self.model_id = "mistralai/Mamba-Codestral-7B-v0.1"
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_id, revision="refs/pr/9", from_slow=True, legacy=False
)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_slow=True, legacy=False)
self.prompt = ("[INST]Write a hello world program in C++.",)
@parameterized.expand(
@ -317,7 +308,7 @@ class Mamba2IntegrationTest(unittest.TestCase):
tokenizer = self.tokenizer
tokenizer.pad_token_id = tokenizer.eos_token_id
model = Mamba2ForCausalLM.from_pretrained(self.model_id, revision="refs/pr/9", torch_dtype=torch.bfloat16)
model = Mamba2ForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16)
model.to(device)
input_ids = tokenizer("[INST]Write a hello world program in C++.[/INST]", return_tensors="pt")["input_ids"].to(
device
@ -343,9 +334,7 @@ class Mamba2IntegrationTest(unittest.TestCase):
"[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]",
]
model = Mamba2ForCausalLM.from_pretrained(self.model_id, revision="refs/pr/9", torch_dtype=torch.bfloat16).to(
torch_device
)
model = Mamba2ForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to(torch_device)
tokenizer.pad_token_id = tokenizer.eos_token_id
# batched generation
tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest").to(torch_device)
@ -375,9 +364,7 @@ class Mamba2IntegrationTest(unittest.TestCase):
"[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]",
]
model = Mamba2ForCausalLM.from_pretrained(self.model_id, revision="refs/pr/9", torch_dtype=torch.bfloat16).to(
torch_device
)
model = Mamba2ForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to(torch_device)
tokenizer.pad_token_id = tokenizer.eos_token_id
# batched generation
tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest").to(torch_device)