mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
[generation] consistently add eos tokens (#6982)
Currently beam search returns inconsistent outputs - if hypos have different lengths we get eos, if they are the same - we don't.
This PR makes the output consistent.
Also why not also replace:
```
if sent_lengths[i] < max_length:
decoded[i, sent_lengths[i]] = eos_token_id
```
with:
```
decoded[i, sent_lengths[i]] = eos_token_id
```
Shouldn't eos always be there? If the data gets truncated, the caller needs to user a larger `max_length`.
Please correct me if my logic is flawed.
This commit is contained in:
parent
d0963486c1
commit
03e363f9ae
1 changed files with 11 additions and 13 deletions
|
|
@ -841,21 +841,19 @@ class GenerationMixin:
|
|||
sent_lengths[effective_batch_idx] = len(best_hyp)
|
||||
best.append(best_hyp)
|
||||
|
||||
# shorter batches are padded
|
||||
# prepare for adding eos
|
||||
sent_max_len = min(sent_lengths.max().item() + 1, max_length)
|
||||
decoded = input_ids.new(output_batch_size, sent_max_len)
|
||||
# shorter batches are padded if needed
|
||||
if sent_lengths.min().item() != sent_lengths.max().item():
|
||||
assert pad_token_id is not None, "`Pad_token_id` has to be defined"
|
||||
sent_max_len = min(sent_lengths.max().item() + 1, max_length)
|
||||
decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)
|
||||
assert pad_token_id is not None, "`pad_token_id` has to be defined"
|
||||
decoded.fill_(pad_token_id)
|
||||
|
||||
# fill with hypothesis and eos_token_id if necessary
|
||||
for i, hypo in enumerate(best):
|
||||
decoded[i, : sent_lengths[i]] = hypo
|
||||
if sent_lengths[i] < max_length:
|
||||
decoded[i, sent_lengths[i]] = eos_token_id
|
||||
else:
|
||||
# none of the hypotheses have an eos_token
|
||||
assert (len(hypo) == max_length for hypo in best)
|
||||
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
|
||||
# fill with hypotheses and eos_token_id if the latter fits in
|
||||
for i, hypo in enumerate(best):
|
||||
decoded[i, : sent_lengths[i]] = hypo
|
||||
if sent_lengths[i] < max_length:
|
||||
decoded[i, sent_lengths[i]] = eos_token_id
|
||||
|
||||
return decoded
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue