From 6cbb2c9303037395a81f3e5ac1d344d77b5b51b1 Mon Sep 17 00:00:00 2001 From: Tobias Rohrer Date: Thu, 27 Apr 2023 18:35:33 +0200 Subject: [PATCH 1/3] Fix DQN target update interval for multi-env (#1463) * Calculating target update interval per environment in `_on_step()`. See GitHub issue #1373 * Added changelog entry and changed test comment * Added requested changes from code review * Update version --------- Co-authored-by: Antonin RAFFIN --- docs/misc/changelog.rst | 3 ++- stable_baselines3/dqn/dqn.py | 9 ++++----- stable_baselines3/version.txt | 2 +- tests/test_save_load.py | 12 ++++++++++++ 4 files changed, 19 insertions(+), 7 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 6fc6dfa..cfe2819 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.0.0a5 (WIP) +Release 2.0.0a6 (WIP) -------------------------- **Gymnasium support** @@ -35,6 +35,7 @@ New Features: Bug Fixes: ^^^^^^^^^^ - Fixed ``VecExtractDictObs`` does not handle terminal observation (@WeberSamuel) +- Fixed loading DQN changes ``target_update_interval`` (@tobirohrer) Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index b85a30f..bd3ad4b 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -151,8 +151,7 @@ class DQN(OffPolicyAlgorithm): self.exploration_final_eps, self.exploration_fraction, ) - # Account for multiple environments - # each call to step() corresponds to n_envs transitions + if self.n_envs > 1: if self.n_envs > self.target_update_interval: warnings.warn( @@ -162,8 +161,6 @@ class DQN(OffPolicyAlgorithm): f"which corresponds to {self.n_envs} steps." ) - self.target_update_interval = max(self.target_update_interval // self.n_envs, 1) - def _create_aliases(self) -> None: self.q_net = self.policy.q_net self.q_net_target = self.policy.q_net_target @@ -174,7 +171,9 @@ class DQN(OffPolicyAlgorithm): This method is called in ``collect_rollouts()`` after each step in the environment. """ self._n_calls += 1 - if self._n_calls % self.target_update_interval == 0: + # Account for multiple environments + # each call to step() corresponds to n_envs transitions + if self._n_calls % max(self.target_update_interval // self.n_envs, 1) == 0: polyak_update(self.q_net.parameters(), self.q_net_target.parameters(), self.tau) # Copy running stats, see GH issue #996 polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 0691e44..758664d 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.0.0a5 +2.0.0a6 diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 2f227ad..9d0aa44 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -15,6 +15,7 @@ import torch as th from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 from stable_baselines3.common.base_class import BaseAlgorithm +from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox from stable_baselines3.common.save_util import load_from_pkl, open_path, save_to_pkl from stable_baselines3.common.utils import get_device @@ -730,3 +731,14 @@ def test_load_invalid_object(tmp_path): with warnings.catch_warnings(record=True) as record: PPO.load(path, custom_objects=dict(learning_rate=lambda _: 1.0)) assert len(record) == 0 + + +def test_dqn_target_update_interval(tmp_path): + # `target_update_interval` should not change when reloading the model. See GH Issue #1373. + env = make_vec_env(env_id="CartPole-v1", n_envs=2) + model = DQN("MlpPolicy", env, verbose=1, target_update_interval=100) + model.save(tmp_path / "dqn_cartpole") + del model + model = DQN.load(tmp_path / "dqn_cartpole") + os.remove(tmp_path / "dqn_cartpole.zip") + assert model.target_update_interval == 100 From 4f9805eeb8d4ee888b0967f05f142f72c97e1867 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=9Cbertreiber?= Date: Thu, 27 Apr 2023 19:07:53 +0200 Subject: [PATCH 2/3] Fix overly relaxed version requirement on NumPy (#1472) Since commit 489b1fda, this package has been using `numpy.typing.DTypeLike`, which was only added in [NumPy 1.20][1]. [1]: https://numpy.org/doc/stable/release/1.20.0-notes.html#numpy-is-now-typed Co-authored-by: troiganto Co-authored-by: Antonin RAFFIN --- docs/misc/changelog.rst | 3 ++- setup.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index cfe2819..9492bd7 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -35,6 +35,7 @@ New Features: Bug Fixes: ^^^^^^^^^^ - Fixed ``VecExtractDictObs`` does not handle terminal observation (@WeberSamuel) +- Set NumPy version to ``>=1.20`` due to use of ``numpy.typing`` (@troiganto) - Fixed loading DQN changes ``target_update_interval`` (@tobirohrer) Deprecations: @@ -1331,4 +1332,4 @@ And all the contributors: @carlosluis @arjun-kg @tlpss @Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong -@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel +@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto diff --git a/setup.py b/setup.py index a7a3fcc..d72e77d 100644 --- a/setup.py +++ b/setup.py @@ -103,7 +103,7 @@ setup( package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ "gymnasium==0.28.1", - "numpy", + "numpy>=1.20", "torch>=1.11", 'typing_extensions>=4.0,<5; python_version < "3.8.0"', # For saving models From d6ddee9366fe8fc2c8fc5997371301ff85aac36c Mon Sep 17 00:00:00 2001 From: Sidney Tio <35787241+sidney-tio@users.noreply.github.com> Date: Wed, 3 May 2023 00:02:36 +0800 Subject: [PATCH 3/3] Add evalcallback example (#1468) * Moved 'Monitoring Training' to subsubsection of 'Using callbacks' * Added EvalCallback example * Updated Changelogs * Edited the language * Moved subsection headers up one level * added make_vec_env into Evalcallback example * Added parameters to the top for readability * Added note on multiple training environments * Added more clarity to eval_freq note * Apply suggestions from code review --------- Co-authored-by: Antonin RAFFIN --- docs/guide/examples.rst | 47 +++++++++++++++++++++++++++++++++++++++-- docs/misc/changelog.rst | 1 + 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 5935f50..818f282 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -209,8 +209,8 @@ These dictionaries are randomly initialized on the creation of the environment a model.learn(total_timesteps=100_000) -Using Callback: Monitoring Training ------------------------------------ +Callbacks: Monitoring Training +------------------------------ .. note:: @@ -308,6 +308,49 @@ If your callback returns False, training is aborted early. plt.show() +Callbacks: Evaluate Agent Performance +------------------------------------- +To periodically evaluate an agent's performance on a separate test environment, use ``EvalCallback``. +You can control the evaluation frequency with ``eval_freq`` to monitor your agent's progress during training. + +.. code-block:: python + + import os + import gymnasium as gym + + from stable_baselines3 import SAC + from stable_baselines3.common.callbacks import EvalCallback + from stable-baselines3.common.env_util import make_vec_env + + env_id = "Pendulum-v1" + n_training_envs = 1 + n_eval_envs = 5 + + # Create log dir where evaluation results will be saved + eval_log_dir = "./eval_logs/" + os.makedirs(eval_log_dir, exist_ok=True) + + # Initialize a vectorized training environment with default parameters + train_env = make_vec_env(env_id, n_env=n_training_envs, seed=0) + + # Separate evaluation env, with different parameters passed via env_kwargs + # Eval environments can be vectorized to speed up evaluation. + eval_env = make_vec_env(env_id, n_envs=n_eval_envs, seed=0, + env_kwargs={'g':0.7}) + + # Create callback that evaluates agent for 5 episodes every 500 training environment steps. + # When using multiple training environments, agent will be evaluated every + # eval_freq calls to train_env.step(), thus it will be evaluated every + # (eval_freq * n_envs) training steps. See EvalCallback doc for more information. + eval_callback = EvalCallback(eval_env, best_model_save_path=eval_log_dir, + log_path=eval_log_dir, eval_freq=max(500 // n_training_envs, 1), + n_eval_episodes=5, deterministic=True, + render=False) + + model = SAC("MlpPolicy", train_env) + model.learn(5000, callback=eval_callback) + + Atari Games ----------- diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 9492bd7..c08a923 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -61,6 +61,7 @@ Documentation: - Upgraded tutorials to Gymnasium API - Make it more explicit when using ``VecEnv`` vs Gym env - Added UAV_Navigation_DRL_AirSim to the project page (@heleidsn) +- Added ``EvalCallback`` example (@sidney-tio) Release 1.8.0 (2023-04-07)