mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
TF Model train and eval step metrics for seq2seq models. (#14009)
* TF Model train and eval step metrics for seq2seq models. When using a model with a seq2seq output compute metrics against logits. * Removing vestigial code Co-authored-by: matt <rocketknight1@gmail.com>
This commit is contained in:
parent
fde4867f97
commit
122c2f81b7
2 changed files with 40 additions and 8 deletions
|
|
@ -43,6 +43,7 @@ from .file_utils import (
|
|||
is_remote_url,
|
||||
)
|
||||
from .generation_tf_utils import TFGenerationMixin
|
||||
from .modeling_tf_outputs import TFSeq2SeqLMOutput
|
||||
from .tokenization_utils_base import BatchEncoding
|
||||
from .utils import logging
|
||||
|
||||
|
|
@ -787,6 +788,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||
loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
|
||||
# Run backwards pass.
|
||||
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
|
||||
# When y_pred is a ModelOutput and y is a tf.Tensor the metrics update
|
||||
# should be done only with the relevant ModelOutput param that is
|
||||
# considered by the loss.
|
||||
if isinstance(y_pred, TFSeq2SeqLMOutput) and isinstance(y, tf.Tensor):
|
||||
y_pred = y_pred["logits"]
|
||||
self.compiled_metrics.update_state(y, y_pred, sample_weight)
|
||||
# Collect metrics to return
|
||||
return_metrics = {}
|
||||
|
|
@ -813,17 +819,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||
if y is None and "labels" in x:
|
||||
y = x["labels"] # Stops confusion with metric computations
|
||||
y_pred = self(x, training=False)
|
||||
if not self.loss:
|
||||
self.loss_tracker.update_state(y_pred.loss)
|
||||
return_metrics = {"loss": self.loss_tracker.result()}
|
||||
else:
|
||||
# Run anyway to update state
|
||||
self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
|
||||
return_metrics = {}
|
||||
# Updates stateful loss metrics.
|
||||
self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
|
||||
# Updates stateful loss metrics.
|
||||
if isinstance(y_pred, TFSeq2SeqLMOutput) and isinstance(y, tf.Tensor):
|
||||
y_pred = y_pred["logits"]
|
||||
self.compiled_metrics.update_state(y, y_pred, sample_weight)
|
||||
# Collect metrics to return
|
||||
return_metrics = {}
|
||||
for metric in self.metrics:
|
||||
result = metric.result()
|
||||
if isinstance(result, dict):
|
||||
|
|
|
|||
|
|
@ -666,3 +666,33 @@ class TFT5ModelIntegrationTests(unittest.TestCase):
|
|||
translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
|
||||
self.assertEqual(translation, expected_translation)
|
||||
|
||||
def test_finetune_keras_trainer(self):
|
||||
"""Ensure that the model can be fine-tuned via the keras API and
|
||||
that metrics work as expected.
|
||||
"""
|
||||
|
||||
# This metric expects to be called with the logits output
|
||||
def _accuracy(y_true, y_pred):
|
||||
return tf.keras.metrics.sparse_categorical_crossentropy(y_true[:, 0], y_pred[:, 0])
|
||||
|
||||
# measure the accuracy of the first token
|
||||
class FirstTokenAccuracy(tf.keras.metrics.MeanMetricWrapper):
|
||||
def __init__(self, name="accuracy", **kwargs):
|
||||
super().__init__(_accuracy, name=name, **kwargs)
|
||||
|
||||
model = self.model
|
||||
model.compile("adam", metrics=FirstTokenAccuracy())
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
|
||||
examples = [
|
||||
("sentiment: Everything is awesome!", "positive"),
|
||||
("sentiment: Tensorflow datasets are hard to use", "negative"),
|
||||
]
|
||||
|
||||
inputs = dict(tokenizer([x[0] for x in examples], padding=True, return_tensors="tf"))
|
||||
inputs["labels"] = tokenizer([x[1] for x in examples], return_tensors="tf").input_ids
|
||||
|
||||
model.fit(inputs)
|
||||
m = model.evaluate(inputs)
|
||||
self.assertEqual(len(m), 2)
|
||||
|
|
|
|||
Loading…
Reference in a new issue