mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
fix
This commit is contained in:
parent
e60bb24d62
commit
27fe21f27d
1 changed files with 37 additions and 29 deletions
|
|
@ -1106,6 +1106,42 @@ class GenerationTesterMixin:
|
|||
# - assisted_decoding does not support `use_cache = False`
|
||||
# - assisted_decoding does not support `batch_size > 1`
|
||||
|
||||
def foo(self, model, inputs_dict, generation_kwargs):
|
||||
|
||||
output_greedy = model.generate(**generation_kwargs, **inputs_dict)
|
||||
|
||||
# test with the same assistant model or randomly init one
|
||||
# in the first case all candidate tokens are accepted, in the second none is accepted
|
||||
# case when some are accepted and some not is hard to reproduce, so let's hope this catches most errors :)
|
||||
if assistant_type == "random":
|
||||
assistant_model = model_class(config).to(torch_device).eval()
|
||||
else:
|
||||
assistant_model = model
|
||||
assistant_model.generation_config.num_assistant_tokens = 2 # see b)
|
||||
assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b)
|
||||
generation_kwargs.update({"assistant_model": assistant_model})
|
||||
output_assisted = model.generate(**generation_kwargs, **inputs_dict)
|
||||
|
||||
# The two outputs must match and their shape must be as expected
|
||||
|
||||
failed = "PASS"
|
||||
try:
|
||||
# The two outputs must match and their shape must be as expected
|
||||
self._check_similar_generate_outputs(output_greedy, output_assisted)
|
||||
except:
|
||||
failed = "FAIL"
|
||||
|
||||
o = {k: inputs_dict[k].detach().cpu().tolist() for k in inputs_dict}
|
||||
import json
|
||||
s = json.dumps(o)
|
||||
with open("test.txt", "a+") as fp:
|
||||
fp.write(failed + f" ({model.__class__.__name__})" + ": " + s + "\n")
|
||||
|
||||
assert 1 == 2
|
||||
|
||||
for output in (output_greedy, output_assisted):
|
||||
self._check_outputs(output, model.config, use_cache=True)
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if model_class._is_stateful:
|
||||
self.skipTest(reason="Stateful models don't support assisted generation")
|
||||
|
|
@ -1154,40 +1190,12 @@ class GenerationTesterMixin:
|
|||
"return_dict_in_generate": True,
|
||||
"use_cache": True,
|
||||
}
|
||||
output_greedy = model.generate(**generation_kwargs, **inputs_dict)
|
||||
|
||||
# test with the same assistant model or randomly init one
|
||||
# in the first case all candidate tokens are accepted, in the second none is accepted
|
||||
# case when some are accepted and some not is hard to reproduce, so let's hope this catches most errors :)
|
||||
if assistant_type == "random":
|
||||
assistant_model = model_class(config).to(torch_device).eval()
|
||||
else:
|
||||
assistant_model = model
|
||||
assistant_model.generation_config.num_assistant_tokens = 2 # see b)
|
||||
assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b)
|
||||
generation_kwargs.update({"assistant_model": assistant_model})
|
||||
output_assisted = model.generate(**generation_kwargs, **inputs_dict)
|
||||
|
||||
# The two outputs must match and their shape must be as expected
|
||||
|
||||
failed = "PASS"
|
||||
try:
|
||||
# The two outputs must match and their shape must be as expected
|
||||
self._check_similar_generate_outputs(output_greedy, output_assisted)
|
||||
foo(self, model, inputs_dict, generation_kwargs)
|
||||
except:
|
||||
failed = "FAIL"
|
||||
|
||||
o = {k: inputs_dict[k].detach().cpu().tolist() for k in inputs_dict}
|
||||
import json
|
||||
s = json.dumps(o)
|
||||
with open("test.txt", "a+") as fp:
|
||||
fp.write(failed + f" ({model.__class__.__name__})" + ": " + s + "\n")
|
||||
|
||||
assert 1 == 2
|
||||
|
||||
for output in (output_greedy, output_assisted):
|
||||
self._check_outputs(output, model.config, use_cache=True)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
||||
# This test ensures that the prompt lookup generation does not introduce output changes over greedy search.
|
||||
|
|
|
|||
Loading…
Reference in a new issue