stable-baselines3/docs/modules/td3.rst
Antonin RAFFIN c62e9259db
Add custom objects support + bug fix (#336)
* Add support for custom objects

* Add python 3.8 to the CI

* Bump version

* PyType fixes

* [ci skip] Fix typo

* Add note about slow-down + fix typos

* Minor edits to the doc

* Bug fix for DQN

* Update test

* Add test for custom objects
2021-03-06 15:17:43 +02:00

167 lines
4.3 KiB
ReStructuredText

.. _td3:
.. automodule:: stable_baselines3.td3
TD3
===
`Twin Delayed DDPG (TD3) <https://spinningup.openai.com/en/latest/algorithms/td3.html>`_ Addressing Function Approximation Error in Actor-Critic Methods.
TD3 is a direct successor of :ref:`DDPG <ddpg>` and improves it using three major tricks: clipped double Q-Learning, delayed policy update and target policy smoothing.
We recommend reading `OpenAI Spinning guide on TD3 <https://spinningup.openai.com/en/latest/algorithms/td3.html>`_ to learn more about those.
.. rubric:: Available Policies
.. autosummary::
:nosignatures:
MlpPolicy
CnnPolicy
Notes
-----
- Original paper: https://arxiv.org/pdf/1802.09477.pdf
- OpenAI Spinning Guide for TD3: https://spinningup.openai.com/en/latest/algorithms/td3.html
- Original Implementation: https://github.com/sfujim/TD3
.. note::
The default policies for TD3 differ a bit from others MlpPolicy: it uses ReLU instead of tanh activation,
to match the original paper
Can I use?
----------
- Recurrent policies: ❌
- Multi processing: ❌
- Gym spaces:
============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ❌ ✔️
Box ✔️ ✔️
MultiDiscrete ❌ ✔️
MultiBinary ❌ ✔️
============= ====== ===========
Example
-------
.. code-block:: python
import gym
import numpy as np
from stable_baselines3 import TD3
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
env = gym.make("Pendulum-v0")
# The noise objects for TD3
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
model = TD3("MlpPolicy", env, action_noise=action_noise, verbose=1)
model.learn(total_timesteps=10000, log_interval=10)
model.save("td3_pendulum")
env = model.get_env()
del model # remove to demonstrate saving and loading
model = TD3.load("td3_pendulum")
obs = env.reset()
while True:
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
env.render()
Results
-------
PyBullet Environments
^^^^^^^^^^^^^^^^^^^^^
Results on the PyBullet benchmark (1M steps) using 3 seeds.
The complete learning curves are available in the `associated issue #48 <https://github.com/DLR-RM/stable-baselines3/issues/48>`_.
.. note::
Hyperparameters from the `gSDE paper <https://arxiv.org/abs/2005.05719>`_ were used (as they are tuned for PyBullet envs).
*Gaussian* means that the unstructured Gaussian noise is used for exploration,
*gSDE* (generalized State-Dependent Exploration) is used otherwise.
+--------------+--------------+--------------+--------------+
| Environments | SAC | SAC | TD3 |
+==============+==============+==============+==============+
| | Gaussian | gSDE | Gaussian |
+--------------+--------------+--------------+--------------+
| HalfCheetah | 2757 +/- 53 | 2984 +/- 202 | 2774 +/- 35 |
+--------------+--------------+--------------+--------------+
| Ant | 3146 +/- 35 | 3102 +/- 37 | 3305 +/- 43 |
+--------------+--------------+--------------+--------------+
| Hopper | 2422 +/- 168 | 2262 +/- 1 | 2429 +/- 126 |
+--------------+--------------+--------------+--------------+
| Walker2D | 2184 +/- 54 | 2136 +/- 67 | 2063 +/- 185 |
+--------------+--------------+--------------+--------------+
How to replicate the results?
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Clone the `rl-zoo repo <https://github.com/DLR-RM/rl-baselines3-zoo>`_:
.. code-block:: bash
git clone https://github.com/DLR-RM/rl-baselines3-zoo
cd rl-baselines3-zoo/
Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above):
.. code-block:: bash
python train.py --algo td3 --env $ENV_ID --eval-episodes 10 --eval-freq 10000
Plot the results:
.. code-block:: bash
python scripts/all_plots.py -a td3 -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/td3_results
python scripts/plot_from_file.py -i logs/td3_results.pkl -latex -l TD3
Parameters
----------
.. autoclass:: TD3
:members:
:inherited-members:
.. _td3_policies:
TD3 Policies
-------------
.. autoclass:: MlpPolicy
:members:
:inherited-members:
.. autoclass:: stable_baselines3.td3.policies.TD3Policy
:members:
:noindex:
.. autoclass:: CnnPolicy
:members: