From 919bcbeca77f7b2d8cf6c4cd77f1db3a6d0152d1 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Fri, 7 Feb 2025 12:20:52 -0500 Subject: [PATCH] Moar bronked --- .../modeling_encoder_decoder.py | 2 ++ .../models/moonshine/modeling_moonshine.py | 7 +--- .../models/olmo2/modeling_olmo2.py | 2 +- .../models/trocr/modeling_trocr.py | 35 +++++++++++++++++-- 4 files changed, 36 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index 9ab4b7f2c..2b1a95a11 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -593,6 +593,8 @@ class EncoderDecoderModel(PreTrainedModel, GenerationMixin): ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # Note: for now, don't deal with num_items_in_batch + kwargs.pop("num_items_in_batch", None) kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} kwargs_decoder = { diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index fdcb1600d..78b70009e 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -43,12 +43,7 @@ from ...modeling_outputs import ( from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_moonshine import MoonshineConfig diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 561b7fdf0..b83ddd20b 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -7,7 +7,7 @@ from typing import Callable, List, Optional, Tuple, Union import torch -import torch.nn as nn +from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index 2a745516c..0021b07dd 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -732,6 +732,29 @@ class TrOCRDecoderWrapper(TrOCRPreTrainedModel): return self.decoder(*args, **kwargs) + + +def trocr_cross_entropy_loss( + logits, + labels, + num_items_in_batch: int = None, + ignore_index: int = -100, + vocab_size: int = None, +): + """ + Loss function for trocr that takes into account `num_items_in_batch` + """ + # move labels to correct device to enable model parallelism + labels = labels.float().to(logits.device) + + logits = logits.view(-1, vocab_size).float() + shift_labels = labels.view(-1) + + reduction = "sum" if num_items_in_batch is not None else "mean" + loss = nn.functional.cross_entropy(logits, shift_labels, ignore_index=ignore_index, reduction=reduction) + if reduction == "sum": + loss = loss / num_items_in_batch + return loss @add_start_docstrings( "The TrOCR Decoder with a language modeling head. Can be used as the decoder part of [`EncoderDecoderModel`] and" " [`VisionEncoderDecoder`].", @@ -752,6 +775,8 @@ class TrOCRForCausalLM(TrOCRPreTrainedModel, GenerationMixin): # Initialize weights and apply final processing self.post_init() + self._loss_function = trocr_cross_entropy_loss + def get_input_embeddings(self): return self.model.decoder.embed_tokens @@ -786,6 +811,7 @@ class TrOCRForCausalLM(TrOCRPreTrainedModel, GenerationMixin): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + **kwargs, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" Args: @@ -929,9 +955,12 @@ class TrOCRForCausalLM(TrOCRPreTrainedModel, GenerationMixin): loss = None if labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) - + loss = self.loss_fn( + logits, + labels, + vocab_size=self.config.vocab_size, + **kwargs + ) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output