diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 5935f50..818f282 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -209,8 +209,8 @@ These dictionaries are randomly initialized on the creation of the environment a model.learn(total_timesteps=100_000) -Using Callback: Monitoring Training ------------------------------------ +Callbacks: Monitoring Training +------------------------------ .. note:: @@ -308,6 +308,49 @@ If your callback returns False, training is aborted early. plt.show() +Callbacks: Evaluate Agent Performance +------------------------------------- +To periodically evaluate an agent's performance on a separate test environment, use ``EvalCallback``. +You can control the evaluation frequency with ``eval_freq`` to monitor your agent's progress during training. + +.. code-block:: python + + import os + import gymnasium as gym + + from stable_baselines3 import SAC + from stable_baselines3.common.callbacks import EvalCallback + from stable-baselines3.common.env_util import make_vec_env + + env_id = "Pendulum-v1" + n_training_envs = 1 + n_eval_envs = 5 + + # Create log dir where evaluation results will be saved + eval_log_dir = "./eval_logs/" + os.makedirs(eval_log_dir, exist_ok=True) + + # Initialize a vectorized training environment with default parameters + train_env = make_vec_env(env_id, n_env=n_training_envs, seed=0) + + # Separate evaluation env, with different parameters passed via env_kwargs + # Eval environments can be vectorized to speed up evaluation. + eval_env = make_vec_env(env_id, n_envs=n_eval_envs, seed=0, + env_kwargs={'g':0.7}) + + # Create callback that evaluates agent for 5 episodes every 500 training environment steps. + # When using multiple training environments, agent will be evaluated every + # eval_freq calls to train_env.step(), thus it will be evaluated every + # (eval_freq * n_envs) training steps. See EvalCallback doc for more information. + eval_callback = EvalCallback(eval_env, best_model_save_path=eval_log_dir, + log_path=eval_log_dir, eval_freq=max(500 // n_training_envs, 1), + n_eval_episodes=5, deterministic=True, + render=False) + + model = SAC("MlpPolicy", train_env) + model.learn(5000, callback=eval_callback) + + Atari Games ----------- diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 9492bd7..c08a923 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -61,6 +61,7 @@ Documentation: - Upgraded tutorials to Gymnasium API - Make it more explicit when using ``VecEnv`` vs Gym env - Added UAV_Navigation_DRL_AirSim to the project page (@heleidsn) +- Added ``EvalCallback`` example (@sidney-tio) Release 1.8.0 (2023-04-07)