mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
fix
This commit is contained in:
parent
7662fb9b9a
commit
8d9e49a113
1 changed files with 19 additions and 0 deletions
|
|
@ -1225,8 +1225,27 @@ class GenerationTesterMixin:
|
|||
"""
|
||||
results[-1][1]['input_ids'][0][0] = 1; results[-1][1]['input_ids'][0][1] = 92; inputs_dict['input_ids'] = torch.tensor(results[-1][1]['input_ids']); failed, o, output_greedy, output_assisted = foo(self, model2, inputs_dict, generation_kwargs); print(failed)
|
||||
results[-1][1]['input_ids'][0][0] = 1; results[-1][1]['input_ids'][0][1] = 92; inputs_dict['input_ids'] = torch.tensor(results[-1][1]['input_ids']); results[-1][1]['attention_mask'] = [[1, 1, 1, 1, 1, 1, 1]]; inputs_dict['attention_mask'] = torch.tensor(results[-1][1]['attention_mask']); failed, o, output_greedy, output_assisted = foo(self, model2, inputs_dict, generation_kwargs); print(failed)
|
||||
|
||||
inputs_dict['attention_mask'] = torch.tensor([[1, 1, 0, 1, 1, 0, 0, 1]])
|
||||
inputs_dict['attention_mask'] = torch.tensor([[1, 1, 0, 1, 1, 0, 0, 1, 1]])
|
||||
inputs_dict['attention_mask'] = torch.tensor([[1, 1, 0, 1, 1, 0, 0, 1]])
|
||||
failed, o, output_greedy, output_assisted = foo(self, model2, inputs_dict, generation_kwargs)
|
||||
output_greedy.sequences
|
||||
output_assisted.sequences
|
||||
output_greedy.scores[0] - output_assisted.scores[0]
|
||||
|
||||
torch.argmax(model(**inputs_dict).logits[:, -1:, :])
|
||||
model(**inputs_dict).logits[:, 0, :]
|
||||
|
||||
# not the same as there is logit processor
|
||||
output_greedy.scores[0] - model(**inputs_dict).logits[:, 0, :]
|
||||
|
||||
model(**inputs_dict).logits - output_greedy.scores[0]
|
||||
|
||||
"""
|
||||
|
||||
|
||||
|
||||
breakpoint()
|
||||
# assert 1 == 2
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue