This commit is contained in:
Antonin Raffin 2019-09-26 11:46:40 +02:00
parent 70e5de1d1b
commit b4dc9d4e4d
23 changed files with 942 additions and 25 deletions

20
docs/Makefile Normal file
View file

@ -0,0 +1,20 @@
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line.
SPHINXOPTS = -W # make warnings fatal
SPHINXBUILD = sphinx-build
SPHINXPROJ = StableBaselines
SOURCEDIR = .
BUILDDIR = _build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

25
docs/README.md Normal file
View file

@ -0,0 +1,25 @@
## Stable Baselines Documentation
This folder contains documentation for the RL baselines.
### Build the Documentation
#### Install Sphinx and Theme
```
pip install sphinx sphinx-autobuild sphinx-rtd-theme
```
#### Building the Docs
In the `docs/` folder:
```
make html
```
if you want to building each time a file is changed:
```
sphinx-autobuild . _build/html
```

52
docs/_static/css/baselines_theme.css vendored Normal file
View file

@ -0,0 +1,52 @@
/* Main colors adapted from pytorch doc */
:root{
--main-bg-color: #343A40;
--link-color: #FD7E14;
}
/* Header fonts y */
h1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend, p.caption {
font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;
}
/* Docs background */
.wy-side-nav-search{
background-color: var(--main-bg-color);
}
/* Mobile version */
.wy-nav-top{
background-color: var(--main-bg-color);
}
/* Change link colors (except for the menu) */
a {
color: var(--link-color);
}
a:hover {
color: #4F778F;
}
.wy-menu a {
color: #b3b3b3;
}
.wy-menu a:hover {
color: #b3b3b3;
}
a.icon.icon-home {
color: #b3b3b3;
}
.version{
color: var(--link-color) !important;
}
/* Make code blocks have a background */
.codeblock,pre.literal-block,.rst-content .literal-block,.rst-content pre.literal-block,div[class^='highlight'] {
background: #f8f8f8;;
}

203
docs/conf.py Normal file
View file

@ -0,0 +1,203 @@
# -*- coding: utf-8 -*-
#
# Configuration file for the Sphinx documentation builder.
#
# This file does only contain a selection of the most common options. For a
# full list see the documentation:
# http://www.sphinx-doc.org/en/master/config
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
import sys
from unittest.mock import MagicMock
# source code directory, relative to this file, for sphinx-autobuild
sys.path.insert(0, os.path.abspath('..'))
class Mock(MagicMock):
__subclasses__ = []
@classmethod
def __getattr__(cls, name):
return MagicMock()
# Mock modules that requires C modules
# Note: because of that we cannot test examples using CI
# 'torch', 'torch.nn', 'torch.nn.functional',
MOCK_MODULES = ['joblib', 'scipy', 'scipy.signal',
'pandas', 'mpi4py', 'mujoco-py', 'cv2',
'tensorflow', 'torch', 'torch.nn', 'torch.nn.functional',
'torch.distributions',
'tensorflow.contrib', 'tensorflow.contrib.layers',
'tensorflow.python', 'tensorflow.python.client', 'tensorflow.python.ops',
'tqdm', 'cloudpickle', 'matplotlib', 'matplotlib.pyplot',
'seaborn', 'gym', 'gym.spaces', 'gym.core',
'tensorflow.core', 'tensorflow.core.util', 'tensorflow.python.util',
'gym.wrappers', 'gym.wrappers.monitoring', 'zmq']
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
import torchy_baselines
# -- Project information -----------------------------------------------------
project = 'Torchy Baselines'
copyright = '2019, Torchy Baselines'
author = 'Torchy Baselines Contributors'
# The short X.Y version
version = 'master (' + torchy_baselines.__version__ + ' )'
# The full version, including alpha/beta/rc tags
release = torchy_baselines.__version__
# -- General configuration ---------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#
# needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.mathjax',
'sphinx.ext.ifconfig',
'sphinx.ext.viewcode',
]
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
# source_suffix = ['.rst', '.md']
source_suffix = '.rst'
# The master toctree document.
master_doc = 'index'
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = None
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path .
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
# Fix for read the docs
on_rtd = os.environ.get('READTHEDOCS') == 'True'
if on_rtd:
html_theme = 'default'
else:
html_theme = 'sphinx_rtd_theme'
html_logo = '_static/img/logo.png'
def setup(app):
app.add_stylesheet("css/baselines_theme.css")
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
#
# html_theme_options = {}
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
# Custom sidebar templates, must be a dictionary that maps document names
# to template names.
#
# The default sidebars (for documents that don't match any pattern) are
# defined by theme itself. Builtin themes are using these templates by
# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
# 'searchbox.html']``.
#
# html_sidebars = {}
# -- Options for HTMLHelp output ---------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = 'TorchyBaselinesdoc'
# -- Options for LaTeX output ------------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'TorchyBaselines.tex', 'Torchy Baselines Documentation',
'Torchy Baselines Contributors', 'manual'),
]
# -- Options for manual page output ------------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, 'torchybaselines', 'Torchy Baselines Documentation',
[author], 1)
]
# -- Options for Texinfo output ----------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'TorchyBaselines', 'Torchy Baselines Documentation',
author, 'TorchyBaselines', 'One line description of project.',
'Miscellaneous'),
]
# -- Extension configuration -------------------------------------------------

