diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index d4fc71802..cb6109a11 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1106,6 +1106,42 @@ class GenerationTesterMixin: # - assisted_decoding does not support `use_cache = False` # - assisted_decoding does not support `batch_size > 1` + def foo(self, model, inputs_dict, generation_kwargs): + + output_greedy = model.generate(**generation_kwargs, **inputs_dict) + + # test with the same assistant model or randomly init one + # in the first case all candidate tokens are accepted, in the second none is accepted + # case when some are accepted and some not is hard to reproduce, so let's hope this catches most errors :) + if assistant_type == "random": + assistant_model = model_class(config).to(torch_device).eval() + else: + assistant_model = model + assistant_model.generation_config.num_assistant_tokens = 2 # see b) + assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b) + generation_kwargs.update({"assistant_model": assistant_model}) + output_assisted = model.generate(**generation_kwargs, **inputs_dict) + + # The two outputs must match and their shape must be as expected + + failed = "PASS" + try: + # The two outputs must match and their shape must be as expected + self._check_similar_generate_outputs(output_greedy, output_assisted) + except: + failed = "FAIL" + + o = {k: inputs_dict[k].detach().cpu().tolist() for k in inputs_dict} + import json + s = json.dumps(o) + with open("test.txt", "a+") as fp: + fp.write(failed + f" ({model.__class__.__name__})" + ": " + s + "\n") + + assert 1 == 2 + + for output in (output_greedy, output_assisted): + self._check_outputs(output, model.config, use_cache=True) + for model_class in self.all_generative_model_classes: if model_class._is_stateful: self.skipTest(reason="Stateful models don't support assisted generation") @@ -1154,40 +1190,12 @@ class GenerationTesterMixin: "return_dict_in_generate": True, "use_cache": True, } - output_greedy = model.generate(**generation_kwargs, **inputs_dict) - # test with the same assistant model or randomly init one - # in the first case all candidate tokens are accepted, in the second none is accepted - # case when some are accepted and some not is hard to reproduce, so let's hope this catches most errors :) - if assistant_type == "random": - assistant_model = model_class(config).to(torch_device).eval() - else: - assistant_model = model - assistant_model.generation_config.num_assistant_tokens = 2 # see b) - assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b) - generation_kwargs.update({"assistant_model": assistant_model}) - output_assisted = model.generate(**generation_kwargs, **inputs_dict) - - # The two outputs must match and their shape must be as expected - - failed = "PASS" try: - # The two outputs must match and their shape must be as expected - self._check_similar_generate_outputs(output_greedy, output_assisted) + foo(self, model, inputs_dict, generation_kwargs) except: - failed = "FAIL" - - o = {k: inputs_dict[k].detach().cpu().tolist() for k in inputs_dict} - import json - s = json.dumps(o) - with open("test.txt", "a+") as fp: - fp.write(failed + f" ({model.__class__.__name__})" + ": " + s + "\n") - assert 1 == 2 - for output in (output_greedy, output_assisted): - self._check_outputs(output, model.config, use_cache=True) - @pytest.mark.generate def test_prompt_lookup_decoding_matches_greedy_search(self): # This test ensures that the prompt lookup generation does not introduce output changes over greedy search.