This commit is contained in:
ydshieh 2024-11-29 13:24:19 +01:00
parent 7662fb9b9a
commit 8d9e49a113

View file

@ -1225,8 +1225,27 @@ class GenerationTesterMixin:
"""
results[-1][1]['input_ids'][0][0] = 1; results[-1][1]['input_ids'][0][1] = 92; inputs_dict['input_ids'] = torch.tensor(results[-1][1]['input_ids']); failed, o, output_greedy, output_assisted = foo(self, model2, inputs_dict, generation_kwargs); print(failed)
results[-1][1]['input_ids'][0][0] = 1; results[-1][1]['input_ids'][0][1] = 92; inputs_dict['input_ids'] = torch.tensor(results[-1][1]['input_ids']); results[-1][1]['attention_mask'] = [[1, 1, 1, 1, 1, 1, 1]]; inputs_dict['attention_mask'] = torch.tensor(results[-1][1]['attention_mask']); failed, o, output_greedy, output_assisted = foo(self, model2, inputs_dict, generation_kwargs); print(failed)
inputs_dict['attention_mask'] = torch.tensor([[1, 1, 0, 1, 1, 0, 0, 1]])
inputs_dict['attention_mask'] = torch.tensor([[1, 1, 0, 1, 1, 0, 0, 1, 1]])
inputs_dict['attention_mask'] = torch.tensor([[1, 1, 0, 1, 1, 0, 0, 1]])
failed, o, output_greedy, output_assisted = foo(self, model2, inputs_dict, generation_kwargs)
output_greedy.sequences
output_assisted.sequences
output_greedy.scores[0] - output_assisted.scores[0]
torch.argmax(model(**inputs_dict).logits[:, -1:, :])
model(**inputs_dict).logits[:, 0, :]
# not the same as there is logit processor
output_greedy.scores[0] - model(**inputs_dict).logits[:, 0, :]
model(**inputs_dict).logits - output_greedy.scores[0]
"""
breakpoint()
# assert 1 == 2