40
docs/guide/quickstart.rst Normal file
View file

@ -0,0 +1,40 @@
.. _quickstart:
===============
Getting Started
===============
Most of the library tries to follow a sklearn-like syntax for the Reinforcement Learning algorithms.
Here is a quick example of how to train and run SAC on a Pendulum environment:
.. code-block:: python
import gym
from torchy_baselines.sac.policies import MlpPolicy
from torchy_baselines.common.vec_env import DummyVecEnv
from torchy_baselines import SAC
# The algorithms require a vectorized environment to run
env = DummyVecEnv([lambda: gym.make('Pendulum-v0')])
model = SAC(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=10000)
obs = env.reset()
for i in range(1000):
action = model.predict(obs)
obs, rewards, dones, info = env.step(action)
env.render()
Or just train a model with a one liner if
`the environment is registered in Gym <https://github.com/openai/gym/wiki/Environments>`_ and if
the policy is registered:
.. code-block:: python
from torchy_baselines import SAC
model = SAC('MlpPolicy', 'Pendulum-v0').learn(10000)

71
docs/guide/vec_envs.rst Normal file
View file

@ -0,0 +1,71 @@
.. _vec_env:
.. automodule:: torchy_baselines.common.vec_env
Vectorized Environments
=======================
Vectorized Environments are a method for stacking multiple independent environments into a single environment.
Instead of training an RL agent on 1 environment per step, it allows us to train it on `n` environments per step.
Because of this, `actions` passed to the environment are now a vector (of dimension `n`).
It is the same for `observations`, `rewards` and end of episode signals (`dones`).
In the case of non-array observation spaces such as `Dict` or `Tuple`, where different sub-spaces
may have different shapes, the sub-observations are vectors (of dimension `n`).
============= ======= ============ ======== ========= ================
Name ``Box`` ``Discrete`` ``Dict`` ``Tuple`` Multi Processing
============= ======= ============ ======== ========= ================
DummyVecEnv ✔️ ✔️ ✔️ ✔️ ❌️
SubprocVecEnv ✔️ ✔️ ✔️ ✔️ ✔️
============= ======= ============ ======== ========= ================
.. note::
Vectorized environments are required when using wrappers for frame-stacking or normalization.
.. note::
When using vectorized environments, the environments are automatically reset at the end of each episode.
Thus, the observation returned for the i-th environment when ``done[i]`` is true will in fact be the first observation of the next episode, not the last observation of the episode that has just terminated.
You can access the "real" final observation of the terminated episode—that is, the one that accompanied the ``done`` event provided by the underlying environment—using the ``terminal_observation`` keys in the info dicts returned by the vecenv.
.. warning::
When using ``SubprocVecEnv``, users must wrap the code in an ``if __name__ == "__main__":`` if using the ``forkserver`` or ``spawn`` start method (default on Windows).
On Linux, the default start method is ``fork`` which is not thread safe and can create deadlocks.
For more information, see Python's `multiprocessing guidelines <https://docs.python.org/3/library/multiprocessing.html#the-spawn-and-forkserver-start-methods>`_.
VecEnv
------
.. autoclass:: VecEnv
:members:
DummyVecEnv
-----------
.. autoclass:: DummyVecEnv
:members:
SubprocVecEnv
-------------
.. autoclass:: SubprocVecEnv
:members:
Wrappers
--------
VecFrameStack
~~~~~~~~~~~~~
.. autoclass:: VecFrameStack
:members:
VecNormalize
~~~~~~~~~~~~
.. autoclass:: VecNormalize
:members:

