Update doc: SB3-Contrib (#267)

* Fix big when saving/loading q-net alone

* Rename variables to match SB3-contrib

* Update docker image

* Set min version for tensorboard

* Add SB3-Contrib to doc

* Update DQN

* Apply suggestions from code review

Co-authored-by: Adam Gleave <adam@gleave.me>

* Update wording

Co-authored-by: Adam Gleave <adam@gleave.me>
This commit is contained in:
Antonin RAFFIN 2020-12-21 16:17:24 +01:00 committed by GitHub
parent b8c72a5348
commit 944dfdafe4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 234 additions and 41 deletions

View file

@ -1,4 +1,4 @@
image: stablebaselines/stable-baselines3-cpu:0.9.0a2
image: stablebaselines/stable-baselines3-cpu:0.11.0a4
type-check:
script:

View file

@ -8,7 +8,7 @@
# Stable Baselines3
Stable Baselines3 is a set of improved implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines).
Stable Baselines3 (SB3) is a set of reliable implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines).
You can read a detailed presentation of Stable Baselines in the [Medium article](https://medium.com/@araffin/stable-baselines-a-fork-of-openai-baselines-reinforcement-learning-made-easy-df87c4b2fc82).
@ -50,7 +50,6 @@ A migration guide from SB2 to SB3 can be found in the [documentation](https://st
Documentation is available online: [https://stable-baselines3.readthedocs.io/](https://stable-baselines3.readthedocs.io/)
## RL Baselines3 Zoo: A Collection of Trained RL Agents
[RL Baselines3 Zoo](https://github.com/DLR-RM/rl-baselines3-zoo). is a collection of pre-trained Reinforcement Learning agents using Stable-Baselines3.
@ -68,6 +67,15 @@ Github repo: https://github.com/DLR-RM/rl-baselines3-zoo
Documentation: https://stable-baselines3.readthedocs.io/en/master/guide/rl_zoo.html
## SB3-Contrib: Experimental RL Features
We implement experimental features in a separate contrib repository: [SB3-Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib)
This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Truncated Quantile Critics (TQC) or Quantile Regression DQN (QR-DQN).
Documentation is available online: [https://sb3-contrib.readthedocs.io/](https://sb3-contrib.readthedocs.io/)
## Installation
**Note:** Stable-Baselines3 supports PyTorch 1.4+.

View file

@ -31,6 +31,10 @@ Actions ``gym.spaces``:
- ``MultiBinary``: A list of possible actions, where each timestep any of the actions can be used in any combination.
.. note::
More algorithms (like QR-DQN or TQC) are implemented in our :ref:`contrib repo <sb3_contrib>`.
.. note::
Some logging values (like ``ep_rew_mean``, ``ep_len_mean``) are only available when using a ``Monitor`` wrapper

View file

@ -87,8 +87,6 @@ Looking at the training curve (episode reward function of the timesteps) is a go
We suggest you reading `Deep Reinforcement Learning that Matters <https://arxiv.org/abs/1709.06560>`_ for a good discussion about RL evaluation.
You can also take a look at this `blog post <https://openlab-flowers.inria.fr/t/how-many-random-seeds-should-i-use-statistical-power-analysis-in-deep-reinforcement-learning-experiments/457>`_
@ -122,6 +120,7 @@ Discrete Actions - Single Process
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
DQN with extensions (double DQN, prioritized replay, ...) are the recommended algorithms.
We notably provide QR-DQN in our :ref:`contrib repo <sb3_contrib>`.
DQN is usually slower to train (regarding wall clock time) but is the most sample efficient (because of its replay buffer).
Discrete Actions - Multiprocessed
@ -136,7 +135,7 @@ Continuous Actions
Continuous Actions - Single Process
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Current State Of The Art (SOTA) algorithms are ``SAC`` and ``TD3``.
Current State Of The Art (SOTA) algorithms are ``SAC``, ``TD3`` and ``TQC`` (available in our :ref:`contrib repo <sb3_contrib>`).
Please use the hyperparameters in the `RL zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_ for best results.
@ -156,7 +155,7 @@ Goal Environment
-----------------
If your environment follows the ``GoalEnv`` interface (cf :ref:`HER <her>`), then you should use
HER + (SAC/TD3/DDPG/DQN) depending on the action space.
HER + (SAC/TD3/DDPG/DQN/TQC) depending on the action space.
.. note::

View file

@ -0,0 +1,97 @@
.. _sb3_contrib:
==================
SB3 Contrib
==================
We implement experimental features in a separate contrib repository:
`SB3-Contrib`_
This allows Stable-Baselines3 (SB3) to maintain a stable and compact core, while still
providing the latest features, like Truncated Quantile Critics (TQC) or
Quantile Regression DQN (QR-DQN).
Why create this repository?
~~~~~~~~~~~~~~~~~~~~~~~~~~~
Over the span of stable-baselines and stable-baselines3, the community
has been eager to contribute in form of better logging utilities,
environment wrappers, extended support (e.g. different action spaces)
and learning algorithms.
However sometimes these utilities were too niche to be considered for
stable-baselines or proved to be too difficult to integrate well into
the existing code without creating a mess. sb3-contrib aims to fix this by not
requiring the neatest code integration with existing code and not
setting limits on what is too niche: almost everything remotely useful
goes!
We hope this allows us to provide reliable implementations
following stable-baselines usual standards (consistent style, documentation, etc)
beyond the relatively small scope of utilities in the main repository.
Features
--------
See documentation for the full list of included features.
**RL Algorithms**:
- `Truncated Quantile Critics (TQC)`_
- `Quantile Regression DQN (QR-DQN)`_
**Gym Wrappers**:
- `Time Feature Wrapper`_
Documentation
-------------
Documentation is available online: https://sb3-contrib.readthedocs.io/
Installation
------------
To install Stable-Baselines3 contrib with pip, execute:
::
pip install sb3-contrib
We recommend to use the ``master`` version of Stable Baselines3 and SB3-Contrib.
To install Stable Baselines3 ``master`` version:
::
pip install git+https://github.com/DLR-RM/stable-baselines3
To install Stable Baselines3 contrib ``master`` version:
::
pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
Example
-------
SB3-Contrib follows the SB3 API and folder structure. So, if you are familiar with SB3,
using SB3-Contrib should be easy too.
Here is an example of training a Quantile Regression DQN (QR-DQN) agent on the CartPole environment.
.. code-block:: python
from sb3_contrib import QRDQN
policy_kwargs = dict(n_quantiles=50)
model = QRDQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1)
model.learn(total_timesteps=10000, log_interval=4)
model.save("qrdqn_cartpole")
.. _SB3-Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
.. _Truncated Quantile Critics (TQC): https://arxiv.org/abs/2005.04269
.. _Quantile Regression DQN (QR-DQN): https://arxiv.org/abs/1710.10044
.. _Time Feature Wrapper: https://arxiv.org/abs/1712.00378

View file

@ -3,10 +3,10 @@
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
Welcome to Stable Baselines3 docs! - RL Baselines Made Easy
===========================================================
Stable-Baselines3 Docs - Reliable Reinforcement Learning Implementations
========================================================================
`Stable Baselines3 <https://github.com/DLR-RM/stable-baselines3>`_ is a set of improved implementations of reinforcement learning algorithms in PyTorch.
`Stable Baselines3 (SB3) <https://github.com/DLR-RM/stable-baselines3>`_ is a set of reliable implementations of reinforcement learning algorithms in PyTorch.
It is the next major version of `Stable Baselines <https://github.com/hill-a/stable-baselines>`_.
@ -16,6 +16,8 @@ RL Baselines3 Zoo (collection of pre-trained agents): https://github.com/DLR-RM/
RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and do hyperparameter tuning.
SB3 Contrib (experimental RL code, latest algorithms): https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
Main Features
--------------
@ -45,6 +47,7 @@ Main Features
guide/callbacks
guide/tensorboard
guide/rl_zoo
guide/sb3_contrib
guide/imitation
guide/migration
guide/checking_nan

View file

@ -3,7 +3,7 @@
Changelog
==========
Pre-Release 0.11.0a2 (WIP)
Pre-Release 0.11.0a4 (WIP)
-------------------------------
Breaking Changes:
@ -44,6 +44,9 @@ Others:
- Add signatures to callable type annotations (@ernestum)
- Improve error message in ``NatureCNN``
- Added checks for supported action spaces to improve clarity of error messages for the user
- Renamed variables in the ``train()`` method of ``SAC``, ``TD3`` and ``DQN`` to match SB3-Contrib.
- Updated docker base image to Ubuntu 18.04
- Set tensorboard min version to 2.2.0 (earlier version are apparently not working with PyTorch)
Documentation:
^^^^^^^^^^^^^^
@ -55,6 +58,7 @@ Documentation:
- Added example of learning rate schedule
- Added SUMO-RL as example project (@LucasAlegre)
- Fix docstring of classes in atari_wrappers.py which were inside the constructor (@LucasAlegre)
- Added SB3-Contrib page
Pre-Release 0.10.0 (2020-10-28)
-------------------------------

View file

@ -1,7 +1,7 @@
#!/bin/bash
CPU_PARENT=ubuntu:16.04
GPU_PARENT=nvidia/cuda:10.1-cudnn7-runtime-ubuntu16.04
CPU_PARENT=ubuntu:18.04
GPU_PARENT=nvidia/cuda:10.1-cudnn7-runtime-ubuntu18.04
TAG=stablebaselines/stable-baselines3
VERSION=$(cat ./stable_baselines3/version.txt)

View file

@ -10,7 +10,7 @@ long_description = """
# Stable Baselines3
Stable Baselines3 is a set of improved implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines).
Stable Baselines3 is a set of reliable implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines).
These algorithms will make it easier for the research community and industry to replicate, refine, and identify new ideas, and will create good baselines to build projects on top of. We expect these tools will be used as a base around which new ideas can be added, and as a tool for comparing a new approach against existing ones. We also hope that the simplicity of these tools will allow beginners to experiment with a more advanced toolset, without being buried in implementation details.
@ -29,6 +29,9 @@ https://stable-baselines3.readthedocs.io/en/master/
RL Baselines3 Zoo:
https://github.com/DLR-RM/rl-baselines3-zoo
SB3 Contrib:
https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
## Quick example
Most of the library tries to follow a sklearn-like syntax for the Reinforcement Learning algorithms using Gym.
@ -112,7 +115,7 @@ setup(
"atari_py~=0.2.0",
"pillow",
# Tensorboard support
"tensorboard",
"tensorboard>=2.2.0",
# Checking memory taken by replay buffer
"psutil",
],

View file

@ -155,23 +155,23 @@ class DQN(OffPolicyAlgorithm):
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
with th.no_grad():
# Compute the target Q values
target_q = self.q_net_target(replay_data.next_observations)
# Compute the next Q-values using the target network
next_q_values = self.q_net_target(replay_data.next_observations)
# Follow greedy policy: use the one with the highest value
target_q, _ = target_q.max(dim=1)
next_q_values, _ = next_q_values.max(dim=1)
# Avoid potential broadcast issue
target_q = target_q.reshape(-1, 1)
next_q_values = next_q_values.reshape(-1, 1)
# 1-step TD target
target_q = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
# Get current Q estimates
current_q = self.q_net(replay_data.observations)
# Get current Q-values estimates
current_q_values = self.q_net(replay_data.observations)
# Retrieve the q-values for the actions from the replay buffer
current_q = th.gather(current_q, dim=1, index=replay_data.actions.long())
current_q_values = th.gather(current_q_values, dim=1, index=replay_data.actions.long())
# Compute Huber loss (less sensitive to outliers)
loss = F.smooth_l1_loss(current_q, target_q)
loss = F.smooth_l1_loss(current_q_values, target_q_values)
losses.append(loss.item())
# Optimize the policy

View file

@ -74,7 +74,6 @@ class QNetwork(BasePolicy):
features_dim=self.features_dim,
activation_fn=self.activation_fn,
features_extractor=self.features_extractor,
epsilon=self.epsilon,
)
)
return data

View file

@ -223,20 +223,20 @@ class SAC(OffPolicyAlgorithm):
with th.no_grad():
# Select action according to policy
next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
# Compute the target Q value: min over all critics targets
targets = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1)
target_q, _ = th.min(targets, dim=1, keepdim=True)
# Compute the next Q values: min over all critics targets
next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1)
next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
# add entropy term
target_q = target_q - ent_coef * next_log_prob.reshape(-1, 1)
next_q_values = next_q_values - ent_coef * next_log_prob.reshape(-1, 1)
# td error + entropy term
q_backup = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
# Get current Q estimates for each critic network
# Get current Q-values estimates for each critic network
# using action from the replay buffer
current_q_estimates = self.critic(replay_data.observations, replay_data.actions)
current_q_values = self.critic(replay_data.observations, replay_data.actions)
# Compute critic loss
critic_loss = 0.5 * sum([F.mse_loss(current_q, q_backup) for current_q in current_q_estimates])
critic_loss = 0.5 * sum([F.mse_loss(current_q, target_q_values) for current_q in current_q_values])
critic_losses.append(critic_loss.item())
# Optimize the critic

View file

@ -146,16 +146,16 @@ class TD3(OffPolicyAlgorithm):
noise = noise.clamp(-self.target_noise_clip, self.target_noise_clip)
next_actions = (self.actor_target(replay_data.next_observations) + noise).clamp(-1, 1)
# Compute the target Q value: min over all critics targets
targets = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1)
target_q, _ = th.min(targets, dim=1, keepdim=True)
target_q = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q
# Compute the next Q-values: min over all critics targets
next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1)
next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
# Get current Q estimates for each critic network
current_q_estimates = self.critic(replay_data.observations, replay_data.actions)
# Get current Q-values estimates for each critic network
current_q_values = self.critic(replay_data.observations, replay_data.actions)
# Compute critic loss
critic_loss = sum([F.mse_loss(current_q, target_q) for current_q in current_q_estimates])
critic_loss = sum([F.mse_loss(current_q, target_q_values) for current_q in current_q_values])
critic_losses.append(critic_loss.item())
# Optimize the critics

View file

@ -1 +1 @@
0.11.0a2
0.11.0a4

View file

@ -408,6 +408,82 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
os.remove(tmp_path / "actor.pkl")
@pytest.mark.parametrize("model_class", [DQN])
@pytest.mark.parametrize("policy_str", ["MlpPolicy", "CnnPolicy"])
def test_save_load_q_net(tmp_path, model_class, policy_str):
"""
Test saving and loading q-network/quantile net only.
:param model_class: (BaseAlgorithm) A RL model
:param policy_str: (str) Name of the policy.
"""
kwargs = dict(policy_kwargs=dict(net_arch=[16]))
if policy_str == "MlpPolicy":
env = select_env(model_class)
else:
if model_class in [DQN]:
# Avoid memory error when using replay buffer
# Reduce the size of the features
kwargs = dict(
buffer_size=250,
learning_starts=100,
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
)
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == DQN)
env = DummyVecEnv([lambda: env])
# create model
model = model_class(policy_str, env, verbose=1, **kwargs)
model.learn(total_timesteps=300)
env.reset()
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
q_net = model.q_net
q_net_class = q_net.__class__
# Get dictionary of current parameters
params = deepcopy(q_net.state_dict())
# Modify all parameters to be random values
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
# Update model parameters with the new random values
q_net.load_state_dict(random_params)
new_params = q_net.state_dict()
# Check that all params are different now
for k in params:
assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected."
params = new_params
# get selected actions
selected_actions, _ = q_net.predict(observations, deterministic=True)
# Save and load q_net
q_net.save(tmp_path / "q_net.pkl")
del q_net
q_net = q_net_class.load(tmp_path / "q_net.pkl")
# check if params are still the same after load
new_params = q_net.state_dict()
# Check that all params are the same as before save load procedure now
for key in params:
assert th.allclose(params[key], new_params[key]), "Policy parameters not the same after save and load."
# check if model still selects the same actions
new_selected_actions, _ = q_net.predict(observations, deterministic=True)
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
# clear file from os
os.remove(tmp_path / "q_net.pkl")
@pytest.mark.parametrize("pathtype", [str, pathlib.Path])
def test_open_file_str_pathlib(tmp_path, pathtype):
# check that suffix isn't added because we used open_path first