Align logits and labels in OPT (#17237)

This commit is contained in:
MichelBartels 2022-05-16 15:37:39 +02:00 committed by GitHub
parent a5d1839679
commit 95b6bef624
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -951,9 +951,12 @@ class OPTForCausalLM(OPTPreTrainedModel):
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]