mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Align logits and labels in OPT (#17237)
This commit is contained in:
parent
a5d1839679
commit
95b6bef624
1 changed files with 5 additions and 2 deletions
|
|
@ -951,9 +951,12 @@ class OPTForCausalLM(OPTPreTrainedModel):
|
|||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
|
||||
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
|
|
|||
Loading…
Reference in a new issue