mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-03 03:59:13 +00:00
Callback to early stop the training if there is no model improvement after consecutive evaluations (#741)
* Added StopTrainingOnNoModelImprovement callback and callback_after_eval parameter in EvalCallback * Correction in EvalCallback and tests for StopTrainingOnNoModelImprovement * Update the docs related to new StopTrainingOnNoModelImprovement callback * Update doc Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
This commit is contained in:
parent
db5366fb51
commit
cdaa9ab418
4 changed files with 105 additions and 7 deletions
|
|
@ -189,7 +189,7 @@ It will save the best model if ``best_model_save_path`` folder is specified and
|
|||
|
||||
.. note::
|
||||
|
||||
You can pass a child callback via the ``callback_on_new_best`` argument. It will be triggered each time there is a new best model.
|
||||
You can pass child callbacks via ``callback_after_eval`` and ``callback_on_new_best`` arguments. ``callback_after_eval`` will be triggered after every evaluation, and ``callback_on_new_best`` will be triggered each time there is a new best model.
|
||||
|
||||
|
||||
.. warning::
|
||||
|
|
@ -333,6 +333,36 @@ and in total for ``max_episodes * n_envs`` episodes.
|
|||
# early as soon as the max number of episodes is reached
|
||||
model.learn(int(1e10), callback=callback_max_episodes)
|
||||
|
||||
.. _StopTrainingOnNoModelImprovement:
|
||||
|
||||
StopTrainingOnNoModelImprovement
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Stop the training if there is no new best model (no new best mean reward) after more than a specific number of consecutive evaluations.
|
||||
The idea is to save time in experiments when you know that the learning curves are somehow well behaved and, therefore,
|
||||
after many evaluations without improvement the learning has probably stabilized.
|
||||
It must be used with the :ref:`EvalCallback` and use the event triggered after every evaluation.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import gym
|
||||
|
||||
from stable_baselines3 import SAC
|
||||
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnNoModelImprovement
|
||||
|
||||
# Separate evaluation env
|
||||
eval_env = gym.make("Pendulum-v1")
|
||||
# Stop training if there is no improvement after more than 3 evaluations
|
||||
stop_train_callback = StopTrainingOnNoModelImprovement(max_no_improvement_evals=3, min_evals=5, verbose=1)
|
||||
eval_callback = EvalCallback(eval_env, eval_freq=1000, callback_after_eval=stop_train_callback, verbose=1)
|
||||
|
||||
model = SAC("MlpPolicy", "Pendulum-v1", learning_rate=1e-3, verbose=1)
|
||||
# Almost infinite number of timesteps, but the training will stop early
|
||||
# as soon as the the number of consecutive evaluations without model
|
||||
# improvement is greater than 3
|
||||
model.learn(int(1e10), callback=eval_callback)
|
||||
|
||||
|
||||
.. automodule:: stable_baselines3.common.callbacks
|
||||
:members:
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ Breaking Changes:
|
|||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Added ``StopTrainingOnNoModelImprovement`` to callback collection (@caburu)
|
||||
- Makes the length of keys and values in ``HumanOutputFormat`` configurable,
|
||||
depending on desired maximum width of output.
|
||||
- Allow PPO to turn of advantage normalization (see `PR #763 <https://github.com/DLR-RM/stable-baselines3/pull/763>`_) @vwxyzjn
|
||||
|
|
@ -925,4 +926,4 @@ And all the contributors:
|
|||
@benblack769 @bstee615 @c-rizz @skandermoalla @MihaiAnca13 @davidblom603 @ayeright @cyprienc
|
||||
@wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum
|
||||
@eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP
|
||||
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99
|
||||
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu
|
||||
|
|
|
|||
|
|
@ -278,6 +278,7 @@ class EvalCallback(EventCallback):
|
|||
:param eval_env: The environment used for initialization
|
||||
:param callback_on_new_best: Callback to trigger
|
||||
when there is a new best model according to the ``mean_reward``
|
||||
:param callback_after_eval: Callback to trigger after every evaluation
|
||||
:param n_eval_episodes: The number of episodes to test the agent
|
||||
:param eval_freq: Evaluate the agent every ``eval_freq`` call of the callback.
|
||||
:param log_path: Path to a folder where the evaluations (``evaluations.npz``)
|
||||
|
|
@ -296,6 +297,7 @@ class EvalCallback(EventCallback):
|
|||
self,
|
||||
eval_env: Union[gym.Env, VecEnv],
|
||||
callback_on_new_best: Optional[BaseCallback] = None,
|
||||
callback_after_eval: Optional[BaseCallback] = None,
|
||||
n_eval_episodes: int = 5,
|
||||
eval_freq: int = 10000,
|
||||
log_path: Optional[str] = None,
|
||||
|
|
@ -305,7 +307,13 @@ class EvalCallback(EventCallback):
|
|||
verbose: int = 1,
|
||||
warn: bool = True,
|
||||
):
|
||||
super(EvalCallback, self).__init__(callback_on_new_best, verbose=verbose)
|
||||
super(EvalCallback, self).__init__(callback_after_eval, verbose=verbose)
|
||||
|
||||
self.callback_on_new_best = callback_on_new_best
|
||||
if self.callback_on_new_best is not None:
|
||||
# Give access to the parent
|
||||
self.callback_on_new_best.parent = self
|
||||
|
||||
self.n_eval_episodes = n_eval_episodes
|
||||
self.eval_freq = eval_freq
|
||||
self.best_mean_reward = -np.inf
|
||||
|
|
@ -342,6 +350,10 @@ class EvalCallback(EventCallback):
|
|||
if self.log_path is not None:
|
||||
os.makedirs(os.path.dirname(self.log_path), exist_ok=True)
|
||||
|
||||
# Init callback called on new best model
|
||||
if self.callback_on_new_best is not None:
|
||||
self.callback_on_new_best.init_callback(self.model)
|
||||
|
||||
def _log_success_callback(self, locals_: Dict[str, Any], globals_: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Callback passed to the ``evaluate_policy`` function
|
||||
|
|
@ -360,7 +372,10 @@ class EvalCallback(EventCallback):
|
|||
|
||||
def _on_step(self) -> bool:
|
||||
|
||||
continue_training = True
|
||||
|
||||
if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
|
||||
|
||||
# Sync training and eval env if there is VecNormalize
|
||||
if self.model.get_vec_normalize_env() is not None:
|
||||
try:
|
||||
|
|
@ -432,11 +447,15 @@ class EvalCallback(EventCallback):
|
|||
if self.best_model_save_path is not None:
|
||||
self.model.save(os.path.join(self.best_model_save_path, "best_model"))
|
||||
self.best_mean_reward = mean_reward
|
||||
# Trigger callback if needed
|
||||
if self.callback is not None:
|
||||
return self._on_event()
|
||||
# Trigger callback on new best model, if needed
|
||||
if self.callback_on_new_best is not None:
|
||||
continue_training = self.callback_on_new_best.on_step()
|
||||
|
||||
return True
|
||||
# Trigger callback after every evaluation, if needed
|
||||
if self.callback is not None:
|
||||
continue_training = continue_training and self._on_event()
|
||||
|
||||
return continue_training
|
||||
|
||||
def update_child_locals(self, locals_: Dict[str, Any]) -> None:
|
||||
"""
|
||||
|
|
@ -538,3 +557,46 @@ class StopTrainingOnMaxEpisodes(BaseCallback):
|
|||
f"{mean_ep_str}"
|
||||
)
|
||||
return continue_training
|
||||
|
||||
|
||||
class StopTrainingOnNoModelImprovement(BaseCallback):
|
||||
"""
|
||||
Stop the training early if there is no new best model (new best mean reward) after more than N consecutive evaluations.
|
||||
|
||||
It is possible to define a minimum number of evaluations before start to count evaluations without improvement.
|
||||
|
||||
It must be used with the ``EvalCallback``.
|
||||
|
||||
:param max_no_improvement_evals: Maximum number of consecutive evaluations without a new best model.
|
||||
:param min_evals: Number of evaluations before start to count evaluations without improvements.
|
||||
:param verbose: Verbosity of the output (set to 1 for info messages)
|
||||
"""
|
||||
|
||||
def __init__(self, max_no_improvement_evals: int, min_evals: int = 0, verbose: int = 0):
|
||||
super(StopTrainingOnNoModelImprovement, self).__init__(verbose=verbose)
|
||||
self.max_no_improvement_evals = max_no_improvement_evals
|
||||
self.min_evals = min_evals
|
||||
self.last_best_mean_reward = -np.inf
|
||||
self.no_improvement_evals = 0
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
assert self.parent is not None, "``StopTrainingOnNoModelImprovement`` callback must be used with an ``EvalCallback``"
|
||||
|
||||
continue_training = True
|
||||
|
||||
if self.n_calls > self.min_evals:
|
||||
if self.parent.best_mean_reward > self.last_best_mean_reward:
|
||||
self.no_improvement_evals = 0
|
||||
else:
|
||||
self.no_improvement_evals += 1
|
||||
if self.no_improvement_evals > self.max_no_improvement_evals:
|
||||
continue_training = False
|
||||
|
||||
self.last_best_mean_reward = self.parent.best_mean_reward
|
||||
|
||||
if self.verbose > 0 and not continue_training:
|
||||
print(
|
||||
f"Stopping training because there was no new best model in the last {self.no_improvement_evals:d} evaluations"
|
||||
)
|
||||
|
||||
return continue_training
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from stable_baselines3.common.callbacks import (
|
|||
EvalCallback,
|
||||
EveryNTimesteps,
|
||||
StopTrainingOnMaxEpisodes,
|
||||
StopTrainingOnNoModelImprovement,
|
||||
StopTrainingOnRewardThreshold,
|
||||
)
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
|
|
@ -35,9 +36,13 @@ def test_callbacks(tmp_path, model_class):
|
|||
# Stop training if the performance is good enough
|
||||
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-1200, verbose=1)
|
||||
|
||||
# Stop training if there is no model improvement after 2 evaluations
|
||||
callback_no_model_improvement = StopTrainingOnNoModelImprovement(max_no_improvement_evals=2, min_evals=1, verbose=1)
|
||||
|
||||
eval_callback = EvalCallback(
|
||||
eval_env,
|
||||
callback_on_new_best=callback_on_best,
|
||||
callback_after_eval=callback_no_model_improvement,
|
||||
best_model_save_path=log_folder,
|
||||
log_path=log_folder,
|
||||
eval_freq=100,
|
||||
|
|
|
|||
Loading…
Reference in a new issue