63
docs/index.rst Normal file
View file

@ -0,0 +1,63 @@
.. Stable Baselines documentation master file, created by
sphinx-quickstart on Thu Sep 26 11:06:54 2019.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
Welcome to Torchy Baselines docs! - Pytorch RL Baselines
========================================================
`Torchy Baselines <https://github.com/hill-a/stable-baselines>`_ is the PyTorch version of `Stable Baselines <https://github.com/hill-a/stable-baselines>`_,
a set of improved implementations of reinforcement learning algorithms.
RL Baselines Zoo (collection of pre-trained agents): https://github.com/araffin/rl-baselines-zoo
RL Baselines zoo also offers a simple interface to train, evaluate agents and do hyperparameter tuning.
.. toctree::
:maxdepth: 2
:caption: User Guide
guide/quickstart
guide/vec_envs
.. toctree::
:maxdepth: 1
:caption: RL Algorithms
modules/base
modules/ppo
modules/sac
modules/td3
.. toctree::
:maxdepth: 1
:caption: Misc
misc/changelog
Citing Torchy Baselines
-----------------------
To cite this project in publications:
.. code-block:: bibtex
@misc{torchy-baselines,
author = {Raffin, Antonin and Hill, Ashley and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi},
title = {Torchy Baselines},
year = {2019},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/hill-a/stable-baselines}},
}
Indices and tables
-------------------
* :ref:`genindex`
* :ref:`search`
* :ref:`modindex`

36
docs/make.bat Normal file
View file

@ -0,0 +1,36 @@
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=.
set BUILDDIR=_build
set SPHINXPROJ=StableBaselines
if "%1" == "" goto help
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.http://sphinx-doc.org/
exit /b 1
)
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
:end
popd

46
docs/misc/changelog.rst Normal file
View file

@ -0,0 +1,46 @@
.. _changelog:
Changelog
==========
Pre-Release 0.0.3a0 (WIP)
-------------------------
**Initial Release**
Breaking Changes:
^^^^^^^^^^^^^^^^^
New Features:
^^^^^^^^^^^^^
- Initial release of CEM-RL, PPO, SAC and TD3
Bug Fixes:
^^^^^^^^^^
Deprecations:
^^^^^^^^^^^^^
Others:
^^^^^^^
Documentation:
^^^^^^^^^^^^^^
Maintainers
-----------
Torchy-Baselines is currently maintained by `Antonin Raffin`_ (aka `@araffin`_).
.. _Antonin Raffin: https://araffin.github.io/
.. _@araffin: https://github.com/araffin
Contributors:
-------------
In random order...
Thanks to @hill-a @enerijunior @AdamGleave @Miffyli

12
docs/modules/base.rst Normal file
View file

@ -0,0 +1,12 @@
.. _base_algo:
.. automodule:: torchy_baselines.common.base_class
Base RL Class
=============
Common interface for all the RL algorithms
.. autoclass:: BaseRLModel
:members:

83
docs/modules/ppo.rst Normal file
View file

@ -0,0 +1,83 @@
.. _ppo2:
.. automodule:: torchy_baselines.ppo
PPO
===
The `Proximal Policy Optimization <https://arxiv.org/abs/1707.06347>`_ algorithm combines ideas from A2C (having multiple workers)
and TRPO (it uses a trust region to improve the actor).
The main idea is that after an update, the new policy should be not too far form the old policy.
For that, ppo uses clipping to avoid too large update.
.. note::
PPO contains several modifications from the original algorithm not documented
by OpenAI: advantages are normalized and value function can be also clipped .
Notes
-----
- Original paper: https://arxiv.org/abs/1707.06347
- Clear explanation of PPO on Arxiv Insights channel: https://www.youtube.com/watch?v=5P7I-xPq8u8
- OpenAI blog post: https://blog.openai.com/openai-baselines-ppo/
Can I use?
----------
- Recurrent policies: ❌
- Multi processing: ✔️
- Gym spaces:
============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ❌ ❌
Box ✔️ ✔️
MultiDiscrete ❌ ❌
MultiBinary ❌ ❌
============= ====== ===========
Example
-------
Train a PPO agent on `Pendulum-v0` using 4 processes.
.. code-block:: python
import gym
from torchy_baselines.ppo.policies import MlpPolicy
from torchy_baselines.common.vec_env import SubprocVecEnv
from torchy_baselines import PPO
# multiprocess environment
n_cpu = 4
env = SubprocVecEnv([lambda: gym.make('Pendulum-v0') for i in range(n_cpu)])
model = PPO(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=25000)
model.save("ppo2_cartpole")
del model # remove to demonstrate saving and loading
model = PPO.load("ppo2_cartpole")
# Enjoy trained agent
obs = env.reset()
while True:
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
env.render()
Parameters
----------
.. autoclass:: PPO
:members:
:inherited-members:

110
docs/modules/sac.rst Normal file
View file

@ -0,0 +1,110 @@
.. _sac:
.. automodule:: torchy_baselines.sac
SAC
===
`Soft Actor Critic (SAC) <https://spinningup.openai.com/en/latest/algorithms/sac.html>`_ Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor.
SAC is the successor of `Soft Q-Learning SQL <https://arxiv.org/abs/1702.08165>`_ and incorporates the double Q-learning trick from TD3.
A key feature of SAC, and a major difference with common RL algorithms, is that it is trained to maximize a trade-off between expected return and entropy, a measure of randomness in the policy.
.. warning::
The SAC model does not support ``torchy_baselines.common.policies`` because it uses double q-values
and value estimation, as a result it must use its own policy models (see :ref:`sac_policies`).
.. rubric:: Available Policies
.. autosummary::
:nosignatures:
MlpPolicy
Notes
-----
- Original paper: https://arxiv.org/abs/1801.01290
- OpenAI Spinning Guide for SAC: https://spinningup.openai.com/en/latest/algorithms/sac.html
- Original Implementation: https://github.com/haarnoja/sac
- Blog post on using SAC with real robots: https://bair.berkeley.edu/blog/2018/12/14/sac/
.. note::
In our implementation, we use an entropy coefficient (as in OpenAI Spinning or Facebook Horizon),
which is the equivalent to the inverse of reward scale in the original SAC paper.
The main reason is that it avoids having too high errors when updating the Q functions.
.. note::
The default policies for SAC differ a bit from others MlpPolicy: it uses ReLU instead of tanh activation,
to match the original paper
Can I use?
----------
- Recurrent policies: ❌
- Multi processing: ❌
- Gym spaces:
============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ❌ ❌
Box ✔️ ✔️
MultiDiscrete ❌ ❌
MultiBinary ❌ ❌
============= ====== ===========
Example
-------
.. code-block:: python
import gym
import numpy as np
from torchy_baselines.sac.policies import MlpPolicy
from torchy_baselines.common.vec_env import DummyVecEnv
from torchy_baselines import SAC
env = gym.make('Pendulum-v0')
env = DummyVecEnv([lambda: env])
model = SAC(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=50000, log_interval=10)
model.save("sac_pendulum")
del model # remove to demonstrate saving and loading
model = SAC.load("sac_pendulum")
obs = env.reset()
while True:
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
env.render()
Parameters
----------
.. autoclass:: SAC
:members:
:inherited-members:
.. _sac_policies:
SAC Policies
-------------
.. autoclass:: MlpPolicy
:members:
:inherited-members:

104
docs/modules/td3.rst Normal file
View file

@ -0,0 +1,104 @@
.. _td3:
.. automodule:: torchy_baselines.td3
TD3
===
`Twin Delayed DDPG (TD3) <https://spinningup.openai.com/en/latest/algorithms/td3.html>`_ Addressing Function Approximation Error in Actor-Critic Methods.
TD3 is a direct successor of DDPG and improves it using three major tricks: clipped double Q-Learning, delayed policy update and target policy smoothing.
We recommend reading `OpenAI Spinning guide on TD3 <https://spinningup.openai.com/en/latest/algorithms/td3.html>`_ to learn more about those.
.. warning::
The TD3 model does not support ``torchy_baselines.common.policies`` because it uses double q-values
estimation, as a result it must use its own policy models (see :ref:`td3_policies`).
.. rubric:: Available Policies
.. autosummary::
:nosignatures:
MlpPolicy
Notes
-----
- Original paper: https://arxiv.org/pdf/1802.09477.pdf
- OpenAI Spinning Guide for TD3: https://spinningup.openai.com/en/latest/algorithms/td3.html
- Original Implementation: https://github.com/sfujim/TD3
.. note::
The default policies for TD3 differ a bit from others MlpPolicy: it uses ReLU instead of tanh activation,
to match the original paper
Can I use?
----------
- Recurrent policies: ❌
- Multi processing: ❌
- Gym spaces:
============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ❌ ❌
Box ✔️ ✔️
MultiDiscrete ❌ ❌
MultiBinary ❌ ❌
============= ====== ===========
Example
-------
.. code-block:: python
import numpy as np
from torchy_baselines import TD3
from torchy_baselines.td3.policies import MlpPolicy
from torchy_baselines.ddpg.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
# The noise objects for TD3
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
model = TD3(MlpPolicy, 'Pendulum-v0', action_noise=action_noise, verbose=1)
model.learn(total_timesteps=50000, log_interval=10)
model.save("td3_pendulum")
env = model.get_env()
del model # remove to demonstrate saving and loading
model = TD3.load("td3_pendulum")
obs = env.reset()
while True:
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
env.render()
Parameters
----------
.. autoclass:: TD3
:members:
:inherited-members:
.. _td3_policies:
TD3 Policies
-------------
.. autoclass:: MlpPolicy
:members:
:inherited-members:

View file

@ -1 +1,2 @@
from torchy_baselines.cem_rl.cem_rl import CEMRL
from torchy_baselines.td3.policies import MlpPolicy

View file

@ -52,6 +52,7 @@ class DiagGaussianDistribution(Distribution):
def proba_distribution_net(self, latent_dim, log_std_init=0.0):
mean_actions = nn.Linear(latent_dim, self.action_dim)
# TODO: allow action dependent std
log_std = nn.Parameter(th.ones(self.action_dim) * log_std_init)
return mean_actions, log_std
@ -73,6 +74,11 @@ class DiagGaussianDistribution(Distribution):
def entropy(self):
return self.distribution.entropy()
def log_prob_from_params(self, mean_actions, log_std):
action, _ = self.proba_distribution(mean_actions, log_std)
log_prob = self.log_prob(action)
return action, log_prob
def log_prob(self, action):
log_prob = self.distribution.log_prob(action)
if len(log_prob.shape) > 1:
@ -87,6 +93,7 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
super(SquashedDiagGaussianDistribution, self).__init__(action_dim)
# Avoid NaN (prevents division by zero or log of zero)
self.epsilon = epsilon
self.gaussian_action = None
def proba_distribution(self, mean_actions, log_std, deterministic=False):
action, _ = super(SquashedDiagGaussianDistribution, self).proba_distribution(mean_actions, log_std, deterministic)
@ -114,6 +121,6 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
# Log likelihood for a gaussian distribution
log_prob = super(SquashedDiagGaussianDistribution, self).log_prob(gaussian_action)
# Squash correction (from original implementation)
# Squash correction (from original SAC implementation)
log_prob -= th.sum(th.log(1 - action ** 2 + self.epsilon), dim=1)
return log_prob

View file

@ -124,10 +124,14 @@ def register_policy(name, policy):
:param policy: (subclass of BasePolicy) the policy
"""
sub_class = None
for cls in BasePolicy.__subclasses__():
if issubclass(policy, cls):
sub_class = cls
break
# For building the doc
try:
for cls in BasePolicy.__subclasses__():
if issubclass(policy, cls):
sub_class = cls
break
except AttributeError:
sub_class = str(th.random.randint(100))
if sub_class is None:
raise ValueError("Error: the policy {} is not of any known subclasses of BasePolicy!".format(policy))

View file

@ -1 +1,2 @@
from torchy_baselines.ppo.ppo import PPO
from torchy_baselines.ppo.policies import MlpPolicy

View file

@ -5,6 +5,7 @@ from copy import deepcopy
import gym
import torch as th
import torch.nn.functional as F
# Check if tensorboard is available for pytorch
try:
from torch.utils.tensorboard import SummaryWriter
except ImportError:
@ -21,28 +22,63 @@ from torchy_baselines.ppo.policies import PPOPolicy
class PPO(BaseRLModel):
"""
Implementation of Proximal Policy Optimization (PPO) (clip version)
Proximal Policy Optimization algorithm (PPO) (clip version)
Paper: https://arxiv.org/abs/1707.06347
Code: https://github.com/openai/spinningup/
and https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail
and stable_baselines
Code: This implementation borrows code from OpenAI spinningup (https://github.com/openai/spinningup/)
https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and
and Stable Baselines (PPO2 from https://github.com/hill-a/stable-baselines)
Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html
:param policy: (PPOPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str)
:param learning_rate: (float or callable) The learning rate, it can be a function
:param n_steps: (int) The number of steps to run for each environment per update
(i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
:param batch_size: (int) Minibatch size
:param n_epochs: (int) Number of epoch when optimizing the surrogate loss
:param gamma: (float) Discount factor
:param gae_lambda: (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator
:param clip_range: (float or callable) Clipping parameter, it can be a function
:param clip_range_vf: (float or callable) Clipping parameter for the value function, it can be a function.
This is a parameter specific to the OpenAI implementation. If None is passed (default),
no clipping will be done on the value function.
IMPORTANT: this clipping depends on the reward scaling.
:param ent_coef: (float) Entropy coefficient for the loss calculation
:param vf_coef: (float) Value function coefficient for the loss calculation
:param max_grad_norm: (float) The maximum value for the gradient clipping
:param target_kl: (float) Limit the KL divergence between updates,
because the clipping is not enough to prevent large update
see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
By default, there is no limit on the kl div.
:param tensorboard_log: (str) the log location for tensorboard (if None, no logging)
:param create_eval_env: (bool) Whether to create a second environment that will be
used for evaluating the agent periodically. (Only available when passing string for the environment)
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
:param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
:param seed: (int) Seed for the pseudo random generators
:param device: (str or th.device) Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
"""
def __init__(self, policy, env, policy_kwargs=None, verbose=0,
learning_rate=3e-4, seed=0, device='auto',
n_optim=5, batch_size=64, n_steps=256,
gamma=0.99, gae_lambda=0.95, clip_range=0.2,
ent_coef=0.01, vf_coef=0.5, max_grad_norm=0.5,
target_kl=None, clip_range_vf=None, create_eval_env=False,
tensorboard_log=None,
def __init__(self, policy, env, learning_rate=3e-4,
n_steps=2048, batch_size=64, n_epochs=10,
gamma=0.99, gae_lambda=0.95, clip_range=0.2, clip_range_vf=None,
ent_coef=0.0, vf_coef=0.5, max_grad_norm=0.5,
target_kl=None, tensorboard_log=None, create_eval_env=False,
policy_kwargs=None, verbose=0, seed=0, device='auto',
_init_setup_model=True):
super(PPO, self).__init__(policy, env, PPOPolicy, policy_kwargs,
verbose, device, create_eval_env=create_eval_env, support_multi_env=True)
super(PPO, self).__init__(policy, env, PPOPolicy, policy_kwargs=policy_kwargs,
verbose=verbose, device=device,
create_eval_env=create_eval_env, support_multi_env=True)
self.learning_rate = learning_rate
self.seed = seed
self.batch_size = batch_size
self.n_optim = n_optim
self.n_epochs = n_epochs
self.n_steps = n_steps
self.gamma = gamma
self.gae_lambda = gae_lambda
@ -177,7 +213,7 @@ class PPO(BaseRLModel):
obs = self.env.reset()
eval_env = self._get_eval_env(eval_env)
if self.tensorboard_log is not None:
if self.tensorboard_log is not None and SummaryWriter is not None:
self.tb_writer = SummaryWriter(log_dir=os.path.join(self.tensorboard_log, tb_log_name))
while self.num_timesteps < total_timesteps:
@ -193,7 +229,7 @@ class PPO(BaseRLModel):
self.num_timesteps += self.n_steps * self.n_envs
timesteps_since_eval += self.n_steps * self.n_envs
self.train(self.n_optim, batch_size=self.batch_size)
self.train(self.n_epochs, batch_size=self.batch_size)
# Evaluate agent
if 0 < eval_freq <= timesteps_since_eval and eval_env is not None:

View file

@ -1 +1,2 @@
from torchy_baselines.sac.sac import SAC
from torchy_baselines.sac.policies import MlpPolicy

View file

@ -69,7 +69,7 @@ class Critic(BaseNetwork):
class SACPolicy(BasePolicy):
def __init__(self, observation_space, action_space,
learning_rate=1e-3, net_arch=None, device='cpu',
learning_rate=3e-4, net_arch=None, device='cpu',
activation_fn=nn.ReLU):
super(SACPolicy, self).__init__(observation_space, action_space, device)
self.obs_dim = self.observation_space.shape[0]
@ -96,9 +96,6 @@ class SACPolicy(BasePolicy):
self.critic_target.load_state_dict(self.critic.state_dict())
self.critic.optimizer = th.optim.Adam(self.critic.parameters(), lr=learning_rate)
def actor_forward(self, state, deterministic=False):
pass
def make_actor(self):
return Actor(**self.net_args).to(self.device)

View file

@ -48,6 +48,8 @@ class SAC(BaseRLModel):
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
:param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
:param seed: (int) Seed for the pseudo random generators
:param device: (str or th.device) Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
"""
def __init__(self, policy, env, learning_rate=3e-4, buffer_size=int(1e6),

View file

@ -1 +1,2 @@
from torchy_baselines.td3.td3 import TD3
from torchy_baselines.td3.policies import MlpPolicy

View file

@ -42,6 +42,8 @@ class TD3(BaseRLModel):
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
:param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
:param seed: (int) Seed for the pseudo random generators
:param device: (str or th.device) Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
"""
def __init__(self, policy, env, buffer_size=int(1e6), learning_rate=1e-3,