This commit is contained in:
ydshieh 2024-11-29 10:46:20 +01:00
parent 61422b7f26
commit a37afef202

View file

@ -1175,6 +1175,14 @@ class GenerationTesterMixin:
model = model_class(config).to(torch_device).eval()
model.save_pretrained("mymodel")
model2 = model_class.from_pretrained("mymodel")
s1 = model.state_dict()
s2 = model2.state_dict()
diffs = {}
for k in s1:
diff = torch.amax(torch.abs(s1[k] - s2[k]))
diffs[k] = diff.detach().cpu().tolist()
# Sets assisted generation arguments such that:
# a) no EOS is generated, to ensure generation doesn't break early
# b) the assistant model always generates two tokens when it is called, to ensure the input preparation of
@ -1208,7 +1216,7 @@ class GenerationTesterMixin:
results3.append((failed, o, output_greedy, output_assisted))
results4.append(results3[-1][0])
for _ in range(200):
for _ in range(1):
results.append(foo(self, model, inputs_dict, generation_kwargs))
results2.append(results[-1][0])
results3.append(foo(self, model2, inputs_dict, generation_kwargs))