From 122c2f81b7ba2eb91c52d4a54a2bca7260e327a4 Mon Sep 17 00:00:00 2001 From: Pedro Marques Date: Tue, 19 Oct 2021 13:14:21 +0200 Subject: [PATCH] TF Model train and eval step metrics for seq2seq models. (#14009) * TF Model train and eval step metrics for seq2seq models. When using a model with a seq2seq output compute metrics against logits. * Removing vestigial code Co-authored-by: matt --- src/transformers/modeling_tf_utils.py | 18 +++++++++------- tests/test_modeling_tf_t5.py | 30 +++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 4ae4cf239..001c42a42 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -43,6 +43,7 @@ from .file_utils import ( is_remote_url, ) from .generation_tf_utils import TFGenerationMixin +from .modeling_tf_outputs import TFSeq2SeqLMOutput from .tokenization_utils_base import BatchEncoding from .utils import logging @@ -787,6 +788,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses) # Run backwards pass. self.optimizer.minimize(loss, self.trainable_variables, tape=tape) + # When y_pred is a ModelOutput and y is a tf.Tensor the metrics update + # should be done only with the relevant ModelOutput param that is + # considered by the loss. + if isinstance(y_pred, TFSeq2SeqLMOutput) and isinstance(y, tf.Tensor): + y_pred = y_pred["logits"] self.compiled_metrics.update_state(y, y_pred, sample_weight) # Collect metrics to return return_metrics = {} @@ -813,17 +819,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu if y is None and "labels" in x: y = x["labels"] # Stops confusion with metric computations y_pred = self(x, training=False) - if not self.loss: - self.loss_tracker.update_state(y_pred.loss) - return_metrics = {"loss": self.loss_tracker.result()} - else: - # Run anyway to update state - self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses) - return_metrics = {} - # Updates stateful loss metrics. self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses) + # Updates stateful loss metrics. + if isinstance(y_pred, TFSeq2SeqLMOutput) and isinstance(y, tf.Tensor): + y_pred = y_pred["logits"] self.compiled_metrics.update_state(y, y_pred, sample_weight) # Collect metrics to return + return_metrics = {} for metric in self.metrics: result = metric.result() if isinstance(result, dict): diff --git a/tests/test_modeling_tf_t5.py b/tests/test_modeling_tf_t5.py index 55f7c8627..59ee70c53 100644 --- a/tests/test_modeling_tf_t5.py +++ b/tests/test_modeling_tf_t5.py @@ -666,3 +666,33 @@ class TFT5ModelIntegrationTests(unittest.TestCase): translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) self.assertEqual(translation, expected_translation) + + def test_finetune_keras_trainer(self): + """Ensure that the model can be fine-tuned via the keras API and + that metrics work as expected. + """ + + # This metric expects to be called with the logits output + def _accuracy(y_true, y_pred): + return tf.keras.metrics.sparse_categorical_crossentropy(y_true[:, 0], y_pred[:, 0]) + + # measure the accuracy of the first token + class FirstTokenAccuracy(tf.keras.metrics.MeanMetricWrapper): + def __init__(self, name="accuracy", **kwargs): + super().__init__(_accuracy, name=name, **kwargs) + + model = self.model + model.compile("adam", metrics=FirstTokenAccuracy()) + tokenizer = T5Tokenizer.from_pretrained("t5-small") + + examples = [ + ("sentiment: Everything is awesome!", "positive"), + ("sentiment: Tensorflow datasets are hard to use", "negative"), + ] + + inputs = dict(tokenizer([x[0] for x in examples], padding=True, return_tensors="tf")) + inputs["labels"] = tokenizer([x[1] for x in examples], return_tensors="tf").input_ids + + model.fit(inputs) + m = model.evaluate(inputs) + self.assertEqual(len(m), 2)