mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
fix
This commit is contained in:
parent
61422b7f26
commit
a37afef202
1 changed files with 9 additions and 1 deletions
|
|
@ -1175,6 +1175,14 @@ class GenerationTesterMixin:
|
|||
model = model_class(config).to(torch_device).eval()
|
||||
model.save_pretrained("mymodel")
|
||||
model2 = model_class.from_pretrained("mymodel")
|
||||
|
||||
s1 = model.state_dict()
|
||||
s2 = model2.state_dict()
|
||||
diffs = {}
|
||||
for k in s1:
|
||||
diff = torch.amax(torch.abs(s1[k] - s2[k]))
|
||||
diffs[k] = diff.detach().cpu().tolist()
|
||||
|
||||
# Sets assisted generation arguments such that:
|
||||
# a) no EOS is generated, to ensure generation doesn't break early
|
||||
# b) the assistant model always generates two tokens when it is called, to ensure the input preparation of
|
||||
|
|
@ -1208,7 +1216,7 @@ class GenerationTesterMixin:
|
|||
results3.append((failed, o, output_greedy, output_assisted))
|
||||
results4.append(results3[-1][0])
|
||||
|
||||
for _ in range(200):
|
||||
for _ in range(1):
|
||||
results.append(foo(self, model, inputs_dict, generation_kwargs))
|
||||
results2.append(results[-1][0])
|
||||
results3.append(foo(self, model2, inputs_dict, generation_kwargs))
|
||||
|
|
|
|||
Loading…
Reference in a new issue