This commit is contained in:
ydshieh 2024-11-29 09:31:02 +01:00
parent e60bb24d62
commit 27fe21f27d

View file

@ -1106,6 +1106,42 @@ class GenerationTesterMixin:
# - assisted_decoding does not support `use_cache = False`
# - assisted_decoding does not support `batch_size > 1`
def foo(self, model, inputs_dict, generation_kwargs):
output_greedy = model.generate(**generation_kwargs, **inputs_dict)
# test with the same assistant model or randomly init one
# in the first case all candidate tokens are accepted, in the second none is accepted
# case when some are accepted and some not is hard to reproduce, so let's hope this catches most errors :)
if assistant_type == "random":
assistant_model = model_class(config).to(torch_device).eval()
else:
assistant_model = model
assistant_model.generation_config.num_assistant_tokens = 2 # see b)
assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b)
generation_kwargs.update({"assistant_model": assistant_model})
output_assisted = model.generate(**generation_kwargs, **inputs_dict)
# The two outputs must match and their shape must be as expected
failed = "PASS"
try:
# The two outputs must match and their shape must be as expected
self._check_similar_generate_outputs(output_greedy, output_assisted)
except:
failed = "FAIL"
o = {k: inputs_dict[k].detach().cpu().tolist() for k in inputs_dict}
import json
s = json.dumps(o)
with open("test.txt", "a+") as fp:
fp.write(failed + f" ({model.__class__.__name__})" + ": " + s + "\n")
assert 1 == 2
for output in (output_greedy, output_assisted):
self._check_outputs(output, model.config, use_cache=True)
for model_class in self.all_generative_model_classes:
if model_class._is_stateful:
self.skipTest(reason="Stateful models don't support assisted generation")
@ -1154,40 +1190,12 @@ class GenerationTesterMixin:
"return_dict_in_generate": True,
"use_cache": True,
}
output_greedy = model.generate(**generation_kwargs, **inputs_dict)
# test with the same assistant model or randomly init one
# in the first case all candidate tokens are accepted, in the second none is accepted
# case when some are accepted and some not is hard to reproduce, so let's hope this catches most errors :)
if assistant_type == "random":
assistant_model = model_class(config).to(torch_device).eval()
else:
assistant_model = model
assistant_model.generation_config.num_assistant_tokens = 2 # see b)
assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b)
generation_kwargs.update({"assistant_model": assistant_model})
output_assisted = model.generate(**generation_kwargs, **inputs_dict)
# The two outputs must match and their shape must be as expected
failed = "PASS"
try:
# The two outputs must match and their shape must be as expected
self._check_similar_generate_outputs(output_greedy, output_assisted)
foo(self, model, inputs_dict, generation_kwargs)
except:
failed = "FAIL"
o = {k: inputs_dict[k].detach().cpu().tolist() for k in inputs_dict}
import json
s = json.dumps(o)
with open("test.txt", "a+") as fp:
fp.write(failed + f" ({model.__class__.__name__})" + ": " + s + "\n")
assert 1 == 2
for output in (output_greedy, output_assisted):
self._check_outputs(output, model.config, use_cache=True)
@pytest.mark.generate
def test_prompt_lookup_decoding_matches_greedy_search(self):
# This test ensures that the prompt lookup generation does not introduce output changes over greedy search.