mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
skip gptj slow generate tests for now (#13809)
This commit is contained in:
parent
41436d3dfb
commit
8bbb53e20b
1 changed files with 4 additions and 2 deletions
|
|
@ -396,8 +396,9 @@ class GPTJModelTest(unittest.TestCase):
|
|||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)
|
||||
|
||||
@slow
|
||||
@tooslow
|
||||
def test_batch_generation(self):
|
||||
# Marked as @tooslow due to GPU OOM
|
||||
model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16)
|
||||
model.to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B", revision="float16")
|
||||
|
|
@ -464,8 +465,9 @@ class GPTJModelTest(unittest.TestCase):
|
|||
|
||||
@require_torch
|
||||
class GPTJModelLanguageGenerationTest(unittest.TestCase):
|
||||
@slow
|
||||
@tooslow
|
||||
def test_lm_generate_gptj(self):
|
||||
# Marked as @tooslow due to GPU OOM
|
||||
for checkpointing in [True, False]:
|
||||
model = GPTJForCausalLM.from_pretrained(
|
||||
"EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16
|
||||
|
|
|
|||
Loading…
Reference in a new issue