From 256482ac9285c467fb97ca3b1b693a4de1d0ac60 Mon Sep 17 00:00:00 2001 From: Tanmay Garg Date: Sun, 28 Feb 2021 06:04:22 +0530 Subject: [PATCH] Introduce save_strategy training argument (#10286) * Introduce save_strategy training argument * deprecate EvaluationStrategy * collapse EvaluationStrategy and LoggingStrategy into a single IntervalStrategy enum * modify tests to use modified enum --- docs/source/internal/trainer_utils.rst | 2 +- src/transformers/__init__.py | 4 +-- src/transformers/integrations.py | 4 +-- src/transformers/trainer_callback.py | 26 +++++++++++----- src/transformers/trainer_tf.py | 4 +-- src/transformers/trainer_utils.py | 4 +-- src/transformers/training_args.py | 43 +++++++++++++++++++------- src/transformers/training_args_tf.py | 13 ++++++-- src/transformers/utils/notebook.py | 8 ++--- tests/test_trainer.py | 10 +++--- tests/test_trainer_callback.py | 9 ++---- 11 files changed, 81 insertions(+), 46 deletions(-) diff --git a/docs/source/internal/trainer_utils.rst b/docs/source/internal/trainer_utils.rst index 5d787620f..c649eb3ab 100644 --- a/docs/source/internal/trainer_utils.rst +++ b/docs/source/internal/trainer_utils.rst @@ -22,7 +22,7 @@ Utilities .. autoclass:: transformers.EvalPrediction -.. autoclass:: transformers.EvaluationStrategy +.. autoclass:: transformers.IntervalStrategy .. autofunction:: transformers.set_seed diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 41b18ad6d..0856c68ed 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -255,7 +255,7 @@ _import_structure = { "TrainerControl", "TrainerState", ], - "trainer_utils": ["EvalPrediction", "EvaluationStrategy", "SchedulerType", "set_seed"], + "trainer_utils": ["EvalPrediction", "IntervalStrategy", "SchedulerType", "set_seed"], "training_args": ["TrainingArguments"], "training_args_seq2seq": ["Seq2SeqTrainingArguments"], "training_args_tf": ["TFTrainingArguments"], @@ -1429,7 +1429,7 @@ if TYPE_CHECKING: TrainerControl, TrainerState, ) - from .trainer_utils import EvalPrediction, EvaluationStrategy, SchedulerType, set_seed + from .trainer_utils import EvalPrediction, IntervalStrategy, SchedulerType, set_seed from .training_args import TrainingArguments from .training_args_seq2seq import Seq2SeqTrainingArguments from .training_args_tf import TFTrainingArguments diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index e5293a1ca..b427e33e7 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -48,7 +48,7 @@ if _has_comet: from .file_utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available # noqa: E402 from .trainer_callback import TrainerCallback # noqa: E402 -from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, EvaluationStrategy # noqa: E402 +from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402 # Integration functions: @@ -219,7 +219,7 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR # Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting. if isinstance( kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining) - ) and (not trainer.args.do_eval or trainer.args.evaluation_strategy == EvaluationStrategy.NO): + ) and (not trainer.args.do_eval or trainer.args.evaluation_strategy == IntervalStrategy.NO): raise RuntimeError( "You are using {cls} as a scheduler but you haven't enabled evaluation during training. " "This means your trials will not report intermediate results to Ray Tune, and " diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index b16f70921..9409f8aaf 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -24,7 +24,7 @@ from typing import Dict, List, Optional, Union import numpy as np from tqdm.auto import tqdm -from .trainer_utils import EvaluationStrategy, LoggingStrategy +from .trainer_utils import IntervalStrategy from .training_args import TrainingArguments from .utils import logging @@ -404,20 +404,25 @@ class DefaultFlowCallback(TrainerCallback): if state.global_step == 1 and args.logging_first_step: control.should_log = True if ( - args.logging_strategy == LoggingStrategy.STEPS + args.logging_strategy == IntervalStrategy.STEPS and args.logging_steps > 0 and state.global_step % args.logging_steps == 0 ): control.should_log = True # Evaluate - if args.evaluation_strategy == EvaluationStrategy.STEPS and state.global_step % args.eval_steps == 0: + if args.evaluation_strategy == IntervalStrategy.STEPS and state.global_step % args.eval_steps == 0: control.should_evaluate = True if args.load_best_model_at_end: control.should_save = True # Save - if not args.load_best_model_at_end and args.save_steps > 0 and state.global_step % args.save_steps == 0: + if ( + not args.load_best_model_at_end + and args.save_strategy == IntervalStrategy.STEPS + and args.save_steps > 0 + and state.global_step % args.save_steps == 0 + ): control.should_save = True # End training @@ -428,14 +433,19 @@ class DefaultFlowCallback(TrainerCallback): def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): # Log - if args.logging_strategy == LoggingStrategy.EPOCH: + if args.logging_strategy == IntervalStrategy.EPOCH: control.should_log = True # Evaluate - if args.evaluation_strategy == EvaluationStrategy.EPOCH: + if args.evaluation_strategy == IntervalStrategy.EPOCH: control.should_evaluate = True if args.load_best_model_at_end: control.should_save = True + + # Save + if args.save_strategy == IntervalStrategy.EPOCH: + control.should_save = True + return control @@ -531,8 +541,8 @@ class EarlyStoppingCallback(TrainerCallback): args.metric_for_best_model is not None ), "EarlyStoppingCallback requires metric_for_best_model is defined" assert ( - args.evaluation_strategy != EvaluationStrategy.NO - ), "EarlyStoppingCallback requires EvaluationStrategy of steps or epoch" + args.evaluation_strategy != IntervalStrategy.NO + ), "EarlyStoppingCallback requires IntervalStrategy of steps or epoch" def on_evaluate(self, args, state, control, metrics, **kwargs): metric_to_check = args.metric_for_best_model diff --git a/src/transformers/trainer_tf.py b/src/transformers/trainer_tf.py index 509d8b77f..184845b85 100644 --- a/src/transformers/trainer_tf.py +++ b/src/transformers/trainer_tf.py @@ -33,7 +33,7 @@ from tensorflow.python.distribute.values import PerReplica from .modeling_tf_utils import TFPreTrainedModel from .optimization_tf import GradientAccumulator, create_optimizer -from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, EvaluationStrategy, PredictionOutput, set_seed +from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, IntervalStrategy, PredictionOutput, set_seed from .training_args_tf import TFTrainingArguments from .utils import logging @@ -574,7 +574,7 @@ class TFTrainer: if ( self.args.eval_steps > 0 - and self.args.evaluation_strategy == EvaluationStrategy.STEPS + and self.args.evaluation_strategy == IntervalStrategy.STEPS and self.global_step % self.args.eval_steps == 0 ): self.evaluate() diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index cd70001c7..04dca620c 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -101,13 +101,13 @@ def get_last_checkpoint(folder): return os.path.join(folder, max(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0]))) -class EvaluationStrategy(ExplicitEnum): +class IntervalStrategy(ExplicitEnum): NO = "no" STEPS = "steps" EPOCH = "epoch" -class LoggingStrategy(ExplicitEnum): +class EvaluationStrategy(ExplicitEnum): NO = "no" STEPS = "steps" EPOCH = "epoch" diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 90c04f89d..c683cb13a 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -14,6 +14,7 @@ import json import os +import warnings from dataclasses import asdict, dataclass, field from enum import Enum from typing import Any, Dict, List, Optional @@ -25,7 +26,7 @@ from .file_utils import ( is_torch_tpu_available, torch_required, ) -from .trainer_utils import EvaluationStrategy, LoggingStrategy, SchedulerType, ShardedDDPOption +from .trainer_utils import EvaluationStrategy, IntervalStrategy, SchedulerType, ShardedDDPOption from .utils import logging @@ -84,7 +85,7 @@ class TrainingArguments: :class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See the `example scripts `__ for more details. - evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.EvaluationStrategy`, `optional`, defaults to :obj:`"no"`): + evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"no"`): The evaluation strategy to adopt during training. Possible values are: * :obj:`"no"`: No evaluation is done during training. @@ -139,7 +140,7 @@ class TrainingArguments: logging_dir (:obj:`str`, `optional`): `TensorBoard `__ log directory. Will default to `runs/**CURRENT_DATETIME_HOSTNAME**`. - logging_strategy (:obj:`str` or :class:`~transformers.trainer_utils.LoggingStrategy`, `optional`, defaults to :obj:`"steps"`): + logging_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"steps"`): The logging strategy to adopt during training. Possible values are: * :obj:`"no"`: No logging is done during training. @@ -150,8 +151,15 @@ class TrainingArguments: Whether to log and evaluate the first :obj:`global_step` or not. logging_steps (:obj:`int`, `optional`, defaults to 500): Number of update steps between two logs if :obj:`logging_strategy="steps"`. + save_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"steps"`): + The checkpoint save strategy to adopt during training. Possible values are: + + * :obj:`"no"`: No save is done during training. + * :obj:`"epoch"`: Save is done at the end of each epoch. + * :obj:`"steps"`: Save is done every :obj:`save_steps`. + save_steps (:obj:`int`, `optional`, defaults to 500): - Number of updates steps before two checkpoint saves. + Number of updates steps before two checkpoint saves if :obj:`save_strategy="steps"`. save_total_limit (:obj:`int`, `optional`): If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in :obj:`output_dir`. @@ -215,8 +223,8 @@ class TrainingArguments: .. note:: - When set to :obj:`True`, the parameters :obj:`save_steps` will be ignored and the model will be saved - after each evaluation. + When set to :obj:`True`, the parameters :obj:`save_strategy` and :obj:`save_steps` will be ignored and + the model will be saved after each evaluation. metric_for_best_model (:obj:`str`, `optional`): Use in conjunction with :obj:`load_best_model_at_end` to specify the metric to use to compare two different models. Must be the name of a metric returned by the evaluation with or without the prefix :obj:`"eval_"`. @@ -297,7 +305,7 @@ class TrainingArguments: do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) do_eval: bool = field(default=None, metadata={"help": "Whether to run eval on the dev set."}) do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."}) - evaluation_strategy: EvaluationStrategy = field( + evaluation_strategy: IntervalStrategy = field( default="no", metadata={"help": "The evaluation strategy to use."}, ) @@ -359,12 +367,16 @@ class TrainingArguments: warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."}) logging_dir: Optional[str] = field(default_factory=default_logdir, metadata={"help": "Tensorboard log dir."}) - logging_strategy: LoggingStrategy = field( + logging_strategy: IntervalStrategy = field( default="steps", metadata={"help": "The logging strategy to use."}, ) logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"}) logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."}) + save_strategy: IntervalStrategy = field( + default="steps", + metadata={"help": "The checkpoint save strategy to use."}, + ) save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."}) save_total_limit: Optional[int] = field( default=None, @@ -510,10 +522,19 @@ class TrainingArguments: self.output_dir = os.getenv("SM_OUTPUT_DATA_DIR") if self.disable_tqdm is None: self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN - self.evaluation_strategy = EvaluationStrategy(self.evaluation_strategy) - self.logging_strategy = LoggingStrategy(self.logging_strategy) + + if isinstance(self.evaluation_strategy, EvaluationStrategy): + warnings.warn( + "using `EvaluationStrategy` for `evaluation_strategy` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `IntervalStrategy` instead", + FutureWarning, + ) + + self.evaluation_strategy = IntervalStrategy(self.evaluation_strategy) + self.logging_strategy = IntervalStrategy(self.logging_strategy) + self.save_strategy = IntervalStrategy(self.save_strategy) + self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type) - if self.do_eval is False and self.evaluation_strategy != EvaluationStrategy.NO: + if self.do_eval is False and self.evaluation_strategy != IntervalStrategy.NO: self.do_eval = True if self.eval_steps is None: self.eval_steps = self.logging_steps diff --git a/src/transformers/training_args_tf.py b/src/transformers/training_args_tf.py index 2b66d4448..96143ffc0 100644 --- a/src/transformers/training_args_tf.py +++ b/src/transformers/training_args_tf.py @@ -58,7 +58,7 @@ class TFTrainingArguments(TrainingArguments): :class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See the `example scripts `__ for more details. - evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.EvaluationStrategy`, `optional`, defaults to :obj:`"no"`): + evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"no"`): The evaluation strategy to adopt during training. Possible values are: * :obj:`"no"`: No evaluation is done during training. @@ -102,7 +102,7 @@ class TFTrainingArguments(TrainingArguments): logging_dir (:obj:`str`, `optional`): `TensorBoard `__ log directory. Will default to `runs/**CURRENT_DATETIME_HOSTNAME**`. - logging_strategy (:obj:`str` or :class:`~transformers.trainer_utils.LoggingStrategy`, `optional`, defaults to :obj:`"steps"`): + logging_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"steps"`): The logging strategy to adopt during training. Possible values are: * :obj:`"no"`: No logging is done during training. @@ -113,8 +113,15 @@ class TFTrainingArguments(TrainingArguments): Whether to log and evaluate the first :obj:`global_step` or not. logging_steps (:obj:`int`, `optional`, defaults to 500): Number of update steps between two logs if :obj:`logging_strategy="steps"`. + save_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"steps"`): + The checkpoint save strategy to adopt during training. Possible values are: + + * :obj:`"no"`: No save is done during training. + * :obj:`"epoch"`: Save is done at the end of each epoch. + * :obj:`"steps"`: Save is done every :obj:`save_steps`. + save_steps (:obj:`int`, `optional`, defaults to 500): - Number of updates steps before two checkpoint saves. + Number of updates steps before two checkpoint saves if :obj:`save_strategy="steps"`. save_total_limit (:obj:`int`, `optional`): If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in :obj:`output_dir`. diff --git a/src/transformers/utils/notebook.py b/src/transformers/utils/notebook.py index fd986e263..9912b736b 100644 --- a/src/transformers/utils/notebook.py +++ b/src/transformers/utils/notebook.py @@ -19,7 +19,7 @@ from typing import Optional import IPython.display as disp from ..trainer_callback import TrainerCallback -from ..trainer_utils import EvaluationStrategy +from ..trainer_utils import IntervalStrategy def format_time(t): @@ -277,11 +277,11 @@ class NotebookProgressCallback(TrainerCallback): self._force_next_update = False def on_train_begin(self, args, state, control, **kwargs): - self.first_column = "Epoch" if args.evaluation_strategy == EvaluationStrategy.EPOCH else "Step" + self.first_column = "Epoch" if args.evaluation_strategy == IntervalStrategy.EPOCH else "Step" self.training_loss = 0 self.last_log = 0 column_names = [self.first_column] + ["Training Loss"] - if args.evaluation_strategy != EvaluationStrategy.NO: + if args.evaluation_strategy != IntervalStrategy.NO: column_names.append("Validation Loss") self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names) @@ -306,7 +306,7 @@ class NotebookProgressCallback(TrainerCallback): def on_log(self, args, state, control, logs=None, **kwargs): # Only for when there is no evaluation - if args.evaluation_strategy == EvaluationStrategy.NO and "loss" in logs: + if args.evaluation_strategy == IntervalStrategy.NO and "loss" in logs: values = {"Training Loss": logs["loss"]} # First column is necessarily Step sine we're not in epoch eval strategy values["Step"] = state.global_step diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 2cdec9229..65f303244 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -21,7 +21,7 @@ import unittest import numpy as np -from transformers import AutoTokenizer, EvaluationStrategy, PretrainedConfig, TrainingArguments, is_torch_available +from transformers import AutoTokenizer, IntervalStrategy, PretrainedConfig, TrainingArguments, is_torch_available from transformers.file_utils import WEIGHTS_NAME from transformers.testing_utils import ( get_tests_dir, @@ -852,7 +852,7 @@ class TrainerIntegrationTest(unittest.TestCase): gradient_accumulation_steps=1, per_device_train_batch_size=16, load_best_model_at_end=True, - evaluation_strategy=EvaluationStrategy.EPOCH, + evaluation_strategy=IntervalStrategy.EPOCH, compute_metrics=AlmostAccuracy(), metric_for_best_model="accuracy", ) @@ -867,7 +867,7 @@ class TrainerIntegrationTest(unittest.TestCase): num_train_epochs=20, gradient_accumulation_steps=1, per_device_train_batch_size=16, - evaluation_strategy=EvaluationStrategy.EPOCH, + evaluation_strategy=IntervalStrategy.EPOCH, compute_metrics=AlmostAccuracy(), metric_for_best_model="accuracy", ) @@ -1013,7 +1013,7 @@ class TrainerHyperParameterOptunaIntegrationTest(unittest.TestCase): output_dir=tmp_dir, learning_rate=0.1, logging_steps=1, - evaluation_strategy=EvaluationStrategy.EPOCH, + evaluation_strategy=IntervalStrategy.EPOCH, num_train_epochs=4, disable_tqdm=True, load_best_model_at_end=True, @@ -1057,7 +1057,7 @@ class TrainerHyperParameterRayIntegrationTest(unittest.TestCase): output_dir=tmp_dir, learning_rate=0.1, logging_steps=1, - evaluation_strategy=EvaluationStrategy.EPOCH, + evaluation_strategy=IntervalStrategy.EPOCH, num_train_epochs=4, disable_tqdm=True, load_best_model_at_end=True, diff --git a/tests/test_trainer_callback.py b/tests/test_trainer_callback.py index e1ef25945..5c0af40f4 100644 --- a/tests/test_trainer_callback.py +++ b/tests/test_trainer_callback.py @@ -18,7 +18,7 @@ import unittest from transformers import ( DefaultFlowCallback, - EvaluationStrategy, + IntervalStrategy, PrinterCallback, ProgressCallback, Trainer, @@ -129,15 +129,12 @@ class TrainerCallbackTest(unittest.TestCase): expected_events += ["on_step_begin", "on_step_end"] if step % trainer.args.logging_steps == 0: expected_events.append("on_log") - if ( - trainer.args.evaluation_strategy == EvaluationStrategy.STEPS - and step % trainer.args.eval_steps == 0 - ): + if trainer.args.evaluation_strategy == IntervalStrategy.STEPS and step % trainer.args.eval_steps == 0: expected_events += evaluation_events.copy() if step % trainer.args.save_steps == 0: expected_events.append("on_save") expected_events.append("on_epoch_end") - if trainer.args.evaluation_strategy == EvaluationStrategy.EPOCH: + if trainer.args.evaluation_strategy == IntervalStrategy.EPOCH: expected_events += evaluation_events.copy() expected_events += ["on_log", "on_train_end"] return expected_events