mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
re-add scoring filtering
This commit is contained in:
parent
9b8ee8cea0
commit
7351a8dbaf
2 changed files with 8 additions and 10 deletions
|
|
@ -1084,10 +1084,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||
next_token_logits = next_token_logits / temperature
|
||||
|
||||
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||
# if (
|
||||
# self.config.is_encoder_decoder and do_sample is False
|
||||
# ): # TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search? The prepare scores fn is ugly here
|
||||
# scores = self.prepare_scores_for_generation(scores, cur_len, max_length)
|
||||
if self.config.is_encoder_decoder and do_sample is False:
|
||||
# TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search? The prepare scores fn is ugly here
|
||||
scores = self.prepare_scores_for_generation(scores, cur_len, max_length)
|
||||
|
||||
# set eos token prob to zero if min_length is not reached
|
||||
if eos_token_ids is not None and cur_len < min_length:
|
||||
|
|
@ -1279,10 +1278,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||
assert (len(hypo) == max_length for hypo in best)
|
||||
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
|
||||
|
||||
# if self.config.is_encoder_decoder:
|
||||
# do not return first <EOS> token
|
||||
# return decoded[:, 1:]
|
||||
return decoded
|
||||
if self.config.is_encoder_decoder:
|
||||
# do not return first <EOS> token
|
||||
return decoded[:, 1:]
|
||||
# return decoded
|
||||
|
||||
# force one of token_ids to be generated by setting prob of all other tokens to 0.
|
||||
def _force_token_ids_generation(self, scores, token_ids):
|
||||
|
|
|
|||
|
|
@ -471,8 +471,7 @@ class BartModelIntegrationTest(unittest.TestCase):
|
|||
attention_mask=dct["attention_mask"].to(torch_device),
|
||||
num_beams=4,
|
||||
length_penalty=2.0,
|
||||
# max_length=max_length + 2,
|
||||
max_length=max_length + 1,
|
||||
max_length=max_length + 2,
|
||||
min_length=min_length + 1,
|
||||
no_repeat_ngram_size=3,
|
||||
do_sample=False,
|
||||
|
|
|
|||
Loading…
Reference in a new issue