mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-17 21:20:11 +00:00
* fix Atari in CI * fix dtype and atari extra * Update setup.py * remove 3.6 * note about how to install Atari * pendulum-v1 * atari v5 * black * fix pendulum capitalization * add minimum version * moved things in changelog to breaking changes * partial v5 fix * env update to pass tests * mismatch env version fixed * Fix tests after merge * Include autorom in setup.py * Blacken code * Fix dtype issue in more robust way * Fix GitLab CI: switch to Docker container with new black version * Remove workaround from GitLab. (May need to rebuild Docker for this though.) * Revert to v4 * Update setup.py * Apply suggestions from code review * Remove unnecessary autorom * Consistent gym versions Co-authored-by: J K Terry <justinkterry@gmail.com> Co-authored-by: Anssi <kaneran21@hotmail.com> Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: modanesh <mohamad4danesh@gmail.com> Co-authored-by: Adam Gleave <adam@gleave.me>
178 lines
4.8 KiB
ReStructuredText
178 lines
4.8 KiB
ReStructuredText
.. _sac:
|
|
|
|
.. automodule:: stable_baselines3.sac
|
|
|
|
|
|
SAC
|
|
===
|
|
|
|
`Soft Actor Critic (SAC) <https://spinningup.openai.com/en/latest/algorithms/sac.html>`_ Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor.
|
|
|
|
SAC is the successor of `Soft Q-Learning SQL <https://arxiv.org/abs/1702.08165>`_ and incorporates the double Q-learning trick from TD3.
|
|
A key feature of SAC, and a major difference with common RL algorithms, is that it is trained to maximize a trade-off between expected return and entropy, a measure of randomness in the policy.
|
|
|
|
|
|
.. rubric:: Available Policies
|
|
|
|
.. autosummary::
|
|
:nosignatures:
|
|
|
|
MlpPolicy
|
|
CnnPolicy
|
|
MultiInputPolicy
|
|
|
|
|
|
Notes
|
|
-----
|
|
|
|
- Original paper: https://arxiv.org/abs/1801.01290
|
|
- OpenAI Spinning Guide for SAC: https://spinningup.openai.com/en/latest/algorithms/sac.html
|
|
- Original Implementation: https://github.com/haarnoja/sac
|
|
- Blog post on using SAC with real robots: https://bair.berkeley.edu/blog/2018/12/14/sac/
|
|
|
|
.. note::
|
|
In our implementation, we use an entropy coefficient (as in OpenAI Spinning or Facebook Horizon),
|
|
which is the equivalent to the inverse of reward scale in the original SAC paper.
|
|
The main reason is that it avoids having too high errors when updating the Q functions.
|
|
|
|
|
|
.. note::
|
|
|
|
The default policies for SAC differ a bit from others MlpPolicy: it uses ReLU instead of tanh activation,
|
|
to match the original paper
|
|
|
|
|
|
Can I use?
|
|
----------
|
|
|
|
- Recurrent policies: ❌
|
|
- Multi processing: ✔️
|
|
- Gym spaces:
|
|
|
|
|
|
============= ====== ===========
|
|
Space Action Observation
|
|
============= ====== ===========
|
|
Discrete ❌ ✔️
|
|
Box ✔️ ✔️
|
|
MultiDiscrete ❌ ✔️
|
|
MultiBinary ❌ ✔️
|
|
Dict ❌ ✔️
|
|
============= ====== ===========
|
|
|
|
|
|
Example
|
|
-------
|
|
|
|
This example is only to demonstrate the use of the library and its functions, and the trained agents may not solve the environments. Optimized hyperparameters can be found in RL Zoo `repository <https://github.com/DLR-RM/rl-baselines3-zoo>`_.
|
|
|
|
.. code-block:: python
|
|
|
|
import gym
|
|
import numpy as np
|
|
|
|
from stable_baselines3 import SAC
|
|
|
|
env = gym.make("Pendulum-v1")
|
|
|
|
model = SAC("MlpPolicy", env, verbose=1)
|
|
model.learn(total_timesteps=10000, log_interval=4)
|
|
model.save("sac_pendulum")
|
|
|
|
del model # remove to demonstrate saving and loading
|
|
|
|
model = SAC.load("sac_pendulum")
|
|
|
|
obs = env.reset()
|
|
while True:
|
|
action, _states = model.predict(obs, deterministic=True)
|
|
obs, reward, done, info = env.step(action)
|
|
env.render()
|
|
if done:
|
|
obs = env.reset()
|
|
|
|
|
|
Results
|
|
-------
|
|
|
|
PyBullet Environments
|
|
^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
Results on the PyBullet benchmark (1M steps) using 3 seeds.
|
|
The complete learning curves are available in the `associated issue #48 <https://github.com/DLR-RM/stable-baselines3/issues/48>`_.
|
|
|
|
|
|
.. note::
|
|
|
|
Hyperparameters from the `gSDE paper <https://arxiv.org/abs/2005.05719>`_ were used (as they are tuned for PyBullet envs).
|
|
|
|
|
|
*Gaussian* means that the unstructured Gaussian noise is used for exploration,
|
|
*gSDE* (generalized State-Dependent Exploration) is used otherwise.
|
|
|
|
+--------------+--------------+--------------+--------------+
|
|
| Environments | SAC | SAC | TD3 |
|
|
+==============+==============+==============+==============+
|
|
| | Gaussian | gSDE | Gaussian |
|
|
+--------------+--------------+--------------+--------------+
|
|
| HalfCheetah | 2757 +/- 53 | 2984 +/- 202 | 2774 +/- 35 |
|
|
+--------------+--------------+--------------+--------------+
|
|
| Ant | 3146 +/- 35 | 3102 +/- 37 | 3305 +/- 43 |
|
|
+--------------+--------------+--------------+--------------+
|
|
| Hopper | 2422 +/- 168 | 2262 +/- 1 | 2429 +/- 126 |
|
|
+--------------+--------------+--------------+--------------+
|
|
| Walker2D | 2184 +/- 54 | 2136 +/- 67 | 2063 +/- 185 |
|
|
+--------------+--------------+--------------+--------------+
|
|
|
|
|
|
How to replicate the results?
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
Clone the `rl-zoo repo <https://github.com/DLR-RM/rl-baselines3-zoo>`_:
|
|
|
|
.. code-block:: bash
|
|
|
|
git clone https://github.com/DLR-RM/rl-baselines3-zoo
|
|
cd rl-baselines3-zoo/
|
|
|
|
|
|
Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above):
|
|
|
|
.. code-block:: bash
|
|
|
|
python train.py --algo sac --env $ENV_ID --eval-episodes 10 --eval-freq 10000
|
|
|
|
|
|
Plot the results:
|
|
|
|
.. code-block:: bash
|
|
|
|
python scripts/all_plots.py -a sac -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/sac_results
|
|
python scripts/plot_from_file.py -i logs/sac_results.pkl -latex -l SAC
|
|
|
|
|
|
Parameters
|
|
----------
|
|
|
|
.. autoclass:: SAC
|
|
:members:
|
|
:inherited-members:
|
|
|
|
.. _sac_policies:
|
|
|
|
SAC Policies
|
|
-------------
|
|
|
|
.. autoclass:: MlpPolicy
|
|
:members:
|
|
:inherited-members:
|
|
|
|
.. autoclass:: stable_baselines3.sac.policies.SACPolicy
|
|
:members:
|
|
:noindex:
|
|
|
|
.. autoclass:: CnnPolicy
|
|
:members:
|
|
|
|
.. autoclass:: MultiInputPolicy
|
|
:members:
|