mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
fix
This commit is contained in:
parent
c40a12e3b8
commit
d37c8fb890
2 changed files with 6 additions and 0 deletions
|
|
@ -4268,6 +4268,7 @@ class GenerationMixin:
|
|||
breakpoint()
|
||||
# 1. Fetch candidate sequences from a `CandidateGenerator`
|
||||
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
|
||||
breakpoint()
|
||||
|
||||
if candidate_logits is not None:
|
||||
candidate_logits = candidate_logits.to(self.device)
|
||||
|
|
@ -4304,12 +4305,14 @@ class GenerationMixin:
|
|||
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
|
||||
|
||||
outputs = self(**model_inputs)
|
||||
breakpoint()
|
||||
|
||||
# 2.3. Process the new logits
|
||||
# .float() is needed to retain precision for later logits manipulations
|
||||
new_logits = outputs.logits[:, -candidate_length - 1 :].float() # excludes the input prompt if present
|
||||
new_logits = new_logits.to(input_ids.device)
|
||||
next_token_logits = new_logits.clone()
|
||||
breakpoint()
|
||||
if len(logits_processor) > 0:
|
||||
for i in range(candidate_length + 1):
|
||||
new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
|
||||
|
|
@ -4349,6 +4352,7 @@ class GenerationMixin:
|
|||
# Because of this last token, assisted generation search reduces to a normal greedy search/sample if there
|
||||
# is no match.
|
||||
|
||||
breakpoint()
|
||||
# 4.1. Get the valid continuation, after the matching tokens
|
||||
input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
|
||||
if streamer is not None:
|
||||
|
|
|
|||
|
|
@ -1108,6 +1108,8 @@ class GenerationTesterMixin:
|
|||
|
||||
def foo(self, model, inputs_dict, generation_kwargs):
|
||||
|
||||
inputs_dict['input_ids'] = torch.tensor([[88, 1, 37, 40, 82, 36, 80]])
|
||||
|
||||
generation_kwargs = copy.deepcopy(generation_kwargs)
|
||||
output_greedy = model.generate(**generation_kwargs, **inputs_dict)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue