Callback to early stop the training if there is no model improvement after consecutive evaluations (#741)

* Added StopTrainingOnNoModelImprovement callback and callback_after_eval parameter in EvalCallback

* Correction in EvalCallback and tests for StopTrainingOnNoModelImprovement

* Update the docs related to new StopTrainingOnNoModelImprovement callback

* Update doc

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
This commit is contained in:
Julio César Alves 2022-02-25 07:56:47 -03:00 committed by GitHub
parent db5366fb51
commit cdaa9ab418
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 105 additions and 7 deletions

View file

@ -189,7 +189,7 @@ It will save the best model if ``best_model_save_path`` folder is specified and
.. note::
You can pass a child callback via the ``callback_on_new_best`` argument. It will be triggered each time there is a new best model.
You can pass child callbacks via ``callback_after_eval`` and ``callback_on_new_best`` arguments. ``callback_after_eval`` will be triggered after every evaluation, and ``callback_on_new_best`` will be triggered each time there is a new best model.
.. warning::
@ -333,6 +333,36 @@ and in total for ``max_episodes * n_envs`` episodes.
# early as soon as the max number of episodes is reached
model.learn(int(1e10), callback=callback_max_episodes)
.. _StopTrainingOnNoModelImprovement:
StopTrainingOnNoModelImprovement
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Stop the training if there is no new best model (no new best mean reward) after more than a specific number of consecutive evaluations.
The idea is to save time in experiments when you know that the learning curves are somehow well behaved and, therefore,
after many evaluations without improvement the learning has probably stabilized.
It must be used with the :ref:`EvalCallback` and use the event triggered after every evaluation.
.. code-block:: python
import gym
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnNoModelImprovement
# Separate evaluation env
eval_env = gym.make("Pendulum-v1")
# Stop training if there is no improvement after more than 3 evaluations
stop_train_callback = StopTrainingOnNoModelImprovement(max_no_improvement_evals=3, min_evals=5, verbose=1)
eval_callback = EvalCallback(eval_env, eval_freq=1000, callback_after_eval=stop_train_callback, verbose=1)
model = SAC("MlpPolicy", "Pendulum-v1", learning_rate=1e-3, verbose=1)
# Almost infinite number of timesteps, but the training will stop early
# as soon as the the number of consecutive evaluations without model
# improvement is greater than 3
model.learn(int(1e10), callback=eval_callback)
.. automodule:: stable_baselines3.common.callbacks
:members:

View file

@ -14,6 +14,7 @@ Breaking Changes:
New Features:
^^^^^^^^^^^^^
- Added ``StopTrainingOnNoModelImprovement`` to callback collection (@caburu)
- Makes the length of keys and values in ``HumanOutputFormat`` configurable,
depending on desired maximum width of output.
- Allow PPO to turn of advantage normalization (see `PR #763 <https://github.com/DLR-RM/stable-baselines3/pull/763>`_) @vwxyzjn
@ -925,4 +926,4 @@ And all the contributors:
@benblack769 @bstee615 @c-rizz @skandermoalla @MihaiAnca13 @davidblom603 @ayeright @cyprienc
@wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum
@eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu

View file

@ -278,6 +278,7 @@ class EvalCallback(EventCallback):
:param eval_env: The environment used for initialization
:param callback_on_new_best: Callback to trigger
when there is a new best model according to the ``mean_reward``
:param callback_after_eval: Callback to trigger after every evaluation
:param n_eval_episodes: The number of episodes to test the agent
:param eval_freq: Evaluate the agent every ``eval_freq`` call of the callback.
:param log_path: Path to a folder where the evaluations (``evaluations.npz``)
@ -296,6 +297,7 @@ class EvalCallback(EventCallback):
self,
eval_env: Union[gym.Env, VecEnv],
callback_on_new_best: Optional[BaseCallback] = None,
callback_after_eval: Optional[BaseCallback] = None,
n_eval_episodes: int = 5,
eval_freq: int = 10000,
log_path: Optional[str] = None,
@ -305,7 +307,13 @@ class EvalCallback(EventCallback):
verbose: int = 1,
warn: bool = True,
):
super(EvalCallback, self).__init__(callback_on_new_best, verbose=verbose)
super(EvalCallback, self).__init__(callback_after_eval, verbose=verbose)
self.callback_on_new_best = callback_on_new_best
if self.callback_on_new_best is not None:
# Give access to the parent
self.callback_on_new_best.parent = self
self.n_eval_episodes = n_eval_episodes
self.eval_freq = eval_freq
self.best_mean_reward = -np.inf
@ -342,6 +350,10 @@ class EvalCallback(EventCallback):
if self.log_path is not None:
os.makedirs(os.path.dirname(self.log_path), exist_ok=True)
# Init callback called on new best model
if self.callback_on_new_best is not None:
self.callback_on_new_best.init_callback(self.model)
def _log_success_callback(self, locals_: Dict[str, Any], globals_: Dict[str, Any]) -> None:
"""
Callback passed to the ``evaluate_policy`` function
@ -360,7 +372,10 @@ class EvalCallback(EventCallback):
def _on_step(self) -> bool:
continue_training = True
if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
# Sync training and eval env if there is VecNormalize
if self.model.get_vec_normalize_env() is not None:
try:
@ -432,11 +447,15 @@ class EvalCallback(EventCallback):
if self.best_model_save_path is not None:
self.model.save(os.path.join(self.best_model_save_path, "best_model"))
self.best_mean_reward = mean_reward
# Trigger callback if needed
if self.callback is not None:
return self._on_event()
# Trigger callback on new best model, if needed
if self.callback_on_new_best is not None:
continue_training = self.callback_on_new_best.on_step()
return True
# Trigger callback after every evaluation, if needed
if self.callback is not None:
continue_training = continue_training and self._on_event()
return continue_training
def update_child_locals(self, locals_: Dict[str, Any]) -> None:
"""
@ -538,3 +557,46 @@ class StopTrainingOnMaxEpisodes(BaseCallback):
f"{mean_ep_str}"
)
return continue_training
class StopTrainingOnNoModelImprovement(BaseCallback):
"""
Stop the training early if there is no new best model (new best mean reward) after more than N consecutive evaluations.
It is possible to define a minimum number of evaluations before start to count evaluations without improvement.
It must be used with the ``EvalCallback``.
:param max_no_improvement_evals: Maximum number of consecutive evaluations without a new best model.
:param min_evals: Number of evaluations before start to count evaluations without improvements.
:param verbose: Verbosity of the output (set to 1 for info messages)
"""
def __init__(self, max_no_improvement_evals: int, min_evals: int = 0, verbose: int = 0):
super(StopTrainingOnNoModelImprovement, self).__init__(verbose=verbose)
self.max_no_improvement_evals = max_no_improvement_evals
self.min_evals = min_evals
self.last_best_mean_reward = -np.inf
self.no_improvement_evals = 0
def _on_step(self) -> bool:
assert self.parent is not None, "``StopTrainingOnNoModelImprovement`` callback must be used with an ``EvalCallback``"
continue_training = True
if self.n_calls > self.min_evals:
if self.parent.best_mean_reward > self.last_best_mean_reward:
self.no_improvement_evals = 0
else:
self.no_improvement_evals += 1
if self.no_improvement_evals > self.max_no_improvement_evals:
continue_training = False
self.last_best_mean_reward = self.parent.best_mean_reward
if self.verbose > 0 and not continue_training:
print(
f"Stopping training because there was no new best model in the last {self.no_improvement_evals:d} evaluations"
)
return continue_training

View file

@ -12,6 +12,7 @@ from stable_baselines3.common.callbacks import (
EvalCallback,
EveryNTimesteps,
StopTrainingOnMaxEpisodes,
StopTrainingOnNoModelImprovement,
StopTrainingOnRewardThreshold,
)
from stable_baselines3.common.env_util import make_vec_env
@ -35,9 +36,13 @@ def test_callbacks(tmp_path, model_class):
# Stop training if the performance is good enough
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-1200, verbose=1)
# Stop training if there is no model improvement after 2 evaluations
callback_no_model_improvement = StopTrainingOnNoModelImprovement(max_no_improvement_evals=2, min_evals=1, verbose=1)
eval_callback = EvalCallback(
eval_env,
callback_on_new_best=callback_on_best,
callback_after_eval=callback_no_model_improvement,
best_model_save_path=log_folder,
log_path=log_folder,
eval_freq=100,