mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Moar bronked
This commit is contained in:
parent
ba29a439ad
commit
919bcbeca7
4 changed files with 36 additions and 10 deletions
|
|
@ -593,6 +593,8 @@ class EncoderDecoderModel(PreTrainedModel, GenerationMixin):
|
|||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# Note: for now, don't deal with num_items_in_batch
|
||||
kwargs.pop("num_items_in_batch", None)
|
||||
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
|
||||
|
||||
kwargs_decoder = {
|
||||
|
|
|
|||
|
|
@ -43,12 +43,7 @@ from ...modeling_outputs import (
|
|||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||
from .configuration_moonshine import MoonshineConfig
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
|
|
|
|||
|
|
@ -732,6 +732,29 @@ class TrOCRDecoderWrapper(TrOCRPreTrainedModel):
|
|||
return self.decoder(*args, **kwargs)
|
||||
|
||||
|
||||
|
||||
|
||||
def trocr_cross_entropy_loss(
|
||||
logits,
|
||||
labels,
|
||||
num_items_in_batch: int = None,
|
||||
ignore_index: int = -100,
|
||||
vocab_size: int = None,
|
||||
):
|
||||
"""
|
||||
Loss function for trocr that takes into account `num_items_in_batch`
|
||||
"""
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.float().to(logits.device)
|
||||
|
||||
logits = logits.view(-1, vocab_size).float()
|
||||
shift_labels = labels.view(-1)
|
||||
|
||||
reduction = "sum" if num_items_in_batch is not None else "mean"
|
||||
loss = nn.functional.cross_entropy(logits, shift_labels, ignore_index=ignore_index, reduction=reduction)
|
||||
if reduction == "sum":
|
||||
loss = loss / num_items_in_batch
|
||||
return loss
|
||||
@add_start_docstrings(
|
||||
"The TrOCR Decoder with a language modeling head. Can be used as the decoder part of [`EncoderDecoderModel`] and"
|
||||
" [`VisionEncoderDecoder`].",
|
||||
|
|
@ -752,6 +775,8 @@ class TrOCRForCausalLM(TrOCRPreTrainedModel, GenerationMixin):
|
|||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
self._loss_function = trocr_cross_entropy_loss
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.decoder.embed_tokens
|
||||
|
||||
|
|
@ -786,6 +811,7 @@ class TrOCRForCausalLM(TrOCRPreTrainedModel, GenerationMixin):
|
|||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
Args:
|
||||
|
|
@ -929,9 +955,12 @@ class TrOCRForCausalLM(TrOCRPreTrainedModel, GenerationMixin):
|
|||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
|
||||
loss = self.loss_fn(
|
||||
logits,
|
||||
labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs
|
||||
)
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
|
|
|||
Loading…
Reference in a new issue