mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Introduce save_strategy training argument (#10286)
* Introduce save_strategy training argument * deprecate EvaluationStrategy * collapse EvaluationStrategy and LoggingStrategy into a single IntervalStrategy enum * modify tests to use modified enum
This commit is contained in:
parent
aca6288ff4
commit
256482ac92
11 changed files with 81 additions and 46 deletions
|
|
@ -22,7 +22,7 @@ Utilities
|
|||
|
||||
.. autoclass:: transformers.EvalPrediction
|
||||
|
||||
.. autoclass:: transformers.EvaluationStrategy
|
||||
.. autoclass:: transformers.IntervalStrategy
|
||||
|
||||
.. autofunction:: transformers.set_seed
|
||||
|
||||
|
|
|
|||
|
|
@ -255,7 +255,7 @@ _import_structure = {
|
|||
"TrainerControl",
|
||||
"TrainerState",
|
||||
],
|
||||
"trainer_utils": ["EvalPrediction", "EvaluationStrategy", "SchedulerType", "set_seed"],
|
||||
"trainer_utils": ["EvalPrediction", "IntervalStrategy", "SchedulerType", "set_seed"],
|
||||
"training_args": ["TrainingArguments"],
|
||||
"training_args_seq2seq": ["Seq2SeqTrainingArguments"],
|
||||
"training_args_tf": ["TFTrainingArguments"],
|
||||
|
|
@ -1429,7 +1429,7 @@ if TYPE_CHECKING:
|
|||
TrainerControl,
|
||||
TrainerState,
|
||||
)
|
||||
from .trainer_utils import EvalPrediction, EvaluationStrategy, SchedulerType, set_seed
|
||||
from .trainer_utils import EvalPrediction, IntervalStrategy, SchedulerType, set_seed
|
||||
from .training_args import TrainingArguments
|
||||
from .training_args_seq2seq import Seq2SeqTrainingArguments
|
||||
from .training_args_tf import TFTrainingArguments
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ if _has_comet:
|
|||
|
||||
from .file_utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available # noqa: E402
|
||||
from .trainer_callback import TrainerCallback # noqa: E402
|
||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, EvaluationStrategy # noqa: E402
|
||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402
|
||||
|
||||
|
||||
# Integration functions:
|
||||
|
|
@ -219,7 +219,7 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
|
|||
# Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting.
|
||||
if isinstance(
|
||||
kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining)
|
||||
) and (not trainer.args.do_eval or trainer.args.evaluation_strategy == EvaluationStrategy.NO):
|
||||
) and (not trainer.args.do_eval or trainer.args.evaluation_strategy == IntervalStrategy.NO):
|
||||
raise RuntimeError(
|
||||
"You are using {cls} as a scheduler but you haven't enabled evaluation during training. "
|
||||
"This means your trials will not report intermediate results to Ray Tune, and "
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ from typing import Dict, List, Optional, Union
|
|||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from .trainer_utils import EvaluationStrategy, LoggingStrategy
|
||||
from .trainer_utils import IntervalStrategy
|
||||
from .training_args import TrainingArguments
|
||||
from .utils import logging
|
||||
|
||||
|
|
@ -404,20 +404,25 @@ class DefaultFlowCallback(TrainerCallback):
|
|||
if state.global_step == 1 and args.logging_first_step:
|
||||
control.should_log = True
|
||||
if (
|
||||
args.logging_strategy == LoggingStrategy.STEPS
|
||||
args.logging_strategy == IntervalStrategy.STEPS
|
||||
and args.logging_steps > 0
|
||||
and state.global_step % args.logging_steps == 0
|
||||
):
|
||||
control.should_log = True
|
||||
|
||||
# Evaluate
|
||||
if args.evaluation_strategy == EvaluationStrategy.STEPS and state.global_step % args.eval_steps == 0:
|
||||
if args.evaluation_strategy == IntervalStrategy.STEPS and state.global_step % args.eval_steps == 0:
|
||||
control.should_evaluate = True
|
||||
if args.load_best_model_at_end:
|
||||
control.should_save = True
|
||||
|
||||
# Save
|
||||
if not args.load_best_model_at_end and args.save_steps > 0 and state.global_step % args.save_steps == 0:
|
||||
if (
|
||||
not args.load_best_model_at_end
|
||||
and args.save_strategy == IntervalStrategy.STEPS
|
||||
and args.save_steps > 0
|
||||
and state.global_step % args.save_steps == 0
|
||||
):
|
||||
control.should_save = True
|
||||
|
||||
# End training
|
||||
|
|
@ -428,14 +433,19 @@ class DefaultFlowCallback(TrainerCallback):
|
|||
|
||||
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
# Log
|
||||
if args.logging_strategy == LoggingStrategy.EPOCH:
|
||||
if args.logging_strategy == IntervalStrategy.EPOCH:
|
||||
control.should_log = True
|
||||
|
||||
# Evaluate
|
||||
if args.evaluation_strategy == EvaluationStrategy.EPOCH:
|
||||
if args.evaluation_strategy == IntervalStrategy.EPOCH:
|
||||
control.should_evaluate = True
|
||||
if args.load_best_model_at_end:
|
||||
control.should_save = True
|
||||
|
||||
# Save
|
||||
if args.save_strategy == IntervalStrategy.EPOCH:
|
||||
control.should_save = True
|
||||
|
||||
return control
|
||||
|
||||
|
||||
|
|
@ -531,8 +541,8 @@ class EarlyStoppingCallback(TrainerCallback):
|
|||
args.metric_for_best_model is not None
|
||||
), "EarlyStoppingCallback requires metric_for_best_model is defined"
|
||||
assert (
|
||||
args.evaluation_strategy != EvaluationStrategy.NO
|
||||
), "EarlyStoppingCallback requires EvaluationStrategy of steps or epoch"
|
||||
args.evaluation_strategy != IntervalStrategy.NO
|
||||
), "EarlyStoppingCallback requires IntervalStrategy of steps or epoch"
|
||||
|
||||
def on_evaluate(self, args, state, control, metrics, **kwargs):
|
||||
metric_to_check = args.metric_for_best_model
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ from tensorflow.python.distribute.values import PerReplica
|
|||
|
||||
from .modeling_tf_utils import TFPreTrainedModel
|
||||
from .optimization_tf import GradientAccumulator, create_optimizer
|
||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, EvaluationStrategy, PredictionOutput, set_seed
|
||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, IntervalStrategy, PredictionOutput, set_seed
|
||||
from .training_args_tf import TFTrainingArguments
|
||||
from .utils import logging
|
||||
|
||||
|
|
@ -574,7 +574,7 @@ class TFTrainer:
|
|||
|
||||
if (
|
||||
self.args.eval_steps > 0
|
||||
and self.args.evaluation_strategy == EvaluationStrategy.STEPS
|
||||
and self.args.evaluation_strategy == IntervalStrategy.STEPS
|
||||
and self.global_step % self.args.eval_steps == 0
|
||||
):
|
||||
self.evaluate()
|
||||
|
|
|
|||
|
|
@ -101,13 +101,13 @@ def get_last_checkpoint(folder):
|
|||
return os.path.join(folder, max(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0])))
|
||||
|
||||
|
||||
class EvaluationStrategy(ExplicitEnum):
|
||||
class IntervalStrategy(ExplicitEnum):
|
||||
NO = "no"
|
||||
STEPS = "steps"
|
||||
EPOCH = "epoch"
|
||||
|
||||
|
||||
class LoggingStrategy(ExplicitEnum):
|
||||
class EvaluationStrategy(ExplicitEnum):
|
||||
NO = "no"
|
||||
STEPS = "steps"
|
||||
EPOCH = "epoch"
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
|
@ -25,7 +26,7 @@ from .file_utils import (
|
|||
is_torch_tpu_available,
|
||||
torch_required,
|
||||
)
|
||||
from .trainer_utils import EvaluationStrategy, LoggingStrategy, SchedulerType, ShardedDDPOption
|
||||
from .trainer_utils import EvaluationStrategy, IntervalStrategy, SchedulerType, ShardedDDPOption
|
||||
from .utils import logging
|
||||
|
||||
|
||||
|
|
@ -84,7 +85,7 @@ class TrainingArguments:
|
|||
:class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See
|
||||
the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more
|
||||
details.
|
||||
evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.EvaluationStrategy`, `optional`, defaults to :obj:`"no"`):
|
||||
evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"no"`):
|
||||
The evaluation strategy to adopt during training. Possible values are:
|
||||
|
||||
* :obj:`"no"`: No evaluation is done during training.
|
||||
|
|
@ -139,7 +140,7 @@ class TrainingArguments:
|
|||
logging_dir (:obj:`str`, `optional`):
|
||||
`TensorBoard <https://www.tensorflow.org/tensorboard>`__ log directory. Will default to
|
||||
`runs/**CURRENT_DATETIME_HOSTNAME**`.
|
||||
logging_strategy (:obj:`str` or :class:`~transformers.trainer_utils.LoggingStrategy`, `optional`, defaults to :obj:`"steps"`):
|
||||
logging_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"steps"`):
|
||||
The logging strategy to adopt during training. Possible values are:
|
||||
|
||||
* :obj:`"no"`: No logging is done during training.
|
||||
|
|
@ -150,8 +151,15 @@ class TrainingArguments:
|
|||
Whether to log and evaluate the first :obj:`global_step` or not.
|
||||
logging_steps (:obj:`int`, `optional`, defaults to 500):
|
||||
Number of update steps between two logs if :obj:`logging_strategy="steps"`.
|
||||
save_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"steps"`):
|
||||
The checkpoint save strategy to adopt during training. Possible values are:
|
||||
|
||||
* :obj:`"no"`: No save is done during training.
|
||||
* :obj:`"epoch"`: Save is done at the end of each epoch.
|
||||
* :obj:`"steps"`: Save is done every :obj:`save_steps`.
|
||||
|
||||
save_steps (:obj:`int`, `optional`, defaults to 500):
|
||||
Number of updates steps before two checkpoint saves.
|
||||
Number of updates steps before two checkpoint saves if :obj:`save_strategy="steps"`.
|
||||
save_total_limit (:obj:`int`, `optional`):
|
||||
If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in
|
||||
:obj:`output_dir`.
|
||||
|
|
@ -215,8 +223,8 @@ class TrainingArguments:
|
|||
|
||||
.. note::
|
||||
|
||||
When set to :obj:`True`, the parameters :obj:`save_steps` will be ignored and the model will be saved
|
||||
after each evaluation.
|
||||
When set to :obj:`True`, the parameters :obj:`save_strategy` and :obj:`save_steps` will be ignored and
|
||||
the model will be saved after each evaluation.
|
||||
metric_for_best_model (:obj:`str`, `optional`):
|
||||
Use in conjunction with :obj:`load_best_model_at_end` to specify the metric to use to compare two different
|
||||
models. Must be the name of a metric returned by the evaluation with or without the prefix :obj:`"eval_"`.
|
||||
|
|
@ -297,7 +305,7 @@ class TrainingArguments:
|
|||
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
||||
do_eval: bool = field(default=None, metadata={"help": "Whether to run eval on the dev set."})
|
||||
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
|
||||
evaluation_strategy: EvaluationStrategy = field(
|
||||
evaluation_strategy: IntervalStrategy = field(
|
||||
default="no",
|
||||
metadata={"help": "The evaluation strategy to use."},
|
||||
)
|
||||
|
|
@ -359,12 +367,16 @@ class TrainingArguments:
|
|||
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
||||
|
||||
logging_dir: Optional[str] = field(default_factory=default_logdir, metadata={"help": "Tensorboard log dir."})
|
||||
logging_strategy: LoggingStrategy = field(
|
||||
logging_strategy: IntervalStrategy = field(
|
||||
default="steps",
|
||||
metadata={"help": "The logging strategy to use."},
|
||||
)
|
||||
logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"})
|
||||
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
|
||||
save_strategy: IntervalStrategy = field(
|
||||
default="steps",
|
||||
metadata={"help": "The checkpoint save strategy to use."},
|
||||
)
|
||||
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
|
||||
save_total_limit: Optional[int] = field(
|
||||
default=None,
|
||||
|
|
@ -510,10 +522,19 @@ class TrainingArguments:
|
|||
self.output_dir = os.getenv("SM_OUTPUT_DATA_DIR")
|
||||
if self.disable_tqdm is None:
|
||||
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
|
||||
self.evaluation_strategy = EvaluationStrategy(self.evaluation_strategy)
|
||||
self.logging_strategy = LoggingStrategy(self.logging_strategy)
|
||||
|
||||
if isinstance(self.evaluation_strategy, EvaluationStrategy):
|
||||
warnings.warn(
|
||||
"using `EvaluationStrategy` for `evaluation_strategy` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `IntervalStrategy` instead",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
self.evaluation_strategy = IntervalStrategy(self.evaluation_strategy)
|
||||
self.logging_strategy = IntervalStrategy(self.logging_strategy)
|
||||
self.save_strategy = IntervalStrategy(self.save_strategy)
|
||||
|
||||
self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type)
|
||||
if self.do_eval is False and self.evaluation_strategy != EvaluationStrategy.NO:
|
||||
if self.do_eval is False and self.evaluation_strategy != IntervalStrategy.NO:
|
||||
self.do_eval = True
|
||||
if self.eval_steps is None:
|
||||
self.eval_steps = self.logging_steps
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ class TFTrainingArguments(TrainingArguments):
|
|||
:class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See
|
||||
the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more
|
||||
details.
|
||||
evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.EvaluationStrategy`, `optional`, defaults to :obj:`"no"`):
|
||||
evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"no"`):
|
||||
The evaluation strategy to adopt during training. Possible values are:
|
||||
|
||||
* :obj:`"no"`: No evaluation is done during training.
|
||||
|
|
@ -102,7 +102,7 @@ class TFTrainingArguments(TrainingArguments):
|
|||
logging_dir (:obj:`str`, `optional`):
|
||||
`TensorBoard <https://www.tensorflow.org/tensorboard>`__ log directory. Will default to
|
||||
`runs/**CURRENT_DATETIME_HOSTNAME**`.
|
||||
logging_strategy (:obj:`str` or :class:`~transformers.trainer_utils.LoggingStrategy`, `optional`, defaults to :obj:`"steps"`):
|
||||
logging_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"steps"`):
|
||||
The logging strategy to adopt during training. Possible values are:
|
||||
|
||||
* :obj:`"no"`: No logging is done during training.
|
||||
|
|
@ -113,8 +113,15 @@ class TFTrainingArguments(TrainingArguments):
|
|||
Whether to log and evaluate the first :obj:`global_step` or not.
|
||||
logging_steps (:obj:`int`, `optional`, defaults to 500):
|
||||
Number of update steps between two logs if :obj:`logging_strategy="steps"`.
|
||||
save_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"steps"`):
|
||||
The checkpoint save strategy to adopt during training. Possible values are:
|
||||
|
||||
* :obj:`"no"`: No save is done during training.
|
||||
* :obj:`"epoch"`: Save is done at the end of each epoch.
|
||||
* :obj:`"steps"`: Save is done every :obj:`save_steps`.
|
||||
|
||||
save_steps (:obj:`int`, `optional`, defaults to 500):
|
||||
Number of updates steps before two checkpoint saves.
|
||||
Number of updates steps before two checkpoint saves if :obj:`save_strategy="steps"`.
|
||||
save_total_limit (:obj:`int`, `optional`):
|
||||
If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in
|
||||
:obj:`output_dir`.
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from typing import Optional
|
|||
import IPython.display as disp
|
||||
|
||||
from ..trainer_callback import TrainerCallback
|
||||
from ..trainer_utils import EvaluationStrategy
|
||||
from ..trainer_utils import IntervalStrategy
|
||||
|
||||
|
||||
def format_time(t):
|
||||
|
|
@ -277,11 +277,11 @@ class NotebookProgressCallback(TrainerCallback):
|
|||
self._force_next_update = False
|
||||
|
||||
def on_train_begin(self, args, state, control, **kwargs):
|
||||
self.first_column = "Epoch" if args.evaluation_strategy == EvaluationStrategy.EPOCH else "Step"
|
||||
self.first_column = "Epoch" if args.evaluation_strategy == IntervalStrategy.EPOCH else "Step"
|
||||
self.training_loss = 0
|
||||
self.last_log = 0
|
||||
column_names = [self.first_column] + ["Training Loss"]
|
||||
if args.evaluation_strategy != EvaluationStrategy.NO:
|
||||
if args.evaluation_strategy != IntervalStrategy.NO:
|
||||
column_names.append("Validation Loss")
|
||||
self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names)
|
||||
|
||||
|
|
@ -306,7 +306,7 @@ class NotebookProgressCallback(TrainerCallback):
|
|||
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
# Only for when there is no evaluation
|
||||
if args.evaluation_strategy == EvaluationStrategy.NO and "loss" in logs:
|
||||
if args.evaluation_strategy == IntervalStrategy.NO and "loss" in logs:
|
||||
values = {"Training Loss": logs["loss"]}
|
||||
# First column is necessarily Step sine we're not in epoch eval strategy
|
||||
values["Step"] = state.global_step
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ import unittest
|
|||
|
||||
import numpy as np
|
||||
|
||||
from transformers import AutoTokenizer, EvaluationStrategy, PretrainedConfig, TrainingArguments, is_torch_available
|
||||
from transformers import AutoTokenizer, IntervalStrategy, PretrainedConfig, TrainingArguments, is_torch_available
|
||||
from transformers.file_utils import WEIGHTS_NAME
|
||||
from transformers.testing_utils import (
|
||||
get_tests_dir,
|
||||
|
|
@ -852,7 +852,7 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||
gradient_accumulation_steps=1,
|
||||
per_device_train_batch_size=16,
|
||||
load_best_model_at_end=True,
|
||||
evaluation_strategy=EvaluationStrategy.EPOCH,
|
||||
evaluation_strategy=IntervalStrategy.EPOCH,
|
||||
compute_metrics=AlmostAccuracy(),
|
||||
metric_for_best_model="accuracy",
|
||||
)
|
||||
|
|
@ -867,7 +867,7 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||
num_train_epochs=20,
|
||||
gradient_accumulation_steps=1,
|
||||
per_device_train_batch_size=16,
|
||||
evaluation_strategy=EvaluationStrategy.EPOCH,
|
||||
evaluation_strategy=IntervalStrategy.EPOCH,
|
||||
compute_metrics=AlmostAccuracy(),
|
||||
metric_for_best_model="accuracy",
|
||||
)
|
||||
|
|
@ -1013,7 +1013,7 @@ class TrainerHyperParameterOptunaIntegrationTest(unittest.TestCase):
|
|||
output_dir=tmp_dir,
|
||||
learning_rate=0.1,
|
||||
logging_steps=1,
|
||||
evaluation_strategy=EvaluationStrategy.EPOCH,
|
||||
evaluation_strategy=IntervalStrategy.EPOCH,
|
||||
num_train_epochs=4,
|
||||
disable_tqdm=True,
|
||||
load_best_model_at_end=True,
|
||||
|
|
@ -1057,7 +1057,7 @@ class TrainerHyperParameterRayIntegrationTest(unittest.TestCase):
|
|||
output_dir=tmp_dir,
|
||||
learning_rate=0.1,
|
||||
logging_steps=1,
|
||||
evaluation_strategy=EvaluationStrategy.EPOCH,
|
||||
evaluation_strategy=IntervalStrategy.EPOCH,
|
||||
num_train_epochs=4,
|
||||
disable_tqdm=True,
|
||||
load_best_model_at_end=True,
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import unittest
|
|||
|
||||
from transformers import (
|
||||
DefaultFlowCallback,
|
||||
EvaluationStrategy,
|
||||
IntervalStrategy,
|
||||
PrinterCallback,
|
||||
ProgressCallback,
|
||||
Trainer,
|
||||
|
|
@ -129,15 +129,12 @@ class TrainerCallbackTest(unittest.TestCase):
|
|||
expected_events += ["on_step_begin", "on_step_end"]
|
||||
if step % trainer.args.logging_steps == 0:
|
||||
expected_events.append("on_log")
|
||||
if (
|
||||
trainer.args.evaluation_strategy == EvaluationStrategy.STEPS
|
||||
and step % trainer.args.eval_steps == 0
|
||||
):
|
||||
if trainer.args.evaluation_strategy == IntervalStrategy.STEPS and step % trainer.args.eval_steps == 0:
|
||||
expected_events += evaluation_events.copy()
|
||||
if step % trainer.args.save_steps == 0:
|
||||
expected_events.append("on_save")
|
||||
expected_events.append("on_epoch_end")
|
||||
if trainer.args.evaluation_strategy == EvaluationStrategy.EPOCH:
|
||||
if trainer.args.evaluation_strategy == IntervalStrategy.EPOCH:
|
||||
expected_events += evaluation_events.copy()
|
||||
expected_events += ["on_log", "on_train_end"]
|
||||
return expected_events
|
||||
|
|
|
|||
Loading…
Reference in a new issue