mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
fix test_generated_length_assisted_generation (#34935)
fix test_generated_length_assisted_generation
This commit is contained in:
parent
ec7afad609
commit
42c8ccfd4c
1 changed files with 8 additions and 1 deletions
|
|
@ -3405,7 +3405,14 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||
assistant_model=assistant,
|
||||
min_new_tokens=10,
|
||||
)
|
||||
self.assertTrue((input_length + 10) <= out.shape[-1] <= 20)
|
||||
self.assertTrue((input_length + 10) <= out.shape[-1])
|
||||
|
||||
out = model.generate(
|
||||
input_ids,
|
||||
assistant_model=assistant,
|
||||
max_new_tokens=7,
|
||||
)
|
||||
self.assertTrue(out.shape[-1] <= (input_length + 7))
|
||||
|
||||
def test_model_kwarg_assisted_decoding_decoder_only(self):
|
||||
# PT-only test: TF doesn't support assisted decoding yet.
|
||||
|
|
|
|||
Loading…
Reference in a new issue