mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Fix memory regression in Seq2Seq example (#9713)
* Fix memory regression in Seq2Seq example * Fix test and properly deal with -100 * Easier condition with device safety * Patch for MBartTokenzierFast
This commit is contained in:
parent
a7dabfb3d1
commit
5f80c15ef5
5 changed files with 43 additions and 16 deletions
|
|
@ -26,6 +26,7 @@ from transformers import (
|
|||
AutoTokenizer,
|
||||
HfArgumentParser,
|
||||
MBartTokenizer,
|
||||
MBartTokenizerFast,
|
||||
Seq2SeqTrainer,
|
||||
Seq2SeqTrainingArguments,
|
||||
set_seed,
|
||||
|
|
@ -220,11 +221,14 @@ def main():
|
|||
data_args.eval_beams = model.config.num_beams
|
||||
|
||||
# set decoder_start_token_id for MBart
|
||||
if model.config.decoder_start_token_id is None and isinstance(tokenizer, MBartTokenizer):
|
||||
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
|
||||
assert (
|
||||
data_args.tgt_lang is not None and data_args.src_lang is not None
|
||||
), "mBart requires --tgt_lang and --src_lang"
|
||||
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang]
|
||||
if isinstance(tokenizer, MBartTokenizer):
|
||||
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang]
|
||||
else:
|
||||
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.tgt_lang)
|
||||
|
||||
if model_args.freeze_embeds:
|
||||
freeze_embeds(model)
|
||||
|
|
@ -284,7 +288,9 @@ def main():
|
|||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores),
|
||||
data_collator=Seq2SeqDataCollator(
|
||||
tokenizer, data_args, model.config.decoder_start_token_id, training_args.tpu_num_cores
|
||||
),
|
||||
compute_metrics=compute_metrics_fn,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -33,8 +33,9 @@ from torch import nn
|
|||
from torch.utils.data import Dataset, Sampler
|
||||
|
||||
from sentence_splitter import add_newline_to_end_of_each_sentence
|
||||
from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer
|
||||
from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.models.bart.modeling_bart import shift_tokens_right
|
||||
|
||||
|
||||
try:
|
||||
|
|
@ -274,9 +275,10 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
|
|||
|
||||
|
||||
class Seq2SeqDataCollator:
|
||||
def __init__(self, tokenizer, data_args, tpu_num_cores=None):
|
||||
def __init__(self, tokenizer, data_args, decoder_start_token_id, tpu_num_cores=None):
|
||||
self.tokenizer = tokenizer
|
||||
self.pad_token_id = tokenizer.pad_token_id
|
||||
self.decoder_start_token_id = decoder_start_token_id
|
||||
assert (
|
||||
self.pad_token_id is not None
|
||||
), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined."
|
||||
|
|
@ -304,9 +306,15 @@ class Seq2SeqDataCollator:
|
|||
labels = trim_batch(labels, self.pad_token_id)
|
||||
input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask)
|
||||
|
||||
if isinstance(self.tokenizer, T5Tokenizer):
|
||||
decoder_input_ids = self._shift_right_t5(labels)
|
||||
else:
|
||||
decoder_input_ids = shift_tokens_right(labels, self.pad_token_id, self.decoder_start_token_id)
|
||||
|
||||
batch = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"labels": labels,
|
||||
}
|
||||
return batch
|
||||
|
|
|
|||
|
|
@ -1297,14 +1297,18 @@ class Trainer:
|
|||
|
||||
Subclass and override for custom behavior.
|
||||
"""
|
||||
if self.label_smoother is not None and "labels" in inputs:
|
||||
labels = inputs.pop("labels")
|
||||
else:
|
||||
labels = None
|
||||
outputs = model(**inputs)
|
||||
# Save past state if it exists
|
||||
# TODO: this needs to be fixed and made cleaner later.
|
||||
if self.args.past_index >= 0:
|
||||
self._past = outputs[self.args.past_index]
|
||||
|
||||
if self.label_smoother is not None and "labels" in inputs:
|
||||
return self.label_smoother(outputs, inputs["labels"])
|
||||
if labels is not None:
|
||||
return self.label_smoother(outputs, labels)
|
||||
else:
|
||||
# We don't use .loss here since the model may return tuples instead of ModelOutput.
|
||||
return outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
||||
|
|
|
|||
|
|
@ -380,17 +380,26 @@ class LabelSmoother:
|
|||
ignore_index: int = -100
|
||||
|
||||
def __call__(self, model_output, labels):
|
||||
model_loss = model_output["loss"] if isinstance(model_output, dict) else model_output[0]
|
||||
logits = model_output["logits"] if isinstance(model_output, dict) else model_output[1]
|
||||
logits = model_output["logits"] if isinstance(model_output, dict) else model_output[0]
|
||||
log_probs = -torch.nn.functional.log_softmax(logits, dim=-1)
|
||||
if labels.dim() == log_probs.dim() - 1:
|
||||
labels = labels.unsqueeze(-1)
|
||||
|
||||
# Look at the ignored index and mask the corresponding log_probs.
|
||||
padding_mask = labels.unsqueeze(-1).eq(self.ignore_index)
|
||||
log_probs.masked_fill_(padding_mask, 0.0)
|
||||
padding_mask = labels.eq(self.ignore_index)
|
||||
# In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask
|
||||
# will ignore them in any case.
|
||||
labels.clamp_min_(0)
|
||||
nll_loss = log_probs.gather(dim=-1, index=labels)
|
||||
smoothed_loss = log_probs.sum(dim=-1, keepdim=True)
|
||||
|
||||
nll_loss.masked_fill_(padding_mask, 0.0)
|
||||
smoothed_loss.masked_fill_(padding_mask, 0.0)
|
||||
|
||||
# Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
|
||||
smoothed_loss = log_probs.mean(dim=-1).sum() / (padding_mask.numel() - padding_mask.long().sum())
|
||||
return (1 - self.epsilon) * model_loss + self.epsilon * smoothed_loss
|
||||
num_active_elements = padding_mask.numel() - padding_mask.long().sum()
|
||||
nll_loss = nll_loss.sum() / num_active_elements
|
||||
smoothed_loss = smoothed_loss.sum() / (num_active_elements * log_probs.shape[-1])
|
||||
return (1 - self.epsilon) * nll_loss + self.epsilon * smoothed_loss
|
||||
|
||||
|
||||
def get_length_grouped_indices(lengths, batch_size, mega_batch_mult=None, generator=None):
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ class TrainerUtilsTest(unittest.TestCase):
|
|||
random_logits = torch.randn(4, 5, num_labels)
|
||||
random_labels = torch.randint(0, num_labels, (4, 5))
|
||||
loss = torch.nn.functional.cross_entropy(random_logits.view(-1, num_labels), random_labels.view(-1))
|
||||
model_output = SequenceClassifierOutput(loss=loss, logits=random_logits)
|
||||
model_output = SequenceClassifierOutput(logits=random_logits)
|
||||
label_smoothed_loss = LabelSmoother(0.1)(model_output, random_labels)
|
||||
log_probs = -torch.nn.functional.log_softmax(random_logits, dim=-1)
|
||||
expected_loss = (1 - epsilon) * loss + epsilon * log_probs.mean()
|
||||
|
|
@ -83,7 +83,7 @@ class TrainerUtilsTest(unittest.TestCase):
|
|||
random_labels[2, 3] = -100
|
||||
|
||||
loss = torch.nn.functional.cross_entropy(random_logits.view(-1, num_labels), random_labels.view(-1))
|
||||
model_output = SequenceClassifierOutput(loss=loss, logits=random_logits)
|
||||
model_output = SequenceClassifierOutput(logits=random_logits)
|
||||
label_smoothed_loss = LabelSmoother(0.1)(model_output, random_labels)
|
||||
log_probs = -torch.nn.functional.log_softmax(random_logits, dim=-1)
|
||||
# Mask the log probs with the -100 labels
|
||||
|
|
|
|||
Loading…
Reference in a new issue