From cdaa9ab418aec18f41c7e8e12e0ad28f238553eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Julio=20C=C3=A9sar=20Alves?= Date: Fri, 25 Feb 2022 07:56:47 -0300 Subject: [PATCH] Callback to early stop the training if there is no model improvement after consecutive evaluations (#741) * Added StopTrainingOnNoModelImprovement callback and callback_after_eval parameter in EvalCallback * Correction in EvalCallback and tests for StopTrainingOnNoModelImprovement * Update the docs related to new StopTrainingOnNoModelImprovement callback * Update doc Co-authored-by: Antonin RAFFIN Co-authored-by: Antonin Raffin --- docs/guide/callbacks.rst | 32 +++++++++++- docs/misc/changelog.rst | 3 +- stable_baselines3/common/callbacks.py | 72 +++++++++++++++++++++++++-- tests/test_callbacks.py | 5 ++ 4 files changed, 105 insertions(+), 7 deletions(-) diff --git a/docs/guide/callbacks.rst b/docs/guide/callbacks.rst index 19bccb2..6c7f4eb 100644 --- a/docs/guide/callbacks.rst +++ b/docs/guide/callbacks.rst @@ -189,7 +189,7 @@ It will save the best model if ``best_model_save_path`` folder is specified and .. note:: - You can pass a child callback via the ``callback_on_new_best`` argument. It will be triggered each time there is a new best model. + You can pass child callbacks via ``callback_after_eval`` and ``callback_on_new_best`` arguments. ``callback_after_eval`` will be triggered after every evaluation, and ``callback_on_new_best`` will be triggered each time there is a new best model. .. warning:: @@ -333,6 +333,36 @@ and in total for ``max_episodes * n_envs`` episodes. # early as soon as the max number of episodes is reached model.learn(int(1e10), callback=callback_max_episodes) +.. _StopTrainingOnNoModelImprovement: + +StopTrainingOnNoModelImprovement +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Stop the training if there is no new best model (no new best mean reward) after more than a specific number of consecutive evaluations. +The idea is to save time in experiments when you know that the learning curves are somehow well behaved and, therefore, +after many evaluations without improvement the learning has probably stabilized. +It must be used with the :ref:`EvalCallback` and use the event triggered after every evaluation. + + +.. code-block:: python + + import gym + + from stable_baselines3 import SAC + from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnNoModelImprovement + + # Separate evaluation env + eval_env = gym.make("Pendulum-v1") + # Stop training if there is no improvement after more than 3 evaluations + stop_train_callback = StopTrainingOnNoModelImprovement(max_no_improvement_evals=3, min_evals=5, verbose=1) + eval_callback = EvalCallback(eval_env, eval_freq=1000, callback_after_eval=stop_train_callback, verbose=1) + + model = SAC("MlpPolicy", "Pendulum-v1", learning_rate=1e-3, verbose=1) + # Almost infinite number of timesteps, but the training will stop early + # as soon as the the number of consecutive evaluations without model + # improvement is greater than 3 + model.learn(int(1e10), callback=eval_callback) + .. automodule:: stable_baselines3.common.callbacks :members: diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index ab0255a..c5be840 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -14,6 +14,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ +- Added ``StopTrainingOnNoModelImprovement`` to callback collection (@caburu) - Makes the length of keys and values in ``HumanOutputFormat`` configurable, depending on desired maximum width of output. - Allow PPO to turn of advantage normalization (see `PR #763 `_) @vwxyzjn @@ -925,4 +926,4 @@ And all the contributors: @benblack769 @bstee615 @c-rizz @skandermoalla @MihaiAnca13 @davidblom603 @ayeright @cyprienc @wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum @eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP -@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 +@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py index cba6cb8..27ce5e6 100644 --- a/stable_baselines3/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -278,6 +278,7 @@ class EvalCallback(EventCallback): :param eval_env: The environment used for initialization :param callback_on_new_best: Callback to trigger when there is a new best model according to the ``mean_reward`` + :param callback_after_eval: Callback to trigger after every evaluation :param n_eval_episodes: The number of episodes to test the agent :param eval_freq: Evaluate the agent every ``eval_freq`` call of the callback. :param log_path: Path to a folder where the evaluations (``evaluations.npz``) @@ -296,6 +297,7 @@ class EvalCallback(EventCallback): self, eval_env: Union[gym.Env, VecEnv], callback_on_new_best: Optional[BaseCallback] = None, + callback_after_eval: Optional[BaseCallback] = None, n_eval_episodes: int = 5, eval_freq: int = 10000, log_path: Optional[str] = None, @@ -305,7 +307,13 @@ class EvalCallback(EventCallback): verbose: int = 1, warn: bool = True, ): - super(EvalCallback, self).__init__(callback_on_new_best, verbose=verbose) + super(EvalCallback, self).__init__(callback_after_eval, verbose=verbose) + + self.callback_on_new_best = callback_on_new_best + if self.callback_on_new_best is not None: + # Give access to the parent + self.callback_on_new_best.parent = self + self.n_eval_episodes = n_eval_episodes self.eval_freq = eval_freq self.best_mean_reward = -np.inf @@ -342,6 +350,10 @@ class EvalCallback(EventCallback): if self.log_path is not None: os.makedirs(os.path.dirname(self.log_path), exist_ok=True) + # Init callback called on new best model + if self.callback_on_new_best is not None: + self.callback_on_new_best.init_callback(self.model) + def _log_success_callback(self, locals_: Dict[str, Any], globals_: Dict[str, Any]) -> None: """ Callback passed to the ``evaluate_policy`` function @@ -360,7 +372,10 @@ class EvalCallback(EventCallback): def _on_step(self) -> bool: + continue_training = True + if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0: + # Sync training and eval env if there is VecNormalize if self.model.get_vec_normalize_env() is not None: try: @@ -432,11 +447,15 @@ class EvalCallback(EventCallback): if self.best_model_save_path is not None: self.model.save(os.path.join(self.best_model_save_path, "best_model")) self.best_mean_reward = mean_reward - # Trigger callback if needed - if self.callback is not None: - return self._on_event() + # Trigger callback on new best model, if needed + if self.callback_on_new_best is not None: + continue_training = self.callback_on_new_best.on_step() - return True + # Trigger callback after every evaluation, if needed + if self.callback is not None: + continue_training = continue_training and self._on_event() + + return continue_training def update_child_locals(self, locals_: Dict[str, Any]) -> None: """ @@ -538,3 +557,46 @@ class StopTrainingOnMaxEpisodes(BaseCallback): f"{mean_ep_str}" ) return continue_training + + +class StopTrainingOnNoModelImprovement(BaseCallback): + """ + Stop the training early if there is no new best model (new best mean reward) after more than N consecutive evaluations. + + It is possible to define a minimum number of evaluations before start to count evaluations without improvement. + + It must be used with the ``EvalCallback``. + + :param max_no_improvement_evals: Maximum number of consecutive evaluations without a new best model. + :param min_evals: Number of evaluations before start to count evaluations without improvements. + :param verbose: Verbosity of the output (set to 1 for info messages) + """ + + def __init__(self, max_no_improvement_evals: int, min_evals: int = 0, verbose: int = 0): + super(StopTrainingOnNoModelImprovement, self).__init__(verbose=verbose) + self.max_no_improvement_evals = max_no_improvement_evals + self.min_evals = min_evals + self.last_best_mean_reward = -np.inf + self.no_improvement_evals = 0 + + def _on_step(self) -> bool: + assert self.parent is not None, "``StopTrainingOnNoModelImprovement`` callback must be used with an ``EvalCallback``" + + continue_training = True + + if self.n_calls > self.min_evals: + if self.parent.best_mean_reward > self.last_best_mean_reward: + self.no_improvement_evals = 0 + else: + self.no_improvement_evals += 1 + if self.no_improvement_evals > self.max_no_improvement_evals: + continue_training = False + + self.last_best_mean_reward = self.parent.best_mean_reward + + if self.verbose > 0 and not continue_training: + print( + f"Stopping training because there was no new best model in the last {self.no_improvement_evals:d} evaluations" + ) + + return continue_training diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index e1f6d38..6576f7d 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -12,6 +12,7 @@ from stable_baselines3.common.callbacks import ( EvalCallback, EveryNTimesteps, StopTrainingOnMaxEpisodes, + StopTrainingOnNoModelImprovement, StopTrainingOnRewardThreshold, ) from stable_baselines3.common.env_util import make_vec_env @@ -35,9 +36,13 @@ def test_callbacks(tmp_path, model_class): # Stop training if the performance is good enough callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-1200, verbose=1) + # Stop training if there is no model improvement after 2 evaluations + callback_no_model_improvement = StopTrainingOnNoModelImprovement(max_no_improvement_evals=2, min_evals=1, verbose=1) + eval_callback = EvalCallback( eval_env, callback_on_new_best=callback_on_best, + callback_after_eval=callback_no_model_improvement, best_model_save_path=log_folder, log_path=log_folder, eval_freq=100,