This commit is contained in:
ydshieh 2024-11-29 15:55:07 +01:00
parent c40a12e3b8
commit d37c8fb890
2 changed files with 6 additions and 0 deletions

View file

@ -4268,6 +4268,7 @@ class GenerationMixin:
breakpoint()
# 1. Fetch candidate sequences from a `CandidateGenerator`
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
breakpoint()
if candidate_logits is not None:
candidate_logits = candidate_logits.to(self.device)
@ -4304,12 +4305,14 @@ class GenerationMixin:
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
outputs = self(**model_inputs)
breakpoint()
# 2.3. Process the new logits
# .float() is needed to retain precision for later logits manipulations
new_logits = outputs.logits[:, -candidate_length - 1 :].float() # excludes the input prompt if present
new_logits = new_logits.to(input_ids.device)
next_token_logits = new_logits.clone()
breakpoint()
if len(logits_processor) > 0:
for i in range(candidate_length + 1):
new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
@ -4349,6 +4352,7 @@ class GenerationMixin:
# Because of this last token, assisted generation search reduces to a normal greedy search/sample if there
# is no match.
breakpoint()
# 4.1. Get the valid continuation, after the matching tokens
input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
if streamer is not None:

View file

@ -1108,6 +1108,8 @@ class GenerationTesterMixin:
def foo(self, model, inputs_dict, generation_kwargs):
inputs_dict['input_ids'] = torch.tensor([[88, 1, 37, 40, 82, 36, 80]])
generation_kwargs = copy.deepcopy(generation_kwargs)
output_greedy = model.generate(**generation_kwargs, **inputs_dict)