From ae06bce888c665dfdb5315b27865f2b0aa2f7d82 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Tue, 6 Dec 2022 20:37:01 +0800 Subject: [PATCH] exclude jit time from the speed metric calculation of evaluation and prediction (#20553) Signed-off-by: Wang, Yi A Signed-off-by: Wang, Yi A --- .../pytorch/question-answering/trainer_qa.py | 8 ++++++- .../question-answering/trainer_seq2seq_qa.py | 24 +++++++++++++++---- src/transformers/modelcard.py | 1 + src/transformers/trainer.py | 8 +++++++ src/transformers/trainer_utils.py | 6 ++++- src/transformers/utils/notebook.py | 1 + 6 files changed, 42 insertions(+), 6 deletions(-) diff --git a/examples/pytorch/question-answering/trainer_qa.py b/examples/pytorch/question-answering/trainer_qa.py index e67d53eb9..a486405b6 100644 --- a/examples/pytorch/question-answering/trainer_qa.py +++ b/examples/pytorch/question-answering/trainer_qa.py @@ -51,10 +51,13 @@ class QuestionAnsweringTrainer(Trainer): # self.args.prediction_loss_only prediction_loss_only=True if compute_metrics is None else None, ignore_keys=ignore_keys, + metric_key_prefix=metric_key_prefix, ) finally: self.compute_metrics = compute_metrics total_batch_size = self.args.eval_batch_size * self.args.world_size + if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: + start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] output.metrics.update( speed_metrics( metric_key_prefix, @@ -74,7 +77,7 @@ class QuestionAnsweringTrainer(Trainer): metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) metrics.update(output.metrics) else: - metrics = {} + metrics = output.metrics if self.args.should_log: # Only the main node log the results by default @@ -103,10 +106,13 @@ class QuestionAnsweringTrainer(Trainer): # self.args.prediction_loss_only prediction_loss_only=True if compute_metrics is None else None, ignore_keys=ignore_keys, + metric_key_prefix=metric_key_prefix, ) finally: self.compute_metrics = compute_metrics total_batch_size = self.args.eval_batch_size * self.args.world_size + if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: + start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] output.metrics.update( speed_metrics( metric_key_prefix, diff --git a/examples/pytorch/question-answering/trainer_seq2seq_qa.py b/examples/pytorch/question-answering/trainer_seq2seq_qa.py index 90acc0520..73517c06d 100644 --- a/examples/pytorch/question-answering/trainer_seq2seq_qa.py +++ b/examples/pytorch/question-answering/trainer_seq2seq_qa.py @@ -71,10 +71,13 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer): # self.args.prediction_loss_only prediction_loss_only=True if compute_metrics is None else None, ignore_keys=ignore_keys, + metric_key_prefix=metric_key_prefix, ) finally: self.compute_metrics = compute_metrics total_batch_size = self.args.eval_batch_size * self.args.world_size + if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: + start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] output.metrics.update( speed_metrics( metric_key_prefix, @@ -94,9 +97,9 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer): if not key.startswith(f"{metric_key_prefix}_"): metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) - output.metrics.update(metrics) + metrics.update(output.metrics) else: - metrics = {} + metrics = output.metrics if self.args.should_log: # Only the main node log the results by default @@ -106,7 +109,7 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer): # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) - self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) + self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics) return metrics def predict( @@ -119,6 +122,7 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer): # Temporarily disable metric computation, we will do it in the loop here. compute_metrics = self.compute_metrics self.compute_metrics = None + start_time = time.time() eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop try: output = eval_loop( @@ -128,10 +132,22 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer): # self.args.prediction_loss_only prediction_loss_only=True if compute_metrics is None else None, ignore_keys=ignore_keys, + metric_key_prefix=metric_key_prefix, ) finally: self.compute_metrics = compute_metrics + total_batch_size = self.args.eval_batch_size * self.args.world_size + if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: + start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] + output.metrics.update( + speed_metrics( + metric_key_prefix, + start_time, + num_samples=output.num_samples, + num_steps=math.ceil(output.num_samples / total_batch_size), + ) + ) if self.post_process_function is None or self.compute_metrics is None: return output @@ -142,5 +158,5 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer): for key in list(metrics.keys()): if not key.startswith(f"{metric_key_prefix}_"): metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) - + metrics.update(output.metrics) return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics) diff --git a/src/transformers/modelcard.py b/src/transformers/modelcard.py index b28bc792c..4c93b810e 100644 --- a/src/transformers/modelcard.py +++ b/src/transformers/modelcard.py @@ -766,6 +766,7 @@ def parse_log_history(log_history): _ = metrics.pop("eval_runtime", None) _ = metrics.pop("eval_samples_per_second", None) _ = metrics.pop("eval_steps_per_second", None) + _ = metrics.pop("eval_jit_compilation_time", None) values = {"Training Loss": training_loss, "Epoch": epoch, "Step": step} for k, v in metrics.items(): if k == "eval_loss": diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 03b036f0f..0df75df22 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1345,7 +1345,9 @@ class Trainer: model = nn.DataParallel(model) if self.args.jit_mode_eval: + start_time = time.time() model = self.torch_jit_model_eval(model, dataloader, training) + self.jit_compilation_time = round(time.time() - start_time, 4) # Note: in torch.distributed mode, there's no point in wrapping the model # inside a DistributedDataParallel as we'll be under `no_grad` anyways. @@ -2819,6 +2821,8 @@ class Trainer: ) total_batch_size = self.args.eval_batch_size * self.args.world_size + if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: + start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] output.metrics.update( speed_metrics( metric_key_prefix, @@ -2886,6 +2890,8 @@ class Trainer: test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix ) total_batch_size = self.args.eval_batch_size * self.args.world_size + if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: + start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] output.metrics.update( speed_metrics( metric_key_prefix, @@ -3102,6 +3108,8 @@ class Trainer: if all_losses is not None: metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item() + if hasattr(self, "jit_compilation_time"): + metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time # Prefix all keys with metric_key_prefix + '_' for key in list(metrics.keys()): diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index a298fc1de..02379c07a 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -224,7 +224,11 @@ def default_compute_objective(metrics: Dict[str, float]) -> float: loss = metrics.pop("eval_loss", None) _ = metrics.pop("epoch", None) # Remove speed metrics - speed_metrics = [m for m in metrics.keys() if m.endswith("_runtime") or m.endswith("_per_second")] + speed_metrics = [ + m + for m in metrics.keys() + if m.endswith("_runtime") or m.endswith("_per_second") or m.endswith("_compilation_time") + ] for sm in speed_metrics: _ = metrics.pop(sm, None) return loss if len(metrics) == 0 else sum(metrics.values()) diff --git a/src/transformers/utils/notebook.py b/src/transformers/utils/notebook.py index 636cf785e..7894f4ad9 100644 --- a/src/transformers/utils/notebook.py +++ b/src/transformers/utils/notebook.py @@ -339,6 +339,7 @@ class NotebookProgressCallback(TrainerCallback): _ = metrics.pop(f"{metric_key_prefix}_runtime", None) _ = metrics.pop(f"{metric_key_prefix}_samples_per_second", None) _ = metrics.pop(f"{metric_key_prefix}_steps_per_second", None) + _ = metrics.pop(f"{metric_key_prefix}_jit_compilation_time", None) for k, v in metrics.items(): if k == f"{metric_key_prefix}_loss": values["Validation Loss"] = v