mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
[integration] Update Ray Tune integration for Ray 2.7 (#26499)
* fix tune integration for ray 2.7+ Signed-off-by: Justin Yu <justinvyu@anyscale.com> * add version check for ray tune backend availability Signed-off-by: Justin Yu <justinvyu@anyscale.com> * missing import Signed-off-by: Justin Yu <justinvyu@anyscale.com> * pin min version instead Signed-off-by: Justin Yu <justinvyu@anyscale.com> * address comments Signed-off-by: Justin Yu <justinvyu@anyscale.com> * some fixes Signed-off-by: Justin Yu <justinvyu@anyscale.com> * fix unnecessary final checkpoint Signed-off-by: Justin Yu <justinvyu@anyscale.com> * fix lint Signed-off-by: Justin Yu <justinvyu@anyscale.com> * dep table fix Signed-off-by: Justin Yu <justinvyu@anyscale.com> * fix lint Signed-off-by: Justin Yu <justinvyu@anyscale.com> --------- Signed-off-by: Justin Yu <justinvyu@anyscale.com>
This commit is contained in:
parent
ffd426eef8
commit
5fa66df3f3
5 changed files with 50 additions and 52 deletions
2
setup.py
2
setup.py
|
|
@ -149,7 +149,7 @@ _deps = [
|
|||
"pytest-timeout",
|
||||
"pytest-xdist",
|
||||
"python>=3.8.0",
|
||||
"ray[tune]",
|
||||
"ray[tune]>=2.7.0",
|
||||
"regex!=2019.12.17",
|
||||
"requests",
|
||||
"rhoknp>=1.1.0,<1.3.1",
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ deps = {
|
|||
"pytest-timeout": "pytest-timeout",
|
||||
"pytest-xdist": "pytest-xdist",
|
||||
"python": "python>=3.8.0",
|
||||
"ray[tune]": "ray[tune]",
|
||||
"ray[tune]": "ray[tune]>=2.7.0",
|
||||
"regex": "regex!=2019.12.17",
|
||||
"requests": "requests",
|
||||
"rhoknp": "rhoknp>=1.1.0,<1.3.1",
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@
|
|||
|
||||
from .integrations import (
|
||||
is_optuna_available,
|
||||
is_ray_available,
|
||||
is_ray_tune_available,
|
||||
is_sigopt_available,
|
||||
is_wandb_available,
|
||||
run_hp_search_optuna,
|
||||
|
|
@ -81,7 +81,7 @@ class RayTuneBackend(HyperParamSearchBackendBase):
|
|||
|
||||
@staticmethod
|
||||
def is_available():
|
||||
return is_ray_available()
|
||||
return is_ray_tune_available()
|
||||
|
||||
def run(self, trainer, n_trials: int, direction: str, **kwargs):
|
||||
return run_hp_search_ray(trainer, n_trials, direction, **kwargs)
|
||||
|
|
|
|||
|
|
@ -236,8 +236,9 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be
|
|||
|
||||
def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
|
||||
import ray
|
||||
import ray.train
|
||||
|
||||
def _objective(trial, local_trainer, checkpoint_dir=None):
|
||||
def _objective(trial: dict, local_trainer):
|
||||
try:
|
||||
from transformers.utils.notebook import NotebookProgressCallback
|
||||
|
||||
|
|
@ -246,19 +247,34 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
|
|||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
checkpoint = None
|
||||
if checkpoint_dir:
|
||||
for subdir in os.listdir(checkpoint_dir):
|
||||
if subdir.startswith(PREFIX_CHECKPOINT_DIR):
|
||||
checkpoint = os.path.join(checkpoint_dir, subdir)
|
||||
local_trainer.objective = None
|
||||
local_trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
|
||||
|
||||
checkpoint = ray.train.get_checkpoint()
|
||||
if checkpoint:
|
||||
# Upon trial resume, the local_trainer's objective gets reset to None.
|
||||
# If `local_trainer.train` is a noop (training has already reached
|
||||
# the target number of epochs/steps), then this would
|
||||
# trigger an unnecessary extra checkpoint at the end of training.
|
||||
# -> Set the objective to a dummy value upon resume as a workaround.
|
||||
local_trainer.objective = "objective"
|
||||
|
||||
with checkpoint.as_directory() as checkpoint_dir:
|
||||
checkpoint_path = next(Path(checkpoint_dir).glob(f"{PREFIX_CHECKPOINT_DIR}*")).as_posix()
|
||||
local_trainer.train(resume_from_checkpoint=checkpoint_path, trial=trial)
|
||||
else:
|
||||
local_trainer.train(trial=trial)
|
||||
|
||||
# If there hasn't been any evaluation during the training loop.
|
||||
if getattr(local_trainer, "objective", None) is None:
|
||||
metrics = local_trainer.evaluate()
|
||||
local_trainer.objective = local_trainer.compute_objective(metrics)
|
||||
local_trainer._tune_save_checkpoint()
|
||||
ray.tune.report(objective=local_trainer.objective, **metrics, done=True)
|
||||
|
||||
metrics.update({"objective": local_trainer.objective, "done": True})
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
|
||||
local_trainer._tune_save_checkpoint(checkpoint_dir=temp_checkpoint_dir)
|
||||
checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
|
||||
ray.train.report(metrics, checkpoint=checkpoint)
|
||||
|
||||
if not trainer._memory_tracker.skip_memory_metrics:
|
||||
from ..trainer_utils import TrainerMemoryTracker
|
||||
|
|
@ -296,28 +312,10 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
|
|||
from ray.tune import CLIReporter
|
||||
|
||||
kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"])
|
||||
if "keep_checkpoints_num" in kwargs and kwargs["keep_checkpoints_num"] > 0:
|
||||
# `keep_checkpoints_num=0` would disabled checkpointing
|
||||
trainer.use_tune_checkpoints = True
|
||||
if kwargs["keep_checkpoints_num"] > 1:
|
||||
logger.warning(
|
||||
f"Currently keeping {kwargs['keep_checkpoints_num']} checkpoints for each trial. "
|
||||
"Checkpoints are usually huge, "
|
||||
"consider setting `keep_checkpoints_num=1`."
|
||||
)
|
||||
|
||||
if "scheduler" in kwargs:
|
||||
from ray.tune.schedulers import ASHAScheduler, HyperBandForBOHB, MedianStoppingRule, PopulationBasedTraining
|
||||
|
||||
# Check if checkpointing is enabled for PopulationBasedTraining
|
||||
if isinstance(kwargs["scheduler"], PopulationBasedTraining):
|
||||
if not trainer.use_tune_checkpoints:
|
||||
logger.warning(
|
||||
"You are using PopulationBasedTraining but you haven't enabled checkpointing. "
|
||||
"This means your trials will train from scratch everytime they are exploiting "
|
||||
"new configurations. Consider enabling checkpointing by passing "
|
||||
"`keep_checkpoints_num=1` as an additional argument to `Trainer.hyperparameter_search`."
|
||||
)
|
||||
|
||||
# Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting.
|
||||
if isinstance(
|
||||
kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining)
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ import random
|
|||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import warnings
|
||||
from collections.abc import Mapping
|
||||
|
|
@ -595,7 +596,6 @@ class Trainer:
|
|||
# returned to 0 every time flos need to be logged
|
||||
self.current_flos = 0
|
||||
self.hp_search_backend = None
|
||||
self.use_tune_checkpoints = False
|
||||
default_label_names = find_labels(self.model.__class__)
|
||||
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
|
||||
self.can_return_loss = can_return_loss(self.model.__class__)
|
||||
|
|
@ -1201,7 +1201,8 @@ class Trainer:
|
|||
def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):
|
||||
if self.hp_search_backend is None or trial is None:
|
||||
return
|
||||
self.objective = self.compute_objective(metrics.copy())
|
||||
metrics = metrics.copy()
|
||||
self.objective = self.compute_objective(metrics)
|
||||
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
||||
import optuna
|
||||
|
||||
|
|
@ -1211,24 +1212,23 @@ class Trainer:
|
|||
self.callback_handler.on_train_end(self.args, self.state, self.control)
|
||||
raise optuna.TrialPruned()
|
||||
elif self.hp_search_backend == HPSearchBackend.RAY:
|
||||
from ray import tune
|
||||
import ray.train
|
||||
|
||||
if self.control.should_save:
|
||||
self._tune_save_checkpoint()
|
||||
tune.report(objective=self.objective, **metrics)
|
||||
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
|
||||
checkpoint = None
|
||||
if self.control.should_save:
|
||||
self._tune_save_checkpoint(checkpoint_dir=temp_checkpoint_dir)
|
||||
checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
|
||||
metrics["objective"] = self.objective
|
||||
ray.train.report(metrics, checkpoint=checkpoint)
|
||||
|
||||
def _tune_save_checkpoint(self):
|
||||
from ray import tune
|
||||
|
||||
if not self.use_tune_checkpoints:
|
||||
return
|
||||
with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
|
||||
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
|
||||
self.save_model(output_dir, _internal_call=True)
|
||||
if self.args.should_save:
|
||||
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
|
||||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
|
||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||
def _tune_save_checkpoint(self, checkpoint_dir: str):
|
||||
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
|
||||
self.save_model(output_dir, _internal_call=True)
|
||||
if self.args.should_save:
|
||||
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
|
||||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
|
||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||
|
||||
def call_model_init(self, trial=None):
|
||||
model_init_argcount = number_of_arguments(self.model_init)
|
||||
|
|
@ -2004,9 +2004,9 @@ class Trainer:
|
|||
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
||||
run_id = trial.number
|
||||
elif self.hp_search_backend == HPSearchBackend.RAY:
|
||||
from ray import tune
|
||||
import ray.train
|
||||
|
||||
run_id = tune.get_trial_id()
|
||||
run_id = ray.train.get_context().get_trial_id()
|
||||
elif self.hp_search_backend == HPSearchBackend.SIGOPT:
|
||||
run_id = trial.id
|
||||
elif self.hp_search_backend == HPSearchBackend.WANDB:
|
||||
|
|
|
|||
Loading…
Reference in a new issue