stable-baselines3/docs/modules/a2c.rst

166 lines
4.4 KiB
ReStructuredText
Raw Normal View History

2020-01-20 15:19:35 +00:00
.. _a2c:
2020-05-05 13:02:35 +00:00
.. automodule:: stable_baselines3.a2c
2020-01-20 15:19:35 +00:00
A2C
====
A synchronous, deterministic variant of `Asynchronous Advantage Actor Critic (A3C) <https://arxiv.org/abs/1602.01783>`_.
It uses multiple workers to avoid the use of a replay buffer.
.. warning::
Implement HER (#120) * Added working her version, Online sampling is missing. * Updated test_her. * Added first version of online her sampling. Still problems with tensor dimensions. * Reformat * Fixed tests * Added some comments. * Updated changelog. * Add missing init file * Fixed some small bugs. * Reduced arguments for HER, small changes. * Added getattr. Fixed bug for online sampling. * Updated save/load funtions. Small changes. * Added her to init. * Updated save method. * Updated her ratio. * Move obs_wrapper * Added DQN test. * Fix potential bug * Offline and online her share same sample_goal function. * Changed lists into arrays. * Updated her test. * Fix online sampling * Fixed action bug. Updated time limit for episodes. * Updated convert_dict method to take keys as arguments. * Renamed obs dict wrapper. * Seed bit flipping env * Remove get_episode_dict * Add fast online sampling version * Added documentation. * Vectorized reward computation * Vectorized goal sampling * Update time limit for episodes in online her sampling. * Fix max episode length inference * Bug fix for Fetch envs * Fix for HER + gSDE * Reformat (new black version) * Added info dict to compute new reward. Check her_replay_buffer again. * Fix info buffer * Updated done flag. * Fixes for gSDE * Offline her version uses now HerReplayBuffer as episode storage. * Fix num_timesteps computation * Fix get torch params * Vectorized version for offline sampling. * Modified offline her sampling to use sample method of her_replay_buffer * Updated HER tests. * Updated documentation * Cleanup docstrings * Updated to review comments * Fix pytype * Update according to review comments. * Removed random goal strategy. Updated sample transitions. * Updated migration. Removed time signal removal. * Update doc * Fix potential load issue * Add VecNormalize support for dict obs * Updated saving/loading replay buffer for HER. * Fix test memory usage * Fixed save/load replay buffer. * Fixed save/load replay buffer * Fixed transition index after loading replay buffer in online sampling * Better error handling * Add tests for get_time_limit * More tests for VecNormalize with dict obs * Update doc * Improve HER description * Add test for sde support * Add comments * Add comments * Remove check that was always valid * Fix for terminal observation * Updated buffer size in offline version and reset of HER buffer * Reformat * Update doc * Remove np.empty + add doc * Fix loading * Updated loading replay buffer * Separate online and offline sampling + bug fixes * Update tensorboard log name * Version bump * Bug fix for special case Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de> Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-10-22 09:56:43 +00:00
If you find training unstable or want to match performance of stable-baselines A2C, consider using
``RMSpropTFLike`` optimizer from ``stable_baselines3.common.sb2_compat.rmsprop_tf_like``.
You can change optimizer with ``A2C(policy_kwargs=dict(optimizer_class=RMSpropTFLike, eps=1e-5))``.
Read more `here <https://github.com/DLR-RM/stable-baselines3/pull/110#issuecomment-663255241>`_.
2020-01-20 15:19:35 +00:00
Notes
-----
- Original paper: https://arxiv.org/abs/1602.01783
- OpenAI blog post: https://openai.com/blog/baselines-acktr-a2c/
Can I use?
----------
- Recurrent policies: ✔️
- Multi processing: ✔️
- Gym spaces:
============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ✔️ ✔️
2020-01-20 15:19:35 +00:00
Box ✔️ ✔️
MultiDiscrete ✔️ ✔️
MultiBinary ✔️ ✔️
2020-01-20 15:19:35 +00:00
============= ====== ===========
Example
-------
2020-05-07 08:10:51 +00:00
Train a A2C agent on ``CartPole-v1`` using 4 environments.
2020-01-20 15:19:35 +00:00
.. code-block:: python
import gym
2020-05-05 13:02:35 +00:00
from stable_baselines3 import A2C
from stable_baselines3.common.env_util import make_vec_env
2020-01-20 15:19:35 +00:00
# Parallel environments
env = make_vec_env("CartPole-v1", n_envs=4)
2020-01-20 15:19:35 +00:00
model = A2C("MlpPolicy", env, verbose=1)
2020-01-20 15:19:35 +00:00
model.learn(total_timesteps=25000)
model.save("a2c_cartpole")
del model # remove to demonstrate saving and loading
model = A2C.load("a2c_cartpole")
obs = env.reset()
while True:
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
env.render()
Results
-------
Atari Games
^^^^^^^^^^^
The complete learning curves are available in the `associated PR #110 <https://github.com/DLR-RM/stable-baselines3/pull/110>`_.
PyBullet Environments
^^^^^^^^^^^^^^^^^^^^^
Results on the PyBullet benchmark (2M steps) using 6 seeds.
The complete learning curves are available in the `associated issue #48 <https://github.com/DLR-RM/stable-baselines3/issues/48>`_.
.. note::
Hyperparameters from the `gSDE paper <https://arxiv.org/abs/2005.05719>`_ were used (as they are tuned for PyBullet envs).
*Gaussian* means that the unstructured Gaussian noise is used for exploration,
*gSDE* (generalized State-Dependent Exploration) is used otherwise.
+--------------+--------------+--------------+--------------+-------------+
| Environments | A2C | A2C | PPO | PPO |
+==============+==============+==============+==============+=============+
| | Gaussian | gSDE | Gaussian | gSDE |
+--------------+--------------+--------------+--------------+-------------+
| HalfCheetah | 2003 +/- 54 | 2032 +/- 122 | 1976 +/- 479 | 2826 +/- 45 |
+--------------+--------------+--------------+--------------+-------------+
| Ant | 2286 +/- 72 | 2443 +/- 89 | 2364 +/- 120 | 2782 +/- 76 |
+--------------+--------------+--------------+--------------+-------------+
| Hopper | 1627 +/- 158 | 1561 +/- 220 | 1567 +/- 339 | 2512 +/- 21 |
+--------------+--------------+--------------+--------------+-------------+
| Walker2D | 577 +/- 65 | 839 +/- 56 | 1230 +/- 147 | 2019 +/- 64 |
+--------------+--------------+--------------+--------------+-------------+
How to replicate the results?
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Clone the `rl-zoo repo <https://github.com/DLR-RM/rl-baselines3-zoo>`_:
.. code-block:: bash
git clone https://github.com/DLR-RM/rl-baselines3-zoo
cd rl-baselines3-zoo/
Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above):
.. code-block:: bash
python train.py --algo a2c --env $ENV_ID --eval-episodes 10 --eval-freq 10000
Plot the results (here for PyBullet envs only):
.. code-block:: bash
python scripts/all_plots.py -a a2c -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/a2c_results
python scripts/plot_from_file.py -i logs/a2c_results.pkl -latex -l A2C
2020-01-20 15:19:35 +00:00
Parameters
----------
.. autoclass:: A2C
:members:
:inherited-members:
Implement HER (#120) * Added working her version, Online sampling is missing. * Updated test_her. * Added first version of online her sampling. Still problems with tensor dimensions. * Reformat * Fixed tests * Added some comments. * Updated changelog. * Add missing init file * Fixed some small bugs. * Reduced arguments for HER, small changes. * Added getattr. Fixed bug for online sampling. * Updated save/load funtions. Small changes. * Added her to init. * Updated save method. * Updated her ratio. * Move obs_wrapper * Added DQN test. * Fix potential bug * Offline and online her share same sample_goal function. * Changed lists into arrays. * Updated her test. * Fix online sampling * Fixed action bug. Updated time limit for episodes. * Updated convert_dict method to take keys as arguments. * Renamed obs dict wrapper. * Seed bit flipping env * Remove get_episode_dict * Add fast online sampling version * Added documentation. * Vectorized reward computation * Vectorized goal sampling * Update time limit for episodes in online her sampling. * Fix max episode length inference * Bug fix for Fetch envs * Fix for HER + gSDE * Reformat (new black version) * Added info dict to compute new reward. Check her_replay_buffer again. * Fix info buffer * Updated done flag. * Fixes for gSDE * Offline her version uses now HerReplayBuffer as episode storage. * Fix num_timesteps computation * Fix get torch params * Vectorized version for offline sampling. * Modified offline her sampling to use sample method of her_replay_buffer * Updated HER tests. * Updated documentation * Cleanup docstrings * Updated to review comments * Fix pytype * Update according to review comments. * Removed random goal strategy. Updated sample transitions. * Updated migration. Removed time signal removal. * Update doc * Fix potential load issue * Add VecNormalize support for dict obs * Updated saving/loading replay buffer for HER. * Fix test memory usage * Fixed save/load replay buffer. * Fixed save/load replay buffer * Fixed transition index after loading replay buffer in online sampling * Better error handling * Add tests for get_time_limit * More tests for VecNormalize with dict obs * Update doc * Improve HER description * Add test for sde support * Add comments * Add comments * Remove check that was always valid * Fix for terminal observation * Updated buffer size in offline version and reset of HER buffer * Reformat * Update doc * Remove np.empty + add doc * Fix loading * Updated loading replay buffer * Separate online and offline sampling + bug fixes * Update tensorboard log name * Version bump * Bug fix for special case Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de> Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-10-22 09:56:43 +00:00
A2C Policies
-------------
.. autoclass:: MlpPolicy
:members:
:inherited-members:
.. autoclass:: stable_baselines3.common.policies.ActorCriticPolicy
:members:
:noindex:
.. autoclass:: CnnPolicy
:members:
.. autoclass:: stable_baselines3.common.policies.ActorCriticCnnPolicy
:members:
:noindex: