diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index bddbb2ddf..0667363c8 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1175,6 +1175,14 @@ class GenerationTesterMixin: model = model_class(config).to(torch_device).eval() model.save_pretrained("mymodel") model2 = model_class.from_pretrained("mymodel") + + s1 = model.state_dict() + s2 = model2.state_dict() + diffs = {} + for k in s1: + diff = torch.amax(torch.abs(s1[k] - s2[k])) + diffs[k] = diff.detach().cpu().tolist() + # Sets assisted generation arguments such that: # a) no EOS is generated, to ensure generation doesn't break early # b) the assistant model always generates two tokens when it is called, to ensure the input preparation of @@ -1208,7 +1216,7 @@ class GenerationTesterMixin: results3.append((failed, o, output_greedy, output_assisted)) results4.append(results3[-1][0]) - for _ in range(200): + for _ in range(1): results.append(foo(self, model, inputs_dict, generation_kwargs)) results2.append(results[-1][0]) results3.append(foo(self, model2, inputs_dict, generation_kwargs))