diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 5d953ab4e..97e24af54 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -951,9 +951,12 @@ class OPTForCausalLM(OPTPreTrainedModel): loss = None if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens loss_fct = CrossEntropyLoss() - - loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) if not return_dict: output = (logits,) + outputs[1:]