Loss_function

This commit is contained in:
[[ -z $EMAIL ]] && read -e -p "Enter your email (for git configuration): " EMAIL 2025-02-07 12:35:22 -05:00
parent f839aa20fe
commit f8a963c116

View file

@ -20,7 +20,6 @@ from typing import Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...generation import GenerationMixin
@ -732,8 +731,6 @@ class TrOCRDecoderWrapper(TrOCRPreTrainedModel):
return self.decoder(*args, **kwargs)
def trocr_cross_entropy_loss(
logits,
labels,
@ -755,6 +752,8 @@ def trocr_cross_entropy_loss(
if reduction == "sum":
loss = loss / num_items_in_batch
return loss
@add_start_docstrings(
"The TrOCR Decoder with a language modeling head. Can be used as the decoder part of [`EncoderDecoderModel`] and"
" [`VisionEncoderDecoder`].",
@ -955,12 +954,7 @@ class TrOCRForCausalLM(TrOCRPreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_fn(
logits,
labels,
vocab_size=self.config.vocab_size,
**kwargs
)
loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output