mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-01 03:45:11 +00:00
Merge branch 'master' into feat/mps-support
This commit is contained in:
commit
086f79ab53
6 changed files with 68 additions and 11 deletions
|
|
@ -209,8 +209,8 @@ These dictionaries are randomly initialized on the creation of the environment a
|
|||
model.learn(total_timesteps=100_000)
|
||||
|
||||
|
||||
Using Callback: Monitoring Training
|
||||
-----------------------------------
|
||||
Callbacks: Monitoring Training
|
||||
------------------------------
|
||||
|
||||
.. note::
|
||||
|
||||
|
|
@ -308,6 +308,49 @@ If your callback returns False, training is aborted early.
|
|||
plt.show()
|
||||
|
||||
|
||||
Callbacks: Evaluate Agent Performance
|
||||
-------------------------------------
|
||||
To periodically evaluate an agent's performance on a separate test environment, use ``EvalCallback``.
|
||||
You can control the evaluation frequency with ``eval_freq`` to monitor your agent's progress during training.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import os
|
||||
import gymnasium as gym
|
||||
|
||||
from stable_baselines3 import SAC
|
||||
from stable_baselines3.common.callbacks import EvalCallback
|
||||
from stable-baselines3.common.env_util import make_vec_env
|
||||
|
||||
env_id = "Pendulum-v1"
|
||||
n_training_envs = 1
|
||||
n_eval_envs = 5
|
||||
|
||||
# Create log dir where evaluation results will be saved
|
||||
eval_log_dir = "./eval_logs/"
|
||||
os.makedirs(eval_log_dir, exist_ok=True)
|
||||
|
||||
# Initialize a vectorized training environment with default parameters
|
||||
train_env = make_vec_env(env_id, n_env=n_training_envs, seed=0)
|
||||
|
||||
# Separate evaluation env, with different parameters passed via env_kwargs
|
||||
# Eval environments can be vectorized to speed up evaluation.
|
||||
eval_env = make_vec_env(env_id, n_envs=n_eval_envs, seed=0,
|
||||
env_kwargs={'g':0.7})
|
||||
|
||||
# Create callback that evaluates agent for 5 episodes every 500 training environment steps.
|
||||
# When using multiple training environments, agent will be evaluated every
|
||||
# eval_freq calls to train_env.step(), thus it will be evaluated every
|
||||
# (eval_freq * n_envs) training steps. See EvalCallback doc for more information.
|
||||
eval_callback = EvalCallback(eval_env, best_model_save_path=eval_log_dir,
|
||||
log_path=eval_log_dir, eval_freq=max(500 // n_training_envs, 1),
|
||||
n_eval_episodes=5, deterministic=True,
|
||||
render=False)
|
||||
|
||||
model = SAC("MlpPolicy", train_env)
|
||||
model.learn(5000, callback=eval_callback)
|
||||
|
||||
|
||||
Atari Games
|
||||
-----------
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 2.0.0a5 (WIP)
|
||||
Release 2.0.0a6 (WIP)
|
||||
--------------------------
|
||||
|
||||
**Gymnasium support**
|
||||
|
|
@ -35,6 +35,8 @@ New Features:
|
|||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fixed ``VecExtractDictObs`` does not handle terminal observation (@WeberSamuel)
|
||||
- Set NumPy version to ``>=1.20`` due to use of ``numpy.typing`` (@troiganto)
|
||||
- Fixed loading DQN changes ``target_update_interval`` (@tobirohrer)
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -59,6 +61,7 @@ Documentation:
|
|||
- Upgraded tutorials to Gymnasium API
|
||||
- Make it more explicit when using ``VecEnv`` vs Gym env
|
||||
- Added UAV_Navigation_DRL_AirSim to the project page (@heleidsn)
|
||||
- Added ``EvalCallback`` example (@sidney-tio)
|
||||
|
||||
|
||||
Release 1.8.0 (2023-04-07)
|
||||
|
|
@ -1332,4 +1335,4 @@ And all the contributors:
|
|||
@carlosluis @arjun-kg @tlpss
|
||||
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875
|
||||
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong
|
||||
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel
|
||||
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -103,7 +103,7 @@ setup(
|
|||
package_data={"stable_baselines3": ["py.typed", "version.txt"]},
|
||||
install_requires=[
|
||||
"gymnasium==0.28.1",
|
||||
"numpy",
|
||||
"numpy>=1.20",
|
||||
"torch>=1.11",
|
||||
'typing_extensions>=4.0,<5; python_version < "3.8.0"',
|
||||
# For saving models
|
||||
|
|
|
|||
|
|
@ -151,8 +151,7 @@ class DQN(OffPolicyAlgorithm):
|
|||
self.exploration_final_eps,
|
||||
self.exploration_fraction,
|
||||
)
|
||||
# Account for multiple environments
|
||||
# each call to step() corresponds to n_envs transitions
|
||||
|
||||
if self.n_envs > 1:
|
||||
if self.n_envs > self.target_update_interval:
|
||||
warnings.warn(
|
||||
|
|
@ -162,8 +161,6 @@ class DQN(OffPolicyAlgorithm):
|
|||
f"which corresponds to {self.n_envs} steps."
|
||||
)
|
||||
|
||||
self.target_update_interval = max(self.target_update_interval // self.n_envs, 1)
|
||||
|
||||
def _create_aliases(self) -> None:
|
||||
self.q_net = self.policy.q_net
|
||||
self.q_net_target = self.policy.q_net_target
|
||||
|
|
@ -174,7 +171,9 @@ class DQN(OffPolicyAlgorithm):
|
|||
This method is called in ``collect_rollouts()`` after each step in the environment.
|
||||
"""
|
||||
self._n_calls += 1
|
||||
if self._n_calls % self.target_update_interval == 0:
|
||||
# Account for multiple environments
|
||||
# each call to step() corresponds to n_envs transitions
|
||||
if self._n_calls % max(self.target_update_interval // self.n_envs, 1) == 0:
|
||||
polyak_update(self.q_net.parameters(), self.q_net_target.parameters(), self.tau)
|
||||
# Copy running stats, see GH issue #996
|
||||
polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
2.0.0a5
|
||||
2.0.0a6
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ import torch as th
|
|||
|
||||
from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
|
||||
from stable_baselines3.common.base_class import BaseAlgorithm
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox
|
||||
from stable_baselines3.common.save_util import load_from_pkl, open_path, save_to_pkl
|
||||
from stable_baselines3.common.utils import get_device
|
||||
|
|
@ -730,3 +731,14 @@ def test_load_invalid_object(tmp_path):
|
|||
with warnings.catch_warnings(record=True) as record:
|
||||
PPO.load(path, custom_objects=dict(learning_rate=lambda _: 1.0))
|
||||
assert len(record) == 0
|
||||
|
||||
|
||||
def test_dqn_target_update_interval(tmp_path):
|
||||
# `target_update_interval` should not change when reloading the model. See GH Issue #1373.
|
||||
env = make_vec_env(env_id="CartPole-v1", n_envs=2)
|
||||
model = DQN("MlpPolicy", env, verbose=1, target_update_interval=100)
|
||||
model.save(tmp_path / "dqn_cartpole")
|
||||
del model
|
||||
model = DQN.load(tmp_path / "dqn_cartpole")
|
||||
os.remove(tmp_path / "dqn_cartpole.zip")
|
||||
assert model.target_update_interval == 100
|
||||
|
|
|
|||
Loading…
Reference in a new issue