mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-29 23:07:07 +00:00
Update documentation (#199)
* Update doc and add new example * Add save/load replay buffer example * Add save format + export doc * Add example for get/set parameters * Typos and minor edits * Add results sections * Add note about performance * Add DDPG results * Address comments * Fix grammar/wording Co-authored-by: Anssi "Miffyli" Kanervisto <kaneran21@hotmail.com>
This commit is contained in:
parent
6327cc6156
commit
897e98c4e2
17 changed files with 715 additions and 18 deletions
|
|
@ -19,6 +19,9 @@ These algorithms will make it easier for the research community and industry to
|
|||
|
||||
## Main Features
|
||||
|
||||
**The performance of each algorithm was tested** (see *Results* section in their respective page),
|
||||
you can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselines3/issues/48) and [#49](https://github.com/DLR-RM/stable-baselines3/issues/49) for more details.
|
||||
|
||||
|
||||
| **Features** | **Stable-Baselines3** |
|
||||
| --------------------------- | ----------------------|
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ To build a custom callback, you need to create a class that derives from ``BaseC
|
|||
This will give you access to events (``_on_training_start``, ``_on_step``) and useful variables (like `self.model` for the RL model).
|
||||
|
||||
|
||||
.. You can find two examples of custom callbacks in the documentation: one for saving the best model according to the training reward (see :ref:`Examples <examples>`), and one for logging additional values with Tensorboard (see :ref:`Tensorboard section <tensorboard>`).
|
||||
You can find two examples of custom callbacks in the documentation: one for saving the best model according to the training reward (see :ref:`Examples <examples>`), and one for logging additional values with Tensorboard (see :ref:`Tensorboard section <tensorboard>`).
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
|
|
|||
|
|
@ -7,6 +7,13 @@ Stable Baselines3 provides policy networks for images (CnnPolicies)
|
|||
and other type of input features (MlpPolicies).
|
||||
|
||||
|
||||
.. warning::
|
||||
For A2C and PPO, continuous actions are clipped during training and testing
|
||||
(to avoid out of bound error). SAC, DDPG and TD3 squash the action, using a ``tanh()`` transformation,
|
||||
which handles bounds more correctly.
|
||||
|
||||
|
||||
|
||||
Custom Policy Architecture
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ notebooks:
|
|||
- `RL Baselines zoo`_
|
||||
- `PyBullet`_
|
||||
- `Hindsight Experience Replay`_
|
||||
- `Advanced Saving and Loading`_
|
||||
|
||||
.. _Getting Started: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_getting_started.ipynb
|
||||
.. _Training, Saving, Loading: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/saving_loading_dqn.ipynb
|
||||
|
|
@ -28,6 +29,7 @@ notebooks:
|
|||
.. _Hindsight Experience Replay: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_her.ipynb
|
||||
.. _RL Baselines zoo: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/rl-baselines-zoo.ipynb
|
||||
.. _PyBullet: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pybullet.ipynb
|
||||
.. _Advanced Saving and Loading: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/advanced_saving_loading.ipynb
|
||||
|
||||
.. |colab| image:: ../_static/img/colab.svg
|
||||
|
||||
|
|
@ -417,6 +419,171 @@ The parking env is a goal-conditioned continuous control task, in which the vehi
|
|||
obs = env.reset()
|
||||
|
||||
|
||||
Advanced Saving and Loading
|
||||
---------------------------------
|
||||
|
||||
In this example, we show how to use some advanced features of Stable-Baselines3 (SB3):
|
||||
how to easily create a test environment to evaluate an agent periodically,
|
||||
use a policy independently from a model (and how to save it, load it) and save/load a replay buffer.
|
||||
|
||||
By default, the replay buffer is not saved when calling ``model.save()``, in order to save space on the disk (a replay buffer can be up to several GB when using images).
|
||||
However, SB3 provides a ``save_replay_buffer()`` and ``load_replay_buffer()`` method to save it separately.
|
||||
|
||||
|
||||
.. image:: ../_static/img/colab-badge.svg
|
||||
:target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/advanced_saving_loading.ipynb
|
||||
|
||||
Stable-Baselines3 automatic creation of an environment for evaluation.
|
||||
For that, you only need to specify ``create_eval_env=True`` when passing the Gym ID of the environment while creating the agent.
|
||||
Behind the scene, SB3 uses an :ref:`EvalCallback <callbacks>`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from stable_baselines3 import SAC
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.sac.policies import MlpPolicy
|
||||
|
||||
# Create the model, the training environment
|
||||
# and the test environment (for evaluation)
|
||||
model = SAC('MlpPolicy', 'Pendulum-v0', verbose=1,
|
||||
learning_rate=1e-3, create_eval_env=True)
|
||||
|
||||
# Evaluate the model every 1000 steps on 5 test episodes
|
||||
# and save the evaluation to the "logs/" folder
|
||||
model.learn(6000, eval_freq=1000, n_eval_episodes=5, eval_log_path="./logs/")
|
||||
|
||||
# save the model
|
||||
model.save("sac_pendulum")
|
||||
|
||||
# the saved model does not contain the replay buffer
|
||||
loaded_model = SAC.load("sac_pendulum")
|
||||
print(f"The loaded_model has {loaded_model.replay_buffer.size()} transitions in its buffer")
|
||||
|
||||
# now save the replay buffer too
|
||||
model.save_replay_buffer("sac_replay_buffer")
|
||||
|
||||
# load it into the loaded_model
|
||||
loaded_model.load_replay_buffer("sac_replay_buffer")
|
||||
|
||||
# now the loaded replay is not empty anymore
|
||||
print(f"The loaded_model has {loaded_model.replay_buffer.size()} transitions in its buffer")
|
||||
|
||||
# Save the policy independently from the model
|
||||
# Note: if you don't save the complete model with `model.save()`
|
||||
# you cannot continue training afterward
|
||||
policy = model.policy
|
||||
policy.save("sac_policy_pendulum.pkl")
|
||||
|
||||
# Retrieve the environment
|
||||
env = model.get_env()
|
||||
|
||||
# Evaluate the policy
|
||||
mean_reward, std_reward = evaluate_policy(policy, env, n_eval_episodes=10, deterministic=True)
|
||||
|
||||
print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")
|
||||
|
||||
# Load the policy independently from the model
|
||||
saved_policy = MlpPolicy.load("sac_policy_pendulum")
|
||||
|
||||
# Evaluate the loaded policy
|
||||
mean_reward, std_reward = evaluate_policy(saved_policy, env, n_eval_episodes=10, deterministic=True)
|
||||
|
||||
print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")
|
||||
|
||||
|
||||
|
||||
Accessing and modifying model parameters
|
||||
----------------------------------------
|
||||
|
||||
You can access model's parameters via ``load_parameters`` and ``get_parameters`` functions,
|
||||
or via ``model.policy.state_dict()`` (and ``load_state_dict()``),
|
||||
which use dictionaries that map variable names to PyTorch tensors.
|
||||
|
||||
These functions are useful when you need to e.g. evaluate large set of models with same network structure,
|
||||
visualize different layers of the network or modify parameters manually.
|
||||
|
||||
Policies also offers a simple way to save/load weights as a NumPy vector, using ``parameters_to_vector()``
|
||||
and ``load_from_vector()`` method.
|
||||
|
||||
Following example demonstrates reading parameters, modifying some of them and loading them to model
|
||||
by implementing `evolution strategy (es) <http://blog.otoro.net/2017/10/29/visual-evolution-strategies/>`_
|
||||
for solving the ``CartPole-v1`` environment. The initial guess for parameters is obtained by running
|
||||
A2C policy gradient updates on the model.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
||||
from stable_baselines3 import A2C
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
|
||||
|
||||
def mutate(params: Dict[str, th.Tensor]) -> Dict[str, th.Tensor]:
|
||||
"""Mutate parameters by adding normal noise to them"""
|
||||
return dict((name, param + th.randn_like(param)) for name, param in params.items())
|
||||
|
||||
|
||||
# Create policy with a small network
|
||||
model = A2C(
|
||||
"MlpPolicy",
|
||||
"CartPole-v1",
|
||||
ent_coef=0.0,
|
||||
policy_kwargs={"net_arch": [32]},
|
||||
seed=0,
|
||||
learning_rate=0.05,
|
||||
)
|
||||
|
||||
# Use traditional actor-critic policy gradient updates to
|
||||
# find good initial parameters
|
||||
model.learn(total_timesteps=10000)
|
||||
|
||||
# Include only variables with "policy", "action" (policy) or "shared_net" (shared layers)
|
||||
# in their name: only these ones affect the action.
|
||||
# NOTE: you can retrieve those parameters using model.get_parameters() too
|
||||
mean_params = dict(
|
||||
(key, value)
|
||||
for key, value in model.policy.state_dict().items()
|
||||
if ("policy" in key or "shared_net" in key or "action" in key)
|
||||
)
|
||||
|
||||
# population size of 50 invdiduals
|
||||
pop_size = 50
|
||||
# Keep top 10%
|
||||
n_elite = pop_size // 10
|
||||
# Retrieve the environment
|
||||
env = model.get_env()
|
||||
|
||||
for iteration in range(10):
|
||||
# Create population of candidates and evaluate them
|
||||
population = []
|
||||
for population_i in range(pop_size):
|
||||
candidate = mutate(mean_params)
|
||||
# Load new policy parameters to agent.
|
||||
# Tell function that it should only update parameters
|
||||
# we give it (policy parameters)
|
||||
model.policy.load_state_dict(candidate, strict=False)
|
||||
# Evaluate the candidate
|
||||
fitness, _ = evaluate_policy(model, env)
|
||||
population.append((candidate, fitness))
|
||||
# Take top 10% and use average over their parameters as next mean parameter
|
||||
top_candidates = sorted(population, key=lambda x: x[1], reverse=True)[:n_elite]
|
||||
mean_params = dict(
|
||||
(
|
||||
name,
|
||||
th.stack([candidate[0][name] for candidate in top_candidates]).mean(dim=0),
|
||||
)
|
||||
for name in mean_params.keys()
|
||||
)
|
||||
mean_fitness = sum(top_candidate[1] for top_candidate in top_candidates) / n_elite
|
||||
print(f"Iteration {iteration + 1:<3} Mean top fitness: {mean_fitness:.2f}")
|
||||
print(f"Best fitness: {top_candidates[0][1]:.2f}")
|
||||
|
||||
|
||||
|
||||
Record a Video
|
||||
--------------
|
||||
|
||||
|
|
|
|||
67
docs/guide/export.rst
Normal file
67
docs/guide/export.rst
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
.. _export:
|
||||
|
||||
|
||||
Exporting models
|
||||
================
|
||||
|
||||
After training an agent, you may want to deploy/use it in another language
|
||||
or framework, like `tensorflowjs <https://github.com/tensorflow/tfjs>`_.
|
||||
Stable Baselines3 does not include tools to export models to other frameworks, but
|
||||
this document aims to cover parts that are required for exporting along with
|
||||
more detailed stories from users of Stable Baselines3.
|
||||
|
||||
|
||||
Background
|
||||
----------
|
||||
|
||||
In Stable Baselines3, the controller is stored inside policies which convert
|
||||
observations into actions. Each learning algorithm (e.g. DQN, A2C, SAC)
|
||||
contains a policy object which represents the currently learned behavior,
|
||||
accessible via ``model.policy``.
|
||||
|
||||
Policies hold enough information to do the inference (i.e. predict actions),
|
||||
so it is enough to export these policies (cf :ref:`examples <examples>`)
|
||||
to do inference in another framework.
|
||||
|
||||
.. warning::
|
||||
When using CNN policies, the observation is normalized during pre-preprocessing.
|
||||
This pre-processing is done *inside* the policy (dividing by 255 to have values in [0, 1])
|
||||
|
||||
|
||||
Export to ONNX
|
||||
-----------------
|
||||
|
||||
TODO: help is welcomed!
|
||||
|
||||
|
||||
Export to C++
|
||||
-----------------
|
||||
|
||||
(using PyTorch JIT)
|
||||
TODO: help is welcomed!
|
||||
|
||||
|
||||
Export to tensorflowjs / ONNX-JS
|
||||
--------------------------------
|
||||
|
||||
TODO: contributors help is welcomed!
|
||||
Probably a good starting point: https://github.com/elliotwaite/pytorch-to-javascript-with-onnx-js
|
||||
|
||||
|
||||
|
||||
Manual export
|
||||
-------------
|
||||
|
||||
You can also manually export required parameters (weights) and construct the
|
||||
network in your desired framework.
|
||||
|
||||
You can access parameters of the model via agents'
|
||||
:func:`get_parameters <stable_baselines3.common.base_class.BaseAlgorithm.get_parameters>` function.
|
||||
As policies are also PyTorch modules, you can also access ``model.policy.state_dict()`` directly.
|
||||
To find the architecture of the networks for each algorithm, best is to check the ``policies.py`` file located
|
||||
in their respective folders.
|
||||
|
||||
.. note::
|
||||
|
||||
In most cases, we recommend using PyTorch methods ``state_dict()`` and ``load_state_dict()`` from the policy,
|
||||
unless you need to access the optimizers' state dict too. In that case, you need to call ``get_parameters()``.
|
||||
|
|
@ -59,6 +59,7 @@ Moved Files
|
|||
- ``bench/monitor.py`` -> ``common/monitor.py``
|
||||
- ``logger.py`` -> ``common/logger.py``
|
||||
- ``results_plotter.py`` -> ``common/results_plotter.py``
|
||||
- ``common/cmd_util.py`` -> ``common/env_util.py``
|
||||
|
||||
Utility functions are no longer exported from ``common`` module, you should import them with their absolute path, e.g.:
|
||||
|
||||
|
|
|
|||
|
|
@ -146,17 +146,17 @@ for continuous actions problems (cf *Bullet* envs).
|
|||
|
||||
|
||||
|
||||
.. Goal Environment
|
||||
.. -----------------
|
||||
..
|
||||
.. If your environment follows the ``GoalEnv`` interface (cf `HER <../modules/her.html>`_), then you should use
|
||||
.. HER + (SAC/TD3/DDPG/DQN) depending on the action space.
|
||||
..
|
||||
..
|
||||
.. .. note::
|
||||
..
|
||||
.. The number of workers is an important hyperparameters for experiments with HER
|
||||
..
|
||||
Goal Environment
|
||||
-----------------
|
||||
|
||||
If your environment follows the ``GoalEnv`` interface (cf :ref:`HER <her>`), then you should use
|
||||
HER + (SAC/TD3/DDPG/DQN) depending on the action space.
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
The number of workers is an important hyperparameters for experiments with HER
|
||||
|
||||
|
||||
|
||||
Tips and Tricks when creating a custom environment
|
||||
|
|
|
|||
59
docs/guide/save_format.rst
Normal file
59
docs/guide/save_format.rst
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
.. _save_format:
|
||||
|
||||
|
||||
On saving and loading
|
||||
=====================
|
||||
|
||||
Stable Baselines3 (SB3) stores both neural network parameters and algorithm-related parameters such as
|
||||
exploration schedule, number of environments and observation/action space. This allows continual learning and easy
|
||||
use of trained agents without training, but it is not without its issues. Following describes the format
|
||||
used to save agents in SB3 along with its pros and shortcomings.
|
||||
|
||||
Terminology used in this page:
|
||||
|
||||
- *parameters* refer to neural network parameters (also called "weights"). This is a dictionary
|
||||
mapping variable name to a PyTorch tensor.
|
||||
- *data* refers to RL algorithm parameters, e.g. learning rate, exploration schedule, action/observation space.
|
||||
These depend on the algorithm used. This is a dictionary mapping classes variable names to their values.
|
||||
|
||||
|
||||
Zip-archive
|
||||
-----------
|
||||
|
||||
A zip-archived JSON dump, PyTorch state dictionaries and PyTorch variables. The data dictionary (class parameters)
|
||||
is stored as a JSON file, model parameters and optimizers are serialized with ``torch.save()`` function and these files
|
||||
are stored under a single .zip archive.
|
||||
|
||||
Any objects that are not JSON serializable are serialized with cloudpickle and stored as base64-encoded
|
||||
string in the JSON file, along with some information that was stored in the serialization. This allows
|
||||
inspecting stored objects without deserializing the object itself.
|
||||
|
||||
This format allows skipping elements in the file, i.e. we can skip deserializing objects that are
|
||||
broken/non-serializable.
|
||||
|
||||
.. This can be done via ``custom_objects`` argument to load functions.
|
||||
|
||||
|
||||
File structure:
|
||||
|
||||
::
|
||||
|
||||
saved_model.zip/
|
||||
├── data JSON file of class-parameters (dictionary)
|
||||
├── *.optimizer.pth PyTorch optimizers serialized
|
||||
├── policy.pth PyTorch state dictionary of the policy saved
|
||||
├── pytorch_variables.pth Additional PyTorch variables
|
||||
├── _stable_baselines3_version contains the SB3 version with which the model was saved
|
||||
|
||||
|
||||
Pros:
|
||||
|
||||
- More robust to unserializable objects (one bad object does not break everything).
|
||||
- Saved files can be inspected/extracted with zip-archive explorers and by other languages.
|
||||
|
||||
|
||||
Cons:
|
||||
|
||||
- More complex implementation.
|
||||
- Still relies partly on cloudpickle for complex objects (e.g. custom functions)
|
||||
with can lead to `incompatibilities <https://github.com/DLR-RM/stable-baselines3/issues/172>`_ between Python versions.
|
||||
|
|
@ -26,6 +26,7 @@ Main Features
|
|||
- Tests, high code coverage and type hints
|
||||
- Clean code
|
||||
- Tensorboard support
|
||||
- **The performance of each algorithm was tested** (see *Results* section in their respective page)
|
||||
|
||||
|
||||
.. toctree::
|
||||
|
|
@ -48,6 +49,8 @@ Main Features
|
|||
guide/migration
|
||||
guide/checking_nan
|
||||
guide/developer
|
||||
guide/save_format
|
||||
guide/export
|
||||
|
||||
|
||||
.. toctree::
|
||||
|
|
|
|||
|
|
@ -47,6 +47,9 @@ Documentation:
|
|||
- Added first draft of migration guide
|
||||
- Added intro to `imitation <https://github.com/HumanCompatibleAI/imitation>`_ library (@shwang)
|
||||
- Enabled doc for ``CnnPolicies``
|
||||
- Added advanced saving and loading example
|
||||
- Added base doc for exporting models
|
||||
- Added example for getting and setting model parameters
|
||||
|
||||
|
||||
Pre-Release 0.9.0 (2020-10-03)
|
||||
|
|
|
|||
|
|
@ -73,6 +73,72 @@ Train a A2C agent on ``CartPole-v1`` using 4 environments.
|
|||
obs, rewards, dones, info = env.step(action)
|
||||
env.render()
|
||||
|
||||
|
||||
Results
|
||||
-------
|
||||
|
||||
Atari Games
|
||||
^^^^^^^^^^^
|
||||
|
||||
The complete learning curves are available in the `associated PR #110 <https://github.com/DLR-RM/stable-baselines3/pull/110>`_.
|
||||
|
||||
|
||||
PyBullet Environments
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Results on the PyBullet benchmark (2M steps) using 6 seeds.
|
||||
The complete learning curves are available in the `associated issue #48 <https://github.com/DLR-RM/stable-baselines3/issues/48>`_.
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
Hyperparameters from the `gSDE paper <https://arxiv.org/abs/2005.05719>`_ were used (as they are tuned for PyBullet envs).
|
||||
|
||||
|
||||
*Gaussian* means that the unstructured Gaussian noise is used for exploration,
|
||||
*gSDE* (generalized State-Dependent Exploration) is used otherwise.
|
||||
|
||||
+--------------+--------------+--------------+--------------+-------------+
|
||||
| Environments | A2C | A2C | PPO | PPO |
|
||||
+==============+==============+==============+==============+=============+
|
||||
| | Gaussian | gSDE | Gaussian | gSDE |
|
||||
+--------------+--------------+--------------+--------------+-------------+
|
||||
| HalfCheetah | 2003 +/- 54 | 2032 +/- 122 | 1976 +/- 479 | 2826 +/- 45 |
|
||||
+--------------+--------------+--------------+--------------+-------------+
|
||||
| Ant | 2286 +/- 72 | 2443 +/- 89 | 2364 +/- 120 | 2782 +/- 76 |
|
||||
+--------------+--------------+--------------+--------------+-------------+
|
||||
| Hopper | 1627 +/- 158 | 1561 +/- 220 | 1567 +/- 339 | 2512 +/- 21 |
|
||||
+--------------+--------------+--------------+--------------+-------------+
|
||||
| Walker2D | 577 +/- 65 | 839 +/- 56 | 1230 +/- 147 | 2019 +/- 64 |
|
||||
+--------------+--------------+--------------+--------------+-------------+
|
||||
|
||||
|
||||
How to replicate the results?
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Clone the `rl-zoo repo <https://github.com/DLR-RM/rl-baselines3-zoo>`_:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
git clone https://github.com/DLR-RM/rl-baselines3-zoo
|
||||
cd rl-baselines3-zoo/
|
||||
|
||||
|
||||
Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above):
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python train.py --algo a2c --env $ENV_ID --eval-episodes 10 --eval-freq 10000
|
||||
|
||||
|
||||
Plot the results (here for PyBullet envs only):
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/all_plots.py -a a2c -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/a2c_results
|
||||
python scripts/plot_from_file.py -i logs/a2c_results.pkl -latex -l A2C
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
|
|
|
|||
|
|
@ -10,12 +10,19 @@ DDPG
|
|||
trick for DQN with the deterministic policy gradient, to obtain an algorithm for continuous actions.
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
As ``DDPG`` can be seen as a special case of its successor :ref:`TD3 <td3>`,
|
||||
they share the same policies and same implementation.
|
||||
|
||||
|
||||
.. rubric:: Available Policies
|
||||
|
||||
.. autosummary::
|
||||
:nosignatures:
|
||||
|
||||
MlpPolicy
|
||||
CnnPolicy
|
||||
|
||||
|
||||
Notes
|
||||
|
|
@ -25,10 +32,6 @@ Notes
|
|||
- DDPG Paper: https://arxiv.org/abs/1509.02971
|
||||
- OpenAI Spinning Guide for DDPG: https://spinningup.openai.com/en/latest/algorithms/ddpg.html
|
||||
|
||||
.. note::
|
||||
|
||||
The default policy for DDPG uses a ReLU activation, to match the original paper, whereas most other algorithms' MlpPolicy uses a tanh activation.
|
||||
to match the original paper
|
||||
|
||||
|
||||
Can I use?
|
||||
|
|
@ -81,6 +84,66 @@ Example
|
|||
obs, rewards, dones, info = env.step(action)
|
||||
env.render()
|
||||
|
||||
Results
|
||||
-------
|
||||
|
||||
PyBullet Environments
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Results on the PyBullet benchmark (1M steps) using 6 seeds.
|
||||
The complete learning curves are available in the `associated issue #48 <https://github.com/DLR-RM/stable-baselines3/issues/48>`_.
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
Hyperparameters of :ref:`TD3 <td3>` from the `gSDE paper <https://arxiv.org/abs/2005.05719>`_ were used for ``DDPG``.
|
||||
|
||||
|
||||
*Gaussian* means that the unstructured Gaussian noise is used for exploration,
|
||||
*gSDE* (generalized State-Dependent Exploration) is used otherwise.
|
||||
|
||||
+--------------+--------------+--------------+--------------+
|
||||
| Environments | DDPG | TD3 | SAC |
|
||||
+==============+==============+==============+==============+
|
||||
| | Gaussian | Gaussian | gSDE |
|
||||
+--------------+--------------+--------------+--------------+
|
||||
| HalfCheetah | 2272 +/- 69 | 2774 +/- 35 | 2984 +/- 202 |
|
||||
+--------------+--------------+--------------+--------------+
|
||||
| Ant | 1651 +/- 407 | 3305 +/- 43 | 3102 +/- 37 |
|
||||
+--------------+--------------+--------------+--------------+
|
||||
| Hopper | 1201 +/- 211 | 2429 +/- 126 | 2262 +/- 1 |
|
||||
+--------------+--------------+--------------+--------------+
|
||||
| Walker2D | 882 +/- 186 | 2063 +/- 185 | 2136 +/- 67 |
|
||||
+--------------+--------------+--------------+--------------+
|
||||
|
||||
|
||||
|
||||
How to replicate the results?
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Clone the `rl-zoo repo <https://github.com/DLR-RM/rl-baselines3-zoo>`_:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
git clone https://github.com/DLR-RM/rl-baselines3-zoo
|
||||
cd rl-baselines3-zoo/
|
||||
|
||||
|
||||
Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above):
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python train.py --algo ddpg --env $ENV_ID --eval-episodes 10 --eval-freq 10000
|
||||
|
||||
|
||||
Plot the results:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/all_plots.py -a ddpg -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/ddpg_results
|
||||
python scripts/plot_from_file.py -i logs/ddpg_results.pkl -latex -l DDPG
|
||||
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
|
|
|||
|
|
@ -6,7 +6,9 @@
|
|||
DQN
|
||||
===
|
||||
|
||||
`Deep Q Network (DQN) <https://arxiv.org/abs/1312.5602>`_
|
||||
`Deep Q Network (DQN) <https://arxiv.org/abs/1312.5602>`_ builds on `Fitted Q-Iteration (FQI) <http://ml.informatik.uni-freiburg.de/former/_media/publications/rieecml05.pdf>`_
|
||||
and make use of different tricks to stabilize the learning with neural networks: it uses a replay buffer, a target network and gradient clipping.
|
||||
|
||||
|
||||
.. rubric:: Available Policies
|
||||
|
||||
|
|
@ -74,6 +76,42 @@ Example
|
|||
if done:
|
||||
obs = env.reset()
|
||||
|
||||
|
||||
Results
|
||||
-------
|
||||
|
||||
Atari Games
|
||||
^^^^^^^^^^^
|
||||
|
||||
The complete learning curves are available in the `associated PR #110 <https://github.com/DLR-RM/stable-baselines3/pull/110>`_.
|
||||
|
||||
|
||||
How to replicate the results?
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Clone the `rl-zoo repo <https://github.com/DLR-RM/rl-baselines3-zoo>`_:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
git clone https://github.com/DLR-RM/rl-baselines3-zoo
|
||||
cd rl-baselines3-zoo/
|
||||
|
||||
|
||||
Run the benchmark (replace ``$ENV_ID`` by the env id, for instance ``BreakoutNoFrameskip-v4``):
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python train.py --algo a2c --env $ENV_ID --eval-episodes 10 --eval-freq 10000
|
||||
|
||||
|
||||
Plot the results:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/all_plots.py -a dqn -e Pong Breakout -f logs/ -o logs/dqn_results
|
||||
python scripts/plot_from_file.py -i logs/dqn_results.pkl -latex -l DQN
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
|
|
|
|||
|
|
@ -89,6 +89,41 @@ Example
|
|||
obs = env.reset()
|
||||
|
||||
|
||||
Results
|
||||
-------
|
||||
|
||||
This implementation was tested on the `parking env <https://github.com/eleurent/highway-env>`_
|
||||
using 3 seeds.
|
||||
|
||||
The complete learning curves are available in the `associated PR #120 <https://github.com/DLR-RM/stable-baselines3/pull/120>`_.
|
||||
|
||||
|
||||
|
||||
How to replicate the results?
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Clone the `rl-zoo repo <https://github.com/DLR-RM/rl-baselines3-zoo>`_:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
git clone https://github.com/DLR-RM/rl-baselines3-zoo
|
||||
cd rl-baselines3-zoo/
|
||||
|
||||
|
||||
Run the benchmark:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python train.py --algo her --env parking-v0 --eval-episodes 10 --eval-freq 10000
|
||||
|
||||
|
||||
Plot the results:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/all_plots.py -a her -e parking-v0 -f logs/ --no-million
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
|
|
|
|||
|
|
@ -74,6 +74,72 @@ Train a PPO agent on ``Pendulum-v0`` using 4 environments.
|
|||
obs, rewards, dones, info = env.step(action)
|
||||
env.render()
|
||||
|
||||
|
||||
Results
|
||||
-------
|
||||
|
||||
Atari Games
|
||||
^^^^^^^^^^^
|
||||
|
||||
The complete learning curves are available in the `associated PR #110 <https://github.com/DLR-RM/stable-baselines3/pull/110>`_.
|
||||
|
||||
|
||||
PyBullet Environments
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Results on the PyBullet benchmark (2M steps) using 6 seeds.
|
||||
The complete learning curves are available in the `associated issue #48 <https://github.com/DLR-RM/stable-baselines3/issues/48>`_.
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
Hyperparameters from the `gSDE paper <https://arxiv.org/abs/2005.05719>`_ were used (as they are tuned for PyBullet envs).
|
||||
|
||||
|
||||
*Gaussian* means that the unstructured Gaussian noise is used for exploration,
|
||||
*gSDE* (generalized State-Dependent Exploration) is used otherwise.
|
||||
|
||||
+--------------+--------------+--------------+--------------+-------------+
|
||||
| Environments | A2C | A2C | PPO | PPO |
|
||||
+==============+==============+==============+==============+=============+
|
||||
| | Gaussian | gSDE | Gaussian | gSDE |
|
||||
+--------------+--------------+--------------+--------------+-------------+
|
||||
| HalfCheetah | 2003 +/- 54 | 2032 +/- 122 | 1976 +/- 479 | 2826 +/- 45 |
|
||||
+--------------+--------------+--------------+--------------+-------------+
|
||||
| Ant | 2286 +/- 72 | 2443 +/- 89 | 2364 +/- 120 | 2782 +/- 76 |
|
||||
+--------------+--------------+--------------+--------------+-------------+
|
||||
| Hopper | 1627 +/- 158 | 1561 +/- 220 | 1567 +/- 339 | 2512 +/- 21 |
|
||||
+--------------+--------------+--------------+--------------+-------------+
|
||||
| Walker2D | 577 +/- 65 | 839 +/- 56 | 1230 +/- 147 | 2019 +/- 64 |
|
||||
+--------------+--------------+--------------+--------------+-------------+
|
||||
|
||||
|
||||
How to replicate the results?
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Clone the `rl-zoo repo <https://github.com/DLR-RM/rl-baselines3-zoo>`_:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
git clone https://github.com/DLR-RM/rl-baselines3-zoo
|
||||
cd rl-baselines3-zoo/
|
||||
|
||||
|
||||
Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above):
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python train.py --algo ppo --env $ENV_ID --eval-episodes 10 --eval-freq 10000
|
||||
|
||||
|
||||
Plot the results (here for PyBullet envs only):
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/all_plots.py -a ppo -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/ppo_results
|
||||
python scripts/plot_from_file.py -i logs/ppo_results.pkl -latex -l PPO
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
|
|
|
|||
|
|
@ -88,6 +88,66 @@ Example
|
|||
if done:
|
||||
obs = env.reset()
|
||||
|
||||
|
||||
Results
|
||||
-------
|
||||
|
||||
PyBullet Environments
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Results on the PyBullet benchmark (1M steps) using 3 seeds.
|
||||
The complete learning curves are available in the `associated issue #48 <https://github.com/DLR-RM/stable-baselines3/issues/48>`_.
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
Hyperparameters from the `gSDE paper <https://arxiv.org/abs/2005.05719>`_ were used (as they are tuned for PyBullet envs).
|
||||
|
||||
|
||||
*Gaussian* means that the unstructured Gaussian noise is used for exploration,
|
||||
*gSDE* (generalized State-Dependent Exploration) is used otherwise.
|
||||
|
||||
+--------------+--------------+--------------+--------------+
|
||||
| Environments | SAC | SAC | TD3 |
|
||||
+==============+==============+==============+==============+
|
||||
| | Gaussian | gSDE | Gaussian |
|
||||
+--------------+--------------+--------------+--------------+
|
||||
| HalfCheetah | 2757 +/- 53 | 2984 +/- 202 | 2774 +/- 35 |
|
||||
+--------------+--------------+--------------+--------------+
|
||||
| Ant | 3146 +/- 35 | 3102 +/- 37 | 3305 +/- 43 |
|
||||
+--------------+--------------+--------------+--------------+
|
||||
| Hopper | 2422 +/- 168 | 2262 +/- 1 | 2429 +/- 126 |
|
||||
+--------------+--------------+--------------+--------------+
|
||||
| Walker2D | 2184 +/- 54 | 2136 +/- 67 | 2063 +/- 185 |
|
||||
+--------------+--------------+--------------+--------------+
|
||||
|
||||
|
||||
How to replicate the results?
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Clone the `rl-zoo repo <https://github.com/DLR-RM/rl-baselines3-zoo>`_:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
git clone https://github.com/DLR-RM/rl-baselines3-zoo
|
||||
cd rl-baselines3-zoo/
|
||||
|
||||
|
||||
Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above):
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python train.py --algo sac --env $ENV_ID --eval-episodes 10 --eval-freq 10000
|
||||
|
||||
|
||||
Plot the results:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/all_plots.py -a sac -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/sac_results
|
||||
python scripts/plot_from_file.py -i logs/sac_results.pkl -latex -l SAC
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ TD3
|
|||
|
||||
`Twin Delayed DDPG (TD3) <https://spinningup.openai.com/en/latest/algorithms/td3.html>`_ Addressing Function Approximation Error in Actor-Critic Methods.
|
||||
|
||||
TD3 is a direct successor of DDPG and improves it using three major tricks: clipped double Q-Learning, delayed policy update and target policy smoothing.
|
||||
TD3 is a direct successor of :ref:`DDPG <ddpg>` and improves it using three major tricks: clipped double Q-Learning, delayed policy update and target policy smoothing.
|
||||
We recommend reading `OpenAI Spinning guide on TD3 <https://spinningup.openai.com/en/latest/algorithms/td3.html>`_ to learn more about those.
|
||||
|
||||
|
||||
|
|
@ -18,6 +18,7 @@ We recommend reading `OpenAI Spinning guide on TD3 <https://spinningup.openai.co
|
|||
:nosignatures:
|
||||
|
||||
MlpPolicy
|
||||
CnnPolicy
|
||||
|
||||
|
||||
Notes
|
||||
|
|
@ -84,6 +85,64 @@ Example
|
|||
obs, rewards, dones, info = env.step(action)
|
||||
env.render()
|
||||
|
||||
Results
|
||||
-------
|
||||
|
||||
PyBullet Environments
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Results on the PyBullet benchmark (1M steps) using 3 seeds.
|
||||
The complete learning curves are available in the `associated issue #48 <https://github.com/DLR-RM/stable-baselines3/issues/48>`_.
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
Hyperparameters from the `gSDE paper <https://arxiv.org/abs/2005.05719>`_ were used (as they are tuned for PyBullet envs).
|
||||
|
||||
|
||||
*Gaussian* means that the unstructured Gaussian noise is used for exploration,
|
||||
*gSDE* (generalized State-Dependent Exploration) is used otherwise.
|
||||
|
||||
+--------------+--------------+--------------+--------------+
|
||||
| Environments | SAC | SAC | TD3 |
|
||||
+==============+==============+==============+==============+
|
||||
| | Gaussian | gSDE | Gaussian |
|
||||
+--------------+--------------+--------------+--------------+
|
||||
| HalfCheetah | 2757 +/- 53 | 2984 +/- 202 | 2774 +/- 35 |
|
||||
+--------------+--------------+--------------+--------------+
|
||||
| Ant | 3146 +/- 35 | 3102 +/- 37 | 3305 +/- 43 |
|
||||
+--------------+--------------+--------------+--------------+
|
||||
| Hopper | 2422 +/- 168 | 2262 +/- 1 | 2429 +/- 126 |
|
||||
+--------------+--------------+--------------+--------------+
|
||||
| Walker2D | 2184 +/- 54 | 2136 +/- 67 | 2063 +/- 185 |
|
||||
+--------------+--------------+--------------+--------------+
|
||||
|
||||
|
||||
How to replicate the results?
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Clone the `rl-zoo repo <https://github.com/DLR-RM/rl-baselines3-zoo>`_:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
git clone https://github.com/DLR-RM/rl-baselines3-zoo
|
||||
cd rl-baselines3-zoo/
|
||||
|
||||
|
||||
Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above):
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python train.py --algo td3 --env $ENV_ID --eval-episodes 10 --eval-freq 10000
|
||||
|
||||
|
||||
Plot the results:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/all_plots.py -a td3 -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/td3_results
|
||||
python scripts/plot_from_file.py -i logs/td3_results.pkl -latex -l TD3
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
|
|
|||
Loading…
Reference in a new issue