Moar bronked

This commit is contained in:
[[ -z $EMAIL ]] && read -e -p "Enter your email (for git configuration): " EMAIL 2025-02-07 12:20:52 -05:00
parent ba29a439ad
commit 919bcbeca7
4 changed files with 36 additions and 10 deletions

View file

@ -593,6 +593,8 @@ class EncoderDecoderModel(PreTrainedModel, GenerationMixin):
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Note: for now, don't deal with num_items_in_batch
kwargs.pop("num_items_in_batch", None)
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
kwargs_decoder = {

View file

@ -43,12 +43,7 @@ from ...modeling_outputs import (
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_moonshine import MoonshineConfig

View file

@ -7,7 +7,7 @@
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache

View file

@ -732,6 +732,29 @@ class TrOCRDecoderWrapper(TrOCRPreTrainedModel):
return self.decoder(*args, **kwargs)
def trocr_cross_entropy_loss(
logits,
labels,
num_items_in_batch: int = None,
ignore_index: int = -100,
vocab_size: int = None,
):
"""
Loss function for trocr that takes into account `num_items_in_batch`
"""
# move labels to correct device to enable model parallelism
labels = labels.float().to(logits.device)
logits = logits.view(-1, vocab_size).float()
shift_labels = labels.view(-1)
reduction = "sum" if num_items_in_batch is not None else "mean"
loss = nn.functional.cross_entropy(logits, shift_labels, ignore_index=ignore_index, reduction=reduction)
if reduction == "sum":
loss = loss / num_items_in_batch
return loss
@add_start_docstrings(
"The TrOCR Decoder with a language modeling head. Can be used as the decoder part of [`EncoderDecoderModel`] and"
" [`VisionEncoderDecoder`].",
@ -752,6 +775,8 @@ class TrOCRForCausalLM(TrOCRPreTrainedModel, GenerationMixin):
# Initialize weights and apply final processing
self.post_init()
self._loss_function = trocr_cross_entropy_loss
def get_input_embeddings(self):
return self.model.decoder.embed_tokens
@ -786,6 +811,7 @@ class TrOCRForCausalLM(TrOCRPreTrainedModel, GenerationMixin):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
Args:
@ -929,9 +955,12 @@ class TrOCRForCausalLM(TrOCRPreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
loss = self.loss_fn(
logits,
labels,
vocab_size=self.config.vocab_size,
**kwargs
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output