mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Loss_function
This commit is contained in:
parent
f839aa20fe
commit
f8a963c116
1 changed files with 3 additions and 9 deletions
|
|
@ -20,7 +20,6 @@ from typing import Optional, Tuple, Union
|
|||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
|
|
@ -732,8 +731,6 @@ class TrOCRDecoderWrapper(TrOCRPreTrainedModel):
|
|||
return self.decoder(*args, **kwargs)
|
||||
|
||||
|
||||
|
||||
|
||||
def trocr_cross_entropy_loss(
|
||||
logits,
|
||||
labels,
|
||||
|
|
@ -755,6 +752,8 @@ def trocr_cross_entropy_loss(
|
|||
if reduction == "sum":
|
||||
loss = loss / num_items_in_batch
|
||||
return loss
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The TrOCR Decoder with a language modeling head. Can be used as the decoder part of [`EncoderDecoderModel`] and"
|
||||
" [`VisionEncoderDecoder`].",
|
||||
|
|
@ -955,12 +954,7 @@ class TrOCRForCausalLM(TrOCRPreTrainedModel, GenerationMixin):
|
|||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_fn(
|
||||
logits,
|
||||
labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs
|
||||
)
|
||||
loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
|
|
|||
Loading…
Reference in a new issue