stable-baselines3/docs/modules/a2c.rst
kplers 9caa168686
Add policy documentation links to policy_kwargs parameter (#2050)
* docs: Add policy documentation links to policy_kwargs parameter

* Fix missing references, update changelog

---------

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2024-12-02 14:40:05 +01:00

193 lines
5.8 KiB
ReStructuredText

.. _a2c:
.. automodule:: stable_baselines3.a2c
A2C
====
A synchronous, deterministic variant of `Asynchronous Advantage Actor Critic (A3C) <https://arxiv.org/abs/1602.01783>`_.
It uses multiple workers to avoid the use of a replay buffer.
.. warning::
If you find training unstable or want to match performance of stable-baselines A2C, consider using
``RMSpropTFLike`` optimizer from ``stable_baselines3.common.sb2_compat.rmsprop_tf_like``.
You can change optimizer with ``A2C(policy_kwargs=dict(optimizer_class=RMSpropTFLike, optimizer_kwargs=dict(eps=1e-5)))``.
Read more `here <https://github.com/DLR-RM/stable-baselines3/pull/110#issuecomment-663255241>`_.
Notes
-----
- Original paper: https://arxiv.org/abs/1602.01783
- OpenAI blog post: https://openai.com/blog/baselines-acktr-a2c/
Can I use?
----------
- Recurrent policies: ❌
- Multi processing: ✔️
- Gym spaces:
============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ✔️ ✔️
Box ✔️ ✔️
MultiDiscrete ✔️ ✔️
MultiBinary ✔️ ✔️
Dict ❌ ✔️
============= ====== ===========
Example
-------
This example is only to demonstrate the use of the library and its functions, and the trained agents may not solve the environments. Optimized hyperparameters can be found in RL Zoo `repository <https://github.com/DLR-RM/rl-baselines3-zoo>`_.
Train a A2C agent on ``CartPole-v1`` using 4 environments.
.. code-block:: python
from stable_baselines3 import A2C
from stable_baselines3.common.env_util import make_vec_env
# Parallel environments
vec_env = make_vec_env("CartPole-v1", n_envs=4)
model = A2C("MlpPolicy", vec_env, verbose=1)
model.learn(total_timesteps=25000)
model.save("a2c_cartpole")
del model # remove to demonstrate saving and loading
model = A2C.load("a2c_cartpole")
obs = vec_env.reset()
while True:
action, _states = model.predict(obs)
obs, rewards, dones, info = vec_env.step(action)
vec_env.render("human")
.. note::
A2C is meant to be run primarily on the CPU, especially when you are not using a CNN. To improve CPU utilization, try turning off the GPU and using ``SubprocVecEnv`` instead of the default ``DummyVecEnv``:
.. code-block:: python
from stable_baselines3 import A2C
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import SubprocVecEnv
if __name__=="__main__":
env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv)
model = A2C("MlpPolicy", env, device="cpu")
model.learn(total_timesteps=25_000)
For more information, see :ref:`Vectorized Environments <vec_env>`, `Issue #1245 <https://github.com/DLR-RM/stable-baselines3/issues/1245>`_ or the `Multiprocessing notebook <https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb>`_.
Results
-------
Atari Games
^^^^^^^^^^^
The complete learning curves are available in the `associated PR #110 <https://github.com/DLR-RM/stable-baselines3/pull/110>`_.
PyBullet Environments
^^^^^^^^^^^^^^^^^^^^^
Results on the PyBullet benchmark (2M steps) using 6 seeds.
The complete learning curves are available in the `associated issue #48 <https://github.com/DLR-RM/stable-baselines3/issues/48>`_.
.. note::
Hyperparameters from the `gSDE paper <https://arxiv.org/abs/2005.05719>`_ were used (as they are tuned for PyBullet envs).
*Gaussian* means that the unstructured Gaussian noise is used for exploration,
*gSDE* (generalized State-Dependent Exploration) is used otherwise.
+--------------+--------------+--------------+--------------+-------------+
| Environments | A2C | A2C | PPO | PPO |
+==============+==============+==============+==============+=============+
| | Gaussian | gSDE | Gaussian | gSDE |
+--------------+--------------+--------------+--------------+-------------+
| HalfCheetah | 2003 +/- 54 | 2032 +/- 122 | 1976 +/- 479 | 2826 +/- 45 |
+--------------+--------------+--------------+--------------+-------------+
| Ant | 2286 +/- 72 | 2443 +/- 89 | 2364 +/- 120 | 2782 +/- 76 |
+--------------+--------------+--------------+--------------+-------------+
| Hopper | 1627 +/- 158 | 1561 +/- 220 | 1567 +/- 339 | 2512 +/- 21 |
+--------------+--------------+--------------+--------------+-------------+
| Walker2D | 577 +/- 65 | 839 +/- 56 | 1230 +/- 147 | 2019 +/- 64 |
+--------------+--------------+--------------+--------------+-------------+
How to replicate the results?
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Clone the `rl-zoo repo <https://github.com/DLR-RM/rl-baselines3-zoo>`_:
.. code-block:: bash
git clone https://github.com/DLR-RM/rl-baselines3-zoo
cd rl-baselines3-zoo/
Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above):
.. code-block:: bash
python train.py --algo a2c --env $ENV_ID --eval-episodes 10 --eval-freq 10000
Plot the results (here for PyBullet envs only):
.. code-block:: bash
python scripts/all_plots.py -a a2c -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/a2c_results
python scripts/plot_from_file.py -i logs/a2c_results.pkl -latex -l A2C
Parameters
----------
.. autoclass:: A2C
:members:
:inherited-members:
.. _a2c_policies:
A2C Policies
-------------
.. autoclass:: MlpPolicy
:members:
:inherited-members:
.. autoclass:: stable_baselines3.common.policies.ActorCriticPolicy
:members:
:noindex:
.. autoclass:: CnnPolicy
:members:
.. autoclass:: stable_baselines3.common.policies.ActorCriticCnnPolicy
:members:
:noindex:
.. autoclass:: MultiInputPolicy
:members:
.. autoclass:: stable_baselines3.common.policies.MultiInputActorCriticPolicy
:members:
:noindex: