Merge branch 'master' into feat/mps-support

This commit is contained in:
Antonin RAFFIN 2023-05-03 10:31:22 +02:00 committed by GitHub
commit 086f79ab53
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 68 additions and 11 deletions

View file

@ -209,8 +209,8 @@ These dictionaries are randomly initialized on the creation of the environment a
model.learn(total_timesteps=100_000)
Using Callback: Monitoring Training
-----------------------------------
Callbacks: Monitoring Training
------------------------------
.. note::
@ -308,6 +308,49 @@ If your callback returns False, training is aborted early.
plt.show()
Callbacks: Evaluate Agent Performance
-------------------------------------
To periodically evaluate an agent's performance on a separate test environment, use ``EvalCallback``.
You can control the evaluation frequency with ``eval_freq`` to monitor your agent's progress during training.
.. code-block:: python
import os
import gymnasium as gym
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import EvalCallback
from stable-baselines3.common.env_util import make_vec_env
env_id = "Pendulum-v1"
n_training_envs = 1
n_eval_envs = 5
# Create log dir where evaluation results will be saved
eval_log_dir = "./eval_logs/"
os.makedirs(eval_log_dir, exist_ok=True)
# Initialize a vectorized training environment with default parameters
train_env = make_vec_env(env_id, n_env=n_training_envs, seed=0)
# Separate evaluation env, with different parameters passed via env_kwargs
# Eval environments can be vectorized to speed up evaluation.
eval_env = make_vec_env(env_id, n_envs=n_eval_envs, seed=0,
env_kwargs={'g':0.7})
# Create callback that evaluates agent for 5 episodes every 500 training environment steps.
# When using multiple training environments, agent will be evaluated every
# eval_freq calls to train_env.step(), thus it will be evaluated every
# (eval_freq * n_envs) training steps. See EvalCallback doc for more information.
eval_callback = EvalCallback(eval_env, best_model_save_path=eval_log_dir,
log_path=eval_log_dir, eval_freq=max(500 // n_training_envs, 1),
n_eval_episodes=5, deterministic=True,
render=False)
model = SAC("MlpPolicy", train_env)
model.learn(5000, callback=eval_callback)
Atari Games
-----------

View file

@ -3,7 +3,7 @@
Changelog
==========
Release 2.0.0a5 (WIP)
Release 2.0.0a6 (WIP)
--------------------------
**Gymnasium support**
@ -35,6 +35,8 @@ New Features:
Bug Fixes:
^^^^^^^^^^
- Fixed ``VecExtractDictObs`` does not handle terminal observation (@WeberSamuel)
- Set NumPy version to ``>=1.20`` due to use of ``numpy.typing`` (@troiganto)
- Fixed loading DQN changes ``target_update_interval`` (@tobirohrer)
Deprecations:
^^^^^^^^^^^^^
@ -59,6 +61,7 @@ Documentation:
- Upgraded tutorials to Gymnasium API
- Make it more explicit when using ``VecEnv`` vs Gym env
- Added UAV_Navigation_DRL_AirSim to the project page (@heleidsn)
- Added ``EvalCallback`` example (@sidney-tio)
Release 1.8.0 (2023-04-07)
@ -1332,4 +1335,4 @@ And all the contributors:
@carlosluis @arjun-kg @tlpss
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto

View file

@ -103,7 +103,7 @@ setup(
package_data={"stable_baselines3": ["py.typed", "version.txt"]},
install_requires=[
"gymnasium==0.28.1",
"numpy",
"numpy>=1.20",
"torch>=1.11",
'typing_extensions>=4.0,<5; python_version < "3.8.0"',
# For saving models

View file

@ -151,8 +151,7 @@ class DQN(OffPolicyAlgorithm):
self.exploration_final_eps,
self.exploration_fraction,
)
# Account for multiple environments
# each call to step() corresponds to n_envs transitions
if self.n_envs > 1:
if self.n_envs > self.target_update_interval:
warnings.warn(
@ -162,8 +161,6 @@ class DQN(OffPolicyAlgorithm):
f"which corresponds to {self.n_envs} steps."
)
self.target_update_interval = max(self.target_update_interval // self.n_envs, 1)
def _create_aliases(self) -> None:
self.q_net = self.policy.q_net
self.q_net_target = self.policy.q_net_target
@ -174,7 +171,9 @@ class DQN(OffPolicyAlgorithm):
This method is called in ``collect_rollouts()`` after each step in the environment.
"""
self._n_calls += 1
if self._n_calls % self.target_update_interval == 0:
# Account for multiple environments
# each call to step() corresponds to n_envs transitions
if self._n_calls % max(self.target_update_interval // self.n_envs, 1) == 0:
polyak_update(self.q_net.parameters(), self.q_net_target.parameters(), self.tau)
# Copy running stats, see GH issue #996
polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)

View file

@ -1 +1 @@
2.0.0a5
2.0.0a6

View file

@ -15,6 +15,7 @@ import torch as th
from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox
from stable_baselines3.common.save_util import load_from_pkl, open_path, save_to_pkl
from stable_baselines3.common.utils import get_device
@ -730,3 +731,14 @@ def test_load_invalid_object(tmp_path):
with warnings.catch_warnings(record=True) as record:
PPO.load(path, custom_objects=dict(learning_rate=lambda _: 1.0))
assert len(record) == 0
def test_dqn_target_update_interval(tmp_path):
# `target_update_interval` should not change when reloading the model. See GH Issue #1373.
env = make_vec_env(env_id="CartPole-v1", n_envs=2)
model = DQN("MlpPolicy", env, verbose=1, target_update_interval=100)
model.save(tmp_path / "dqn_cartpole")
del model
model = DQN.load(tmp_path / "dqn_cartpole")
os.remove(tmp_path / "dqn_cartpole.zip")
assert model.target_update_interval == 100