mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
fix tests with main revision and read token (#33560)
* fix tests with main revision and read token * [run-slow]mamba2 * test previously skipped tests * [run-slow]mamba2 * skip some tests * [run-slow]mamba2 * finalize tests * [run-slow]mamba2
This commit is contained in:
parent
80b774eb29
commit
4f0246e535
1 changed files with 12 additions and 25 deletions
|
|
@ -20,7 +20,7 @@ from typing import Dict, List, Tuple
|
|||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoTokenizer, Mamba2Config, is_torch_available
|
||||
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
|
||||
from transformers.testing_utils import require_read_token, require_torch, require_torch_gpu, slow, torch_device
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
|
|
@ -96,7 +96,7 @@ class Mamba2ModelTester:
|
|||
self.tie_word_embeddings = tie_word_embeddings
|
||||
|
||||
def get_large_model_config(self):
|
||||
return Mamba2Config.from_pretrained("revision='refs/pr/9'")
|
||||
return Mamba2Config.from_pretrained("mistralai/Mamba-Codestral-7B-v0.1")
|
||||
|
||||
def prepare_config_and_inputs(
|
||||
self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False
|
||||
|
|
@ -199,34 +199,26 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||
def test_tied_weights_keys(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="To fix, Mamba 2 cache slicing is interacting with beam search")
|
||||
def test_beam_search_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="To fix, Mamba 2 cache slicing is interacting with beam search")
|
||||
def test_beam_sample_generate(self):
|
||||
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
|
||||
def test_generate_without_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
|
||||
def test_generate_without_input_ids(self):
|
||||
def test_generate_from_inputs_embeds_decoder_only(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
|
||||
def test_greedy_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Initialization of mamba2 fails this")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
@unittest.skip(reason="To fix, Mamba 2 cache slicing is interacting with beam search")
|
||||
def test_beam_search_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="A large mamba2 would be necessary (and costly) for that")
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
|
||||
def test_generate_from_inputs_embeds_decoder_only(self):
|
||||
pass
|
||||
|
||||
def test_model_outputs_equivalence(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
|
@ -292,12 +284,11 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||
|
||||
@require_torch
|
||||
@slow
|
||||
@require_read_token
|
||||
class Mamba2IntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.model_id = "mistralai/Mamba-Codestral-7B-v0.1"
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.model_id, revision="refs/pr/9", from_slow=True, legacy=False
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_slow=True, legacy=False)
|
||||
self.prompt = ("[INST]Write a hello world program in C++.",)
|
||||
|
||||
@parameterized.expand(
|
||||
|
|
@ -317,7 +308,7 @@ class Mamba2IntegrationTest(unittest.TestCase):
|
|||
tokenizer = self.tokenizer
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
model = Mamba2ForCausalLM.from_pretrained(self.model_id, revision="refs/pr/9", torch_dtype=torch.bfloat16)
|
||||
model = Mamba2ForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16)
|
||||
model.to(device)
|
||||
input_ids = tokenizer("[INST]Write a hello world program in C++.[/INST]", return_tensors="pt")["input_ids"].to(
|
||||
device
|
||||
|
|
@ -343,9 +334,7 @@ class Mamba2IntegrationTest(unittest.TestCase):
|
|||
"[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]",
|
||||
]
|
||||
|
||||
model = Mamba2ForCausalLM.from_pretrained(self.model_id, revision="refs/pr/9", torch_dtype=torch.bfloat16).to(
|
||||
torch_device
|
||||
)
|
||||
model = Mamba2ForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to(torch_device)
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
# batched generation
|
||||
tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest").to(torch_device)
|
||||
|
|
@ -375,9 +364,7 @@ class Mamba2IntegrationTest(unittest.TestCase):
|
|||
"[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]",
|
||||
]
|
||||
|
||||
model = Mamba2ForCausalLM.from_pretrained(self.model_id, revision="refs/pr/9", torch_dtype=torch.bfloat16).to(
|
||||
torch_device
|
||||
)
|
||||
model = Mamba2ForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to(torch_device)
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
# batched generation
|
||||
tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest").to(torch_device)
|
||||
|
|
|
|||
Loading…
Reference in a new issue