diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index 0021b07dd..97db40169 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -20,7 +20,6 @@ from typing import Optional, Tuple, Union import torch from torch import nn -from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...generation import GenerationMixin @@ -732,8 +731,6 @@ class TrOCRDecoderWrapper(TrOCRPreTrainedModel): return self.decoder(*args, **kwargs) - - def trocr_cross_entropy_loss( logits, labels, @@ -755,6 +752,8 @@ def trocr_cross_entropy_loss( if reduction == "sum": loss = loss / num_items_in_batch return loss + + @add_start_docstrings( "The TrOCR Decoder with a language modeling head. Can be used as the decoder part of [`EncoderDecoderModel`] and" " [`VisionEncoderDecoder`].", @@ -955,12 +954,7 @@ class TrOCRForCausalLM(TrOCRPreTrainedModel, GenerationMixin): loss = None if labels is not None: - loss = self.loss_fn( - logits, - labels, - vocab_size=self.config.vocab_size, - **kwargs - ) + loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output