From 95b6bef624bd9dfdfcdfdedd86bb2173f7fb4bfe Mon Sep 17 00:00:00 2001 From: MichelBartels Date: Mon, 16 May 2022 15:37:39 +0200 Subject: [PATCH] Align logits and labels in OPT (#17237) --- src/transformers/models/opt/modeling_opt.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 5d953ab4e..97e24af54 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -951,9 +951,12 @@ class OPTForCausalLM(OPTPreTrainedModel): loss = None if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens loss_fct = CrossEntropyLoss() - - loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) if not return_dict: output = (logits,) + outputs[1:]