diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 37482199c..82434100c 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1084,10 +1084,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): next_token_logits = next_token_logits / temperature scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) - # if ( - # self.config.is_encoder_decoder and do_sample is False - # ): # TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search? The prepare scores fn is ugly here - # scores = self.prepare_scores_for_generation(scores, cur_len, max_length) + if self.config.is_encoder_decoder and do_sample is False: + # TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search? The prepare scores fn is ugly here + scores = self.prepare_scores_for_generation(scores, cur_len, max_length) # set eos token prob to zero if min_length is not reached if eos_token_ids is not None and cur_len < min_length: @@ -1279,10 +1278,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): assert (len(hypo) == max_length for hypo in best) decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device) - # if self.config.is_encoder_decoder: - # do not return first token - # return decoded[:, 1:] - return decoded + if self.config.is_encoder_decoder: + # do not return first token + return decoded[:, 1:] +# return decoded # force one of token_ids to be generated by setting prob of all other tokens to 0. def _force_token_ids_generation(self, scores, token_ids): diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index dacee2631..b23f01066 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -471,8 +471,7 @@ class BartModelIntegrationTest(unittest.TestCase): attention_mask=dct["attention_mask"].to(torch_device), num_beams=4, length_penalty=2.0, - # max_length=max_length + 2, - max_length=max_length + 1, + max_length=max_length + 2, min_length=min_length + 1, no_repeat_ngram_size=3, do_sample=False,