diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 5935f50..818f282 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -209,8 +209,8 @@ These dictionaries are randomly initialized on the creation of the environment a model.learn(total_timesteps=100_000) -Using Callback: Monitoring Training ------------------------------------ +Callbacks: Monitoring Training +------------------------------ .. note:: @@ -308,6 +308,49 @@ If your callback returns False, training is aborted early. plt.show() +Callbacks: Evaluate Agent Performance +------------------------------------- +To periodically evaluate an agent's performance on a separate test environment, use ``EvalCallback``. +You can control the evaluation frequency with ``eval_freq`` to monitor your agent's progress during training. + +.. code-block:: python + + import os + import gymnasium as gym + + from stable_baselines3 import SAC + from stable_baselines3.common.callbacks import EvalCallback + from stable-baselines3.common.env_util import make_vec_env + + env_id = "Pendulum-v1" + n_training_envs = 1 + n_eval_envs = 5 + + # Create log dir where evaluation results will be saved + eval_log_dir = "./eval_logs/" + os.makedirs(eval_log_dir, exist_ok=True) + + # Initialize a vectorized training environment with default parameters + train_env = make_vec_env(env_id, n_env=n_training_envs, seed=0) + + # Separate evaluation env, with different parameters passed via env_kwargs + # Eval environments can be vectorized to speed up evaluation. + eval_env = make_vec_env(env_id, n_envs=n_eval_envs, seed=0, + env_kwargs={'g':0.7}) + + # Create callback that evaluates agent for 5 episodes every 500 training environment steps. + # When using multiple training environments, agent will be evaluated every + # eval_freq calls to train_env.step(), thus it will be evaluated every + # (eval_freq * n_envs) training steps. See EvalCallback doc for more information. + eval_callback = EvalCallback(eval_env, best_model_save_path=eval_log_dir, + log_path=eval_log_dir, eval_freq=max(500 // n_training_envs, 1), + n_eval_episodes=5, deterministic=True, + render=False) + + model = SAC("MlpPolicy", train_env) + model.learn(5000, callback=eval_callback) + + Atari Games ----------- diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index a47118d..b23f764 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.0.0a5 (WIP) +Release 2.0.0a6 (WIP) -------------------------- **Gymnasium support** @@ -35,6 +35,8 @@ New Features: Bug Fixes: ^^^^^^^^^^ - Fixed ``VecExtractDictObs`` does not handle terminal observation (@WeberSamuel) +- Set NumPy version to ``>=1.20`` due to use of ``numpy.typing`` (@troiganto) +- Fixed loading DQN changes ``target_update_interval`` (@tobirohrer) Deprecations: ^^^^^^^^^^^^^ @@ -59,6 +61,7 @@ Documentation: - Upgraded tutorials to Gymnasium API - Make it more explicit when using ``VecEnv`` vs Gym env - Added UAV_Navigation_DRL_AirSim to the project page (@heleidsn) +- Added ``EvalCallback`` example (@sidney-tio) Release 1.8.0 (2023-04-07) @@ -1332,4 +1335,4 @@ And all the contributors: @carlosluis @arjun-kg @tlpss @Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong -@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel +@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto diff --git a/setup.py b/setup.py index a7a3fcc..d72e77d 100644 --- a/setup.py +++ b/setup.py @@ -103,7 +103,7 @@ setup( package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ "gymnasium==0.28.1", - "numpy", + "numpy>=1.20", "torch>=1.11", 'typing_extensions>=4.0,<5; python_version < "3.8.0"', # For saving models diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index b85a30f..bd3ad4b 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -151,8 +151,7 @@ class DQN(OffPolicyAlgorithm): self.exploration_final_eps, self.exploration_fraction, ) - # Account for multiple environments - # each call to step() corresponds to n_envs transitions + if self.n_envs > 1: if self.n_envs > self.target_update_interval: warnings.warn( @@ -162,8 +161,6 @@ class DQN(OffPolicyAlgorithm): f"which corresponds to {self.n_envs} steps." ) - self.target_update_interval = max(self.target_update_interval // self.n_envs, 1) - def _create_aliases(self) -> None: self.q_net = self.policy.q_net self.q_net_target = self.policy.q_net_target @@ -174,7 +171,9 @@ class DQN(OffPolicyAlgorithm): This method is called in ``collect_rollouts()`` after each step in the environment. """ self._n_calls += 1 - if self._n_calls % self.target_update_interval == 0: + # Account for multiple environments + # each call to step() corresponds to n_envs transitions + if self._n_calls % max(self.target_update_interval // self.n_envs, 1) == 0: polyak_update(self.q_net.parameters(), self.q_net_target.parameters(), self.tau) # Copy running stats, see GH issue #996 polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 0691e44..758664d 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.0.0a5 +2.0.0a6 diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 2f227ad..9d0aa44 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -15,6 +15,7 @@ import torch as th from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 from stable_baselines3.common.base_class import BaseAlgorithm +from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox from stable_baselines3.common.save_util import load_from_pkl, open_path, save_to_pkl from stable_baselines3.common.utils import get_device @@ -730,3 +731,14 @@ def test_load_invalid_object(tmp_path): with warnings.catch_warnings(record=True) as record: PPO.load(path, custom_objects=dict(learning_rate=lambda _: 1.0)) assert len(record) == 0 + + +def test_dqn_target_update_interval(tmp_path): + # `target_update_interval` should not change when reloading the model. See GH Issue #1373. + env = make_vec_env(env_id="CartPole-v1", n_envs=2) + model = DQN("MlpPolicy", env, verbose=1, target_update_interval=100) + model.save(tmp_path / "dqn_cartpole") + del model + model = DQN.load(tmp_path / "dqn_cartpole") + os.remove(tmp_path / "dqn_cartpole.zip") + assert model.target_update_interval == 100