mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Add doc
This commit is contained in:
parent
70e5de1d1b
commit
b4dc9d4e4d
23 changed files with 942 additions and 25 deletions
20
docs/Makefile
Normal file
20
docs/Makefile
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
# Minimal makefile for Sphinx documentation
|
||||
#
|
||||
|
||||
# You can set these variables from the command line.
|
||||
SPHINXOPTS = -W # make warnings fatal
|
||||
SPHINXBUILD = sphinx-build
|
||||
SPHINXPROJ = StableBaselines
|
||||
SOURCEDIR = .
|
||||
BUILDDIR = _build
|
||||
|
||||
# Put it first so that "make" without argument is like "make help".
|
||||
help:
|
||||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
|
||||
.PHONY: help Makefile
|
||||
|
||||
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||
%: Makefile
|
||||
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
25
docs/README.md
Normal file
25
docs/README.md
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
## Stable Baselines Documentation
|
||||
|
||||
This folder contains documentation for the RL baselines.
|
||||
|
||||
|
||||
### Build the Documentation
|
||||
|
||||
#### Install Sphinx and Theme
|
||||
|
||||
```
|
||||
pip install sphinx sphinx-autobuild sphinx-rtd-theme
|
||||
```
|
||||
|
||||
#### Building the Docs
|
||||
|
||||
In the `docs/` folder:
|
||||
```
|
||||
make html
|
||||
```
|
||||
|
||||
if you want to building each time a file is changed:
|
||||
|
||||
```
|
||||
sphinx-autobuild . _build/html
|
||||
```
|
||||
52
docs/_static/css/baselines_theme.css
vendored
Normal file
52
docs/_static/css/baselines_theme.css
vendored
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
/* Main colors adapted from pytorch doc */
|
||||
:root{
|
||||
--main-bg-color: #343A40;
|
||||
--link-color: #FD7E14;
|
||||
}
|
||||
|
||||
/* Header fonts y */
|
||||
h1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend, p.caption {
|
||||
font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;
|
||||
}
|
||||
|
||||
|
||||
/* Docs background */
|
||||
.wy-side-nav-search{
|
||||
background-color: var(--main-bg-color);
|
||||
}
|
||||
|
||||
/* Mobile version */
|
||||
.wy-nav-top{
|
||||
background-color: var(--main-bg-color);
|
||||
}
|
||||
|
||||
/* Change link colors (except for the menu) */
|
||||
a {
|
||||
color: var(--link-color);
|
||||
}
|
||||
|
||||
a:hover {
|
||||
color: #4F778F;
|
||||
}
|
||||
|
||||
.wy-menu a {
|
||||
color: #b3b3b3;
|
||||
}
|
||||
|
||||
.wy-menu a:hover {
|
||||
color: #b3b3b3;
|
||||
}
|
||||
|
||||
a.icon.icon-home {
|
||||
color: #b3b3b3;
|
||||
}
|
||||
|
||||
.version{
|
||||
color: var(--link-color) !important;
|
||||
}
|
||||
|
||||
|
||||
/* Make code blocks have a background */
|
||||
.codeblock,pre.literal-block,.rst-content .literal-block,.rst-content pre.literal-block,div[class^='highlight'] {
|
||||
background: #f8f8f8;;
|
||||
}
|
||||
203
docs/conf.py
Normal file
203
docs/conf.py
Normal file
|
|
@ -0,0 +1,203 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Configuration file for the Sphinx documentation builder.
|
||||
#
|
||||
# This file does only contain a selection of the most common options. For a
|
||||
# full list see the documentation:
|
||||
# http://www.sphinx-doc.org/en/master/config
|
||||
|
||||
# -- Path setup --------------------------------------------------------------
|
||||
|
||||
# If extensions (or modules to document with autodoc) are in another directory,
|
||||
# add these directories to sys.path here. If the directory is relative to the
|
||||
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||
#
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# source code directory, relative to this file, for sphinx-autobuild
|
||||
sys.path.insert(0, os.path.abspath('..'))
|
||||
|
||||
|
||||
class Mock(MagicMock):
|
||||
__subclasses__ = []
|
||||
@classmethod
|
||||
def __getattr__(cls, name):
|
||||
return MagicMock()
|
||||
|
||||
|
||||
# Mock modules that requires C modules
|
||||
# Note: because of that we cannot test examples using CI
|
||||
# 'torch', 'torch.nn', 'torch.nn.functional',
|
||||
MOCK_MODULES = ['joblib', 'scipy', 'scipy.signal',
|
||||
'pandas', 'mpi4py', 'mujoco-py', 'cv2',
|
||||
'tensorflow', 'torch', 'torch.nn', 'torch.nn.functional',
|
||||
'torch.distributions',
|
||||
'tensorflow.contrib', 'tensorflow.contrib.layers',
|
||||
'tensorflow.python', 'tensorflow.python.client', 'tensorflow.python.ops',
|
||||
'tqdm', 'cloudpickle', 'matplotlib', 'matplotlib.pyplot',
|
||||
'seaborn', 'gym', 'gym.spaces', 'gym.core',
|
||||
'tensorflow.core', 'tensorflow.core.util', 'tensorflow.python.util',
|
||||
'gym.wrappers', 'gym.wrappers.monitoring', 'zmq']
|
||||
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
|
||||
|
||||
|
||||
import torchy_baselines
|
||||
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = 'Torchy Baselines'
|
||||
copyright = '2019, Torchy Baselines'
|
||||
author = 'Torchy Baselines Contributors'
|
||||
|
||||
# The short X.Y version
|
||||
version = 'master (' + torchy_baselines.__version__ + ' )'
|
||||
# The full version, including alpha/beta/rc tags
|
||||
release = torchy_baselines.__version__
|
||||
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
# If your documentation needs a minimal Sphinx version, state it here.
|
||||
#
|
||||
# needs_sphinx = '1.0'
|
||||
|
||||
# Add any Sphinx extension module names here, as strings. They can be
|
||||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||
# ones.
|
||||
extensions = [
|
||||
'sphinx.ext.autodoc',
|
||||
'sphinx.ext.autosummary',
|
||||
'sphinx.ext.mathjax',
|
||||
'sphinx.ext.ifconfig',
|
||||
'sphinx.ext.viewcode',
|
||||
]
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ['_templates']
|
||||
|
||||
# The suffix(es) of source filenames.
|
||||
# You can specify multiple suffix as a list of string:
|
||||
#
|
||||
# source_suffix = ['.rst', '.md']
|
||||
source_suffix = '.rst'
|
||||
|
||||
# The master toctree document.
|
||||
master_doc = 'index'
|
||||
|
||||
# The language for content autogenerated by Sphinx. Refer to documentation
|
||||
# for a list of supported languages.
|
||||
#
|
||||
# This is also used if you do content translation via gettext catalogs.
|
||||
# Usually you set "language" from the command line for these cases.
|
||||
language = None
|
||||
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
# This pattern also affects html_static_path and html_extra_path .
|
||||
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
|
||||
|
||||
# The name of the Pygments (syntax highlighting) style to use.
|
||||
pygments_style = 'sphinx'
|
||||
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
|
||||
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||
# a list of builtin themes.
|
||||
|
||||
# Fix for read the docs
|
||||
on_rtd = os.environ.get('READTHEDOCS') == 'True'
|
||||
if on_rtd:
|
||||
html_theme = 'default'
|
||||
else:
|
||||
html_theme = 'sphinx_rtd_theme'
|
||||
|
||||
html_logo = '_static/img/logo.png'
|
||||
|
||||
|
||||
def setup(app):
|
||||
app.add_stylesheet("css/baselines_theme.css")
|
||||
|
||||
# Theme options are theme-specific and customize the look and feel of a theme
|
||||
# further. For a list of options available for each theme, see the
|
||||
# documentation.
|
||||
#
|
||||
# html_theme_options = {}
|
||||
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||
html_static_path = ['_static']
|
||||
|
||||
# Custom sidebar templates, must be a dictionary that maps document names
|
||||
# to template names.
|
||||
#
|
||||
# The default sidebars (for documents that don't match any pattern) are
|
||||
# defined by theme itself. Builtin themes are using these templates by
|
||||
# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
|
||||
# 'searchbox.html']``.
|
||||
#
|
||||
# html_sidebars = {}
|
||||
|
||||
|
||||
# -- Options for HTMLHelp output ---------------------------------------------
|
||||
|
||||
# Output file base name for HTML help builder.
|
||||
htmlhelp_basename = 'TorchyBaselinesdoc'
|
||||
|
||||
|
||||
# -- Options for LaTeX output ------------------------------------------------
|
||||
|
||||
latex_elements = {
|
||||
# The paper size ('letterpaper' or 'a4paper').
|
||||
#
|
||||
# 'papersize': 'letterpaper',
|
||||
|
||||
# The font size ('10pt', '11pt' or '12pt').
|
||||
#
|
||||
# 'pointsize': '10pt',
|
||||
|
||||
# Additional stuff for the LaTeX preamble.
|
||||
#
|
||||
# 'preamble': '',
|
||||
|
||||
# Latex figure (float) alignment
|
||||
#
|
||||
# 'figure_align': 'htbp',
|
||||
}
|
||||
|
||||
# Grouping the document tree into LaTeX files. List of tuples
|
||||
# (source start file, target name, title,
|
||||
# author, documentclass [howto, manual, or own class]).
|
||||
latex_documents = [
|
||||
(master_doc, 'TorchyBaselines.tex', 'Torchy Baselines Documentation',
|
||||
'Torchy Baselines Contributors', 'manual'),
|
||||
]
|
||||
|
||||
|
||||
# -- Options for manual page output ------------------------------------------
|
||||
|
||||
# One entry per manual page. List of tuples
|
||||
# (source start file, name, description, authors, manual section).
|
||||
man_pages = [
|
||||
(master_doc, 'torchybaselines', 'Torchy Baselines Documentation',
|
||||
[author], 1)
|
||||
]
|
||||
|
||||
|
||||
# -- Options for Texinfo output ----------------------------------------------
|
||||
|
||||
# Grouping the document tree into Texinfo files. List of tuples
|
||||
# (source start file, target name, title, author,
|
||||
# dir menu entry, description, category)
|
||||
texinfo_documents = [
|
||||
(master_doc, 'TorchyBaselines', 'Torchy Baselines Documentation',
|
||||
author, 'TorchyBaselines', 'One line description of project.',
|
||||
'Miscellaneous'),
|
||||
]
|
||||
|
||||
|
||||
# -- Extension configuration -------------------------------------------------
|
||||
40
docs/guide/quickstart.rst
Normal file
40
docs/guide/quickstart.rst
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
.. _quickstart:
|
||||
|
||||
===============
|
||||
Getting Started
|
||||
===============
|
||||
|
||||
Most of the library tries to follow a sklearn-like syntax for the Reinforcement Learning algorithms.
|
||||
|
||||
Here is a quick example of how to train and run SAC on a Pendulum environment:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import gym
|
||||
|
||||
from torchy_baselines.sac.policies import MlpPolicy
|
||||
from torchy_baselines.common.vec_env import DummyVecEnv
|
||||
from torchy_baselines import SAC
|
||||
|
||||
# The algorithms require a vectorized environment to run
|
||||
env = DummyVecEnv([lambda: gym.make('Pendulum-v0')])
|
||||
|
||||
model = SAC(MlpPolicy, env, verbose=1)
|
||||
model.learn(total_timesteps=10000)
|
||||
|
||||
obs = env.reset()
|
||||
for i in range(1000):
|
||||
action = model.predict(obs)
|
||||
obs, rewards, dones, info = env.step(action)
|
||||
env.render()
|
||||
|
||||
|
||||
Or just train a model with a one liner if
|
||||
`the environment is registered in Gym <https://github.com/openai/gym/wiki/Environments>`_ and if
|
||||
the policy is registered:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from torchy_baselines import SAC
|
||||
|
||||
model = SAC('MlpPolicy', 'Pendulum-v0').learn(10000)
|
||||
71
docs/guide/vec_envs.rst
Normal file
71
docs/guide/vec_envs.rst
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
.. _vec_env:
|
||||
|
||||
.. automodule:: torchy_baselines.common.vec_env
|
||||
|
||||
Vectorized Environments
|
||||
=======================
|
||||
|
||||
Vectorized Environments are a method for stacking multiple independent environments into a single environment.
|
||||
Instead of training an RL agent on 1 environment per step, it allows us to train it on `n` environments per step.
|
||||
Because of this, `actions` passed to the environment are now a vector (of dimension `n`).
|
||||
It is the same for `observations`, `rewards` and end of episode signals (`dones`).
|
||||
In the case of non-array observation spaces such as `Dict` or `Tuple`, where different sub-spaces
|
||||
may have different shapes, the sub-observations are vectors (of dimension `n`).
|
||||
|
||||
============= ======= ============ ======== ========= ================
|
||||
Name ``Box`` ``Discrete`` ``Dict`` ``Tuple`` Multi Processing
|
||||
============= ======= ============ ======== ========= ================
|
||||
DummyVecEnv ✔️ ✔️ ✔️ ✔️ ❌️
|
||||
SubprocVecEnv ✔️ ✔️ ✔️ ✔️ ✔️
|
||||
============= ======= ============ ======== ========= ================
|
||||
|
||||
.. note::
|
||||
|
||||
Vectorized environments are required when using wrappers for frame-stacking or normalization.
|
||||
|
||||
.. note::
|
||||
|
||||
When using vectorized environments, the environments are automatically reset at the end of each episode.
|
||||
Thus, the observation returned for the i-th environment when ``done[i]`` is true will in fact be the first observation of the next episode, not the last observation of the episode that has just terminated.
|
||||
You can access the "real" final observation of the terminated episode—that is, the one that accompanied the ``done`` event provided by the underlying environment—using the ``terminal_observation`` keys in the info dicts returned by the vecenv.
|
||||
|
||||
.. warning::
|
||||
|
||||
When using ``SubprocVecEnv``, users must wrap the code in an ``if __name__ == "__main__":`` if using the ``forkserver`` or ``spawn`` start method (default on Windows).
|
||||
On Linux, the default start method is ``fork`` which is not thread safe and can create deadlocks.
|
||||
|
||||
For more information, see Python's `multiprocessing guidelines <https://docs.python.org/3/library/multiprocessing.html#the-spawn-and-forkserver-start-methods>`_.
|
||||
|
||||
VecEnv
|
||||
------
|
||||
|
||||
.. autoclass:: VecEnv
|
||||
:members:
|
||||
|
||||
DummyVecEnv
|
||||
-----------
|
||||
|
||||
.. autoclass:: DummyVecEnv
|
||||
:members:
|
||||
|
||||
SubprocVecEnv
|
||||
-------------
|
||||
|
||||
.. autoclass:: SubprocVecEnv
|
||||
:members:
|
||||
|
||||
Wrappers
|
||||
--------
|
||||
|
||||
VecFrameStack
|
||||
~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: VecFrameStack
|
||||
:members:
|
||||
|
||||
|
||||
VecNormalize
|
||||
~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: VecNormalize
|
||||
:members:
|
||||
63
docs/index.rst
Normal file
63
docs/index.rst
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
.. Stable Baselines documentation master file, created by
|
||||
sphinx-quickstart on Thu Sep 26 11:06:54 2019.
|
||||
You can adapt this file completely to your liking, but it should at least
|
||||
contain the root `toctree` directive.
|
||||
|
||||
Welcome to Torchy Baselines docs! - Pytorch RL Baselines
|
||||
========================================================
|
||||
|
||||
`Torchy Baselines <https://github.com/hill-a/stable-baselines>`_ is the PyTorch version of `Stable Baselines <https://github.com/hill-a/stable-baselines>`_,
|
||||
a set of improved implementations of reinforcement learning algorithms.
|
||||
|
||||
RL Baselines Zoo (collection of pre-trained agents): https://github.com/araffin/rl-baselines-zoo
|
||||
|
||||
RL Baselines zoo also offers a simple interface to train, evaluate agents and do hyperparameter tuning.
|
||||
|
||||
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: User Guide
|
||||
|
||||
guide/quickstart
|
||||
guide/vec_envs
|
||||
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: RL Algorithms
|
||||
|
||||
modules/base
|
||||
modules/ppo
|
||||
modules/sac
|
||||
modules/td3
|
||||
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Misc
|
||||
|
||||
misc/changelog
|
||||
|
||||
|
||||
Citing Torchy Baselines
|
||||
-----------------------
|
||||
To cite this project in publications:
|
||||
|
||||
.. code-block:: bibtex
|
||||
|
||||
@misc{torchy-baselines,
|
||||
author = {Raffin, Antonin and Hill, Ashley and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi},
|
||||
title = {Torchy Baselines},
|
||||
year = {2019},
|
||||
publisher = {GitHub},
|
||||
journal = {GitHub repository},
|
||||
howpublished = {\url{https://github.com/hill-a/stable-baselines}},
|
||||
}
|
||||
|
||||
Indices and tables
|
||||
-------------------
|
||||
|
||||
* :ref:`genindex`
|
||||
* :ref:`search`
|
||||
* :ref:`modindex`
|
||||
36
docs/make.bat
Normal file
36
docs/make.bat
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
@ECHO OFF
|
||||
|
||||
pushd %~dp0
|
||||
|
||||
REM Command file for Sphinx documentation
|
||||
|
||||
if "%SPHINXBUILD%" == "" (
|
||||
set SPHINXBUILD=sphinx-build
|
||||
)
|
||||
set SOURCEDIR=.
|
||||
set BUILDDIR=_build
|
||||
set SPHINXPROJ=StableBaselines
|
||||
|
||||
if "%1" == "" goto help
|
||||
|
||||
%SPHINXBUILD% >NUL 2>NUL
|
||||
if errorlevel 9009 (
|
||||
echo.
|
||||
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
||||
echo.installed, then set the SPHINXBUILD environment variable to point
|
||||
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
||||
echo.may add the Sphinx directory to PATH.
|
||||
echo.
|
||||
echo.If you don't have Sphinx installed, grab it from
|
||||
echo.http://sphinx-doc.org/
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
|
||||
goto end
|
||||
|
||||
:help
|
||||
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
|
||||
|
||||
:end
|
||||
popd
|
||||
46
docs/misc/changelog.rst
Normal file
46
docs/misc/changelog.rst
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
.. _changelog:
|
||||
|
||||
Changelog
|
||||
==========
|
||||
|
||||
|
||||
Pre-Release 0.0.3a0 (WIP)
|
||||
-------------------------
|
||||
**Initial Release**
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Initial release of CEM-RL, PPO, SAC and TD3
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
|
||||
Maintainers
|
||||
-----------
|
||||
|
||||
Torchy-Baselines is currently maintained by `Antonin Raffin`_ (aka `@araffin`_).
|
||||
|
||||
.. _Antonin Raffin: https://araffin.github.io/
|
||||
.. _@araffin: https://github.com/araffin
|
||||
|
||||
|
||||
|
||||
Contributors:
|
||||
-------------
|
||||
In random order...
|
||||
|
||||
Thanks to @hill-a @enerijunior @AdamGleave @Miffyli
|
||||
12
docs/modules/base.rst
Normal file
12
docs/modules/base.rst
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
.. _base_algo:
|
||||
|
||||
.. automodule:: torchy_baselines.common.base_class
|
||||
|
||||
|
||||
Base RL Class
|
||||
=============
|
||||
|
||||
Common interface for all the RL algorithms
|
||||
|
||||
.. autoclass:: BaseRLModel
|
||||
:members:
|
||||
83
docs/modules/ppo.rst
Normal file
83
docs/modules/ppo.rst
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
.. _ppo2:
|
||||
|
||||
.. automodule:: torchy_baselines.ppo
|
||||
|
||||
PPO
|
||||
===
|
||||
|
||||
The `Proximal Policy Optimization <https://arxiv.org/abs/1707.06347>`_ algorithm combines ideas from A2C (having multiple workers)
|
||||
and TRPO (it uses a trust region to improve the actor).
|
||||
|
||||
The main idea is that after an update, the new policy should be not too far form the old policy.
|
||||
For that, ppo uses clipping to avoid too large update.
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
PPO contains several modifications from the original algorithm not documented
|
||||
by OpenAI: advantages are normalized and value function can be also clipped .
|
||||
|
||||
|
||||
Notes
|
||||
-----
|
||||
|
||||
- Original paper: https://arxiv.org/abs/1707.06347
|
||||
- Clear explanation of PPO on Arxiv Insights channel: https://www.youtube.com/watch?v=5P7I-xPq8u8
|
||||
- OpenAI blog post: https://blog.openai.com/openai-baselines-ppo/
|
||||
|
||||
|
||||
Can I use?
|
||||
----------
|
||||
|
||||
- Recurrent policies: ❌
|
||||
- Multi processing: ✔️
|
||||
- Gym spaces:
|
||||
|
||||
|
||||
============= ====== ===========
|
||||
Space Action Observation
|
||||
============= ====== ===========
|
||||
Discrete ❌ ❌
|
||||
Box ✔️ ✔️
|
||||
MultiDiscrete ❌ ❌
|
||||
MultiBinary ❌ ❌
|
||||
============= ====== ===========
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
Train a PPO agent on `Pendulum-v0` using 4 processes.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import gym
|
||||
|
||||
from torchy_baselines.ppo.policies import MlpPolicy
|
||||
from torchy_baselines.common.vec_env import SubprocVecEnv
|
||||
from torchy_baselines import PPO
|
||||
|
||||
# multiprocess environment
|
||||
n_cpu = 4
|
||||
env = SubprocVecEnv([lambda: gym.make('Pendulum-v0') for i in range(n_cpu)])
|
||||
|
||||
model = PPO(MlpPolicy, env, verbose=1)
|
||||
model.learn(total_timesteps=25000)
|
||||
model.save("ppo2_cartpole")
|
||||
|
||||
del model # remove to demonstrate saving and loading
|
||||
|
||||
model = PPO.load("ppo2_cartpole")
|
||||
|
||||
# Enjoy trained agent
|
||||
obs = env.reset()
|
||||
while True:
|
||||
action, _states = model.predict(obs)
|
||||
obs, rewards, dones, info = env.step(action)
|
||||
env.render()
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
.. autoclass:: PPO
|
||||
:members:
|
||||
:inherited-members:
|
||||
110
docs/modules/sac.rst
Normal file
110
docs/modules/sac.rst
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
.. _sac:
|
||||
|
||||
.. automodule:: torchy_baselines.sac
|
||||
|
||||
|
||||
SAC
|
||||
===
|
||||
|
||||
`Soft Actor Critic (SAC) <https://spinningup.openai.com/en/latest/algorithms/sac.html>`_ Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor.
|
||||
|
||||
SAC is the successor of `Soft Q-Learning SQL <https://arxiv.org/abs/1702.08165>`_ and incorporates the double Q-learning trick from TD3.
|
||||
A key feature of SAC, and a major difference with common RL algorithms, is that it is trained to maximize a trade-off between expected return and entropy, a measure of randomness in the policy.
|
||||
|
||||
|
||||
.. warning::
|
||||
|
||||
The SAC model does not support ``torchy_baselines.common.policies`` because it uses double q-values
|
||||
and value estimation, as a result it must use its own policy models (see :ref:`sac_policies`).
|
||||
|
||||
|
||||
.. rubric:: Available Policies
|
||||
|
||||
.. autosummary::
|
||||
:nosignatures:
|
||||
|
||||
MlpPolicy
|
||||
|
||||
|
||||
Notes
|
||||
-----
|
||||
|
||||
- Original paper: https://arxiv.org/abs/1801.01290
|
||||
- OpenAI Spinning Guide for SAC: https://spinningup.openai.com/en/latest/algorithms/sac.html
|
||||
- Original Implementation: https://github.com/haarnoja/sac
|
||||
- Blog post on using SAC with real robots: https://bair.berkeley.edu/blog/2018/12/14/sac/
|
||||
|
||||
.. note::
|
||||
In our implementation, we use an entropy coefficient (as in OpenAI Spinning or Facebook Horizon),
|
||||
which is the equivalent to the inverse of reward scale in the original SAC paper.
|
||||
The main reason is that it avoids having too high errors when updating the Q functions.
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
The default policies for SAC differ a bit from others MlpPolicy: it uses ReLU instead of tanh activation,
|
||||
to match the original paper
|
||||
|
||||
|
||||
Can I use?
|
||||
----------
|
||||
|
||||
- Recurrent policies: ❌
|
||||
- Multi processing: ❌
|
||||
- Gym spaces:
|
||||
|
||||
|
||||
============= ====== ===========
|
||||
Space Action Observation
|
||||
============= ====== ===========
|
||||
Discrete ❌ ❌
|
||||
Box ✔️ ✔️
|
||||
MultiDiscrete ❌ ❌
|
||||
MultiBinary ❌ ❌
|
||||
============= ====== ===========
|
||||
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
from torchy_baselines.sac.policies import MlpPolicy
|
||||
from torchy_baselines.common.vec_env import DummyVecEnv
|
||||
from torchy_baselines import SAC
|
||||
|
||||
env = gym.make('Pendulum-v0')
|
||||
env = DummyVecEnv([lambda: env])
|
||||
|
||||
model = SAC(MlpPolicy, env, verbose=1)
|
||||
model.learn(total_timesteps=50000, log_interval=10)
|
||||
model.save("sac_pendulum")
|
||||
|
||||
del model # remove to demonstrate saving and loading
|
||||
|
||||
model = SAC.load("sac_pendulum")
|
||||
|
||||
obs = env.reset()
|
||||
while True:
|
||||
action, _states = model.predict(obs)
|
||||
obs, rewards, dones, info = env.step(action)
|
||||
env.render()
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
.. autoclass:: SAC
|
||||
:members:
|
||||
:inherited-members:
|
||||
|
||||
.. _sac_policies:
|
||||
|
||||
SAC Policies
|
||||
-------------
|
||||
|
||||
.. autoclass:: MlpPolicy
|
||||
:members:
|
||||
:inherited-members:
|
||||
104
docs/modules/td3.rst
Normal file
104
docs/modules/td3.rst
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
.. _td3:
|
||||
|
||||
.. automodule:: torchy_baselines.td3
|
||||
|
||||
|
||||
TD3
|
||||
===
|
||||
|
||||
`Twin Delayed DDPG (TD3) <https://spinningup.openai.com/en/latest/algorithms/td3.html>`_ Addressing Function Approximation Error in Actor-Critic Methods.
|
||||
|
||||
TD3 is a direct successor of DDPG and improves it using three major tricks: clipped double Q-Learning, delayed policy update and target policy smoothing.
|
||||
We recommend reading `OpenAI Spinning guide on TD3 <https://spinningup.openai.com/en/latest/algorithms/td3.html>`_ to learn more about those.
|
||||
|
||||
|
||||
.. warning::
|
||||
|
||||
The TD3 model does not support ``torchy_baselines.common.policies`` because it uses double q-values
|
||||
estimation, as a result it must use its own policy models (see :ref:`td3_policies`).
|
||||
|
||||
|
||||
.. rubric:: Available Policies
|
||||
|
||||
.. autosummary::
|
||||
:nosignatures:
|
||||
|
||||
MlpPolicy
|
||||
|
||||
|
||||
Notes
|
||||
-----
|
||||
|
||||
- Original paper: https://arxiv.org/pdf/1802.09477.pdf
|
||||
- OpenAI Spinning Guide for TD3: https://spinningup.openai.com/en/latest/algorithms/td3.html
|
||||
- Original Implementation: https://github.com/sfujim/TD3
|
||||
|
||||
.. note::
|
||||
|
||||
The default policies for TD3 differ a bit from others MlpPolicy: it uses ReLU instead of tanh activation,
|
||||
to match the original paper
|
||||
|
||||
|
||||
Can I use?
|
||||
----------
|
||||
|
||||
- Recurrent policies: ❌
|
||||
- Multi processing: ❌
|
||||
- Gym spaces:
|
||||
|
||||
|
||||
============= ====== ===========
|
||||
Space Action Observation
|
||||
============= ====== ===========
|
||||
Discrete ❌ ❌
|
||||
Box ✔️ ✔️
|
||||
MultiDiscrete ❌ ❌
|
||||
MultiBinary ❌ ❌
|
||||
============= ====== ===========
|
||||
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import numpy as np
|
||||
|
||||
from torchy_baselines import TD3
|
||||
from torchy_baselines.td3.policies import MlpPolicy
|
||||
from torchy_baselines.ddpg.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
|
||||
|
||||
# The noise objects for TD3
|
||||
n_actions = env.action_space.shape[-1]
|
||||
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
|
||||
|
||||
model = TD3(MlpPolicy, 'Pendulum-v0', action_noise=action_noise, verbose=1)
|
||||
model.learn(total_timesteps=50000, log_interval=10)
|
||||
model.save("td3_pendulum")
|
||||
env = model.get_env()
|
||||
|
||||
del model # remove to demonstrate saving and loading
|
||||
|
||||
model = TD3.load("td3_pendulum")
|
||||
|
||||
obs = env.reset()
|
||||
while True:
|
||||
action, _states = model.predict(obs)
|
||||
obs, rewards, dones, info = env.step(action)
|
||||
env.render()
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
.. autoclass:: TD3
|
||||
:members:
|
||||
:inherited-members:
|
||||
|
||||
.. _td3_policies:
|
||||
|
||||
TD3 Policies
|
||||
-------------
|
||||
|
||||
.. autoclass:: MlpPolicy
|
||||
:members:
|
||||
:inherited-members:
|
||||
|
|
@ -1 +1,2 @@
|
|||
from torchy_baselines.cem_rl.cem_rl import CEMRL
|
||||
from torchy_baselines.td3.policies import MlpPolicy
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ class DiagGaussianDistribution(Distribution):
|
|||
|
||||
def proba_distribution_net(self, latent_dim, log_std_init=0.0):
|
||||
mean_actions = nn.Linear(latent_dim, self.action_dim)
|
||||
# TODO: allow action dependent std
|
||||
log_std = nn.Parameter(th.ones(self.action_dim) * log_std_init)
|
||||
return mean_actions, log_std
|
||||
|
||||
|
|
@ -73,6 +74,11 @@ class DiagGaussianDistribution(Distribution):
|
|||
def entropy(self):
|
||||
return self.distribution.entropy()
|
||||
|
||||
def log_prob_from_params(self, mean_actions, log_std):
|
||||
action, _ = self.proba_distribution(mean_actions, log_std)
|
||||
log_prob = self.log_prob(action)
|
||||
return action, log_prob
|
||||
|
||||
def log_prob(self, action):
|
||||
log_prob = self.distribution.log_prob(action)
|
||||
if len(log_prob.shape) > 1:
|
||||
|
|
@ -87,6 +93,7 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
|
|||
super(SquashedDiagGaussianDistribution, self).__init__(action_dim)
|
||||
# Avoid NaN (prevents division by zero or log of zero)
|
||||
self.epsilon = epsilon
|
||||
self.gaussian_action = None
|
||||
|
||||
def proba_distribution(self, mean_actions, log_std, deterministic=False):
|
||||
action, _ = super(SquashedDiagGaussianDistribution, self).proba_distribution(mean_actions, log_std, deterministic)
|
||||
|
|
@ -114,6 +121,6 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
|
|||
|
||||
# Log likelihood for a gaussian distribution
|
||||
log_prob = super(SquashedDiagGaussianDistribution, self).log_prob(gaussian_action)
|
||||
# Squash correction (from original implementation)
|
||||
# Squash correction (from original SAC implementation)
|
||||
log_prob -= th.sum(th.log(1 - action ** 2 + self.epsilon), dim=1)
|
||||
return log_prob
|
||||
|
|
|
|||
|
|
@ -124,10 +124,14 @@ def register_policy(name, policy):
|
|||
:param policy: (subclass of BasePolicy) the policy
|
||||
"""
|
||||
sub_class = None
|
||||
for cls in BasePolicy.__subclasses__():
|
||||
if issubclass(policy, cls):
|
||||
sub_class = cls
|
||||
break
|
||||
# For building the doc
|
||||
try:
|
||||
for cls in BasePolicy.__subclasses__():
|
||||
if issubclass(policy, cls):
|
||||
sub_class = cls
|
||||
break
|
||||
except AttributeError:
|
||||
sub_class = str(th.random.randint(100))
|
||||
if sub_class is None:
|
||||
raise ValueError("Error: the policy {} is not of any known subclasses of BasePolicy!".format(policy))
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +1,2 @@
|
|||
from torchy_baselines.ppo.ppo import PPO
|
||||
from torchy_baselines.ppo.policies import MlpPolicy
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from copy import deepcopy
|
|||
import gym
|
||||
import torch as th
|
||||
import torch.nn.functional as F
|
||||
# Check if tensorboard is available for pytorch
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
except ImportError:
|
||||
|
|
@ -21,28 +22,63 @@ from torchy_baselines.ppo.policies import PPOPolicy
|
|||
|
||||
class PPO(BaseRLModel):
|
||||
"""
|
||||
Implementation of Proximal Policy Optimization (PPO) (clip version)
|
||||
Proximal Policy Optimization algorithm (PPO) (clip version)
|
||||
|
||||
Paper: https://arxiv.org/abs/1707.06347
|
||||
Code: https://github.com/openai/spinningup/
|
||||
and https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail
|
||||
and stable_baselines
|
||||
Code: This implementation borrows code from OpenAI spinningup (https://github.com/openai/spinningup/)
|
||||
https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and
|
||||
and Stable Baselines (PPO2 from https://github.com/hill-a/stable-baselines)
|
||||
|
||||
Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html
|
||||
|
||||
:param policy: (PPOPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, ...)
|
||||
:param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str)
|
||||
:param learning_rate: (float or callable) The learning rate, it can be a function
|
||||
:param n_steps: (int) The number of steps to run for each environment per update
|
||||
(i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
|
||||
:param batch_size: (int) Minibatch size
|
||||
:param n_epochs: (int) Number of epoch when optimizing the surrogate loss
|
||||
:param gamma: (float) Discount factor
|
||||
:param gae_lambda: (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
||||
:param clip_range: (float or callable) Clipping parameter, it can be a function
|
||||
:param clip_range_vf: (float or callable) Clipping parameter for the value function, it can be a function.
|
||||
This is a parameter specific to the OpenAI implementation. If None is passed (default),
|
||||
no clipping will be done on the value function.
|
||||
IMPORTANT: this clipping depends on the reward scaling.
|
||||
:param ent_coef: (float) Entropy coefficient for the loss calculation
|
||||
:param vf_coef: (float) Value function coefficient for the loss calculation
|
||||
:param max_grad_norm: (float) The maximum value for the gradient clipping
|
||||
:param target_kl: (float) Limit the KL divergence between updates,
|
||||
because the clipping is not enough to prevent large update
|
||||
see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
|
||||
By default, there is no limit on the kl div.
|
||||
:param tensorboard_log: (str) the log location for tensorboard (if None, no logging)
|
||||
:param create_eval_env: (bool) Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
|
||||
:param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
|
||||
:param seed: (int) Seed for the pseudo random generators
|
||||
:param device: (str or th.device) Device (cpu, cuda, ...) on which the code should be run.
|
||||
Setting it to auto, the code will be run on the GPU if possible.
|
||||
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
def __init__(self, policy, env, policy_kwargs=None, verbose=0,
|
||||
learning_rate=3e-4, seed=0, device='auto',
|
||||
n_optim=5, batch_size=64, n_steps=256,
|
||||
gamma=0.99, gae_lambda=0.95, clip_range=0.2,
|
||||
ent_coef=0.01, vf_coef=0.5, max_grad_norm=0.5,
|
||||
target_kl=None, clip_range_vf=None, create_eval_env=False,
|
||||
tensorboard_log=None,
|
||||
|
||||
def __init__(self, policy, env, learning_rate=3e-4,
|
||||
n_steps=2048, batch_size=64, n_epochs=10,
|
||||
gamma=0.99, gae_lambda=0.95, clip_range=0.2, clip_range_vf=None,
|
||||
ent_coef=0.0, vf_coef=0.5, max_grad_norm=0.5,
|
||||
target_kl=None, tensorboard_log=None, create_eval_env=False,
|
||||
policy_kwargs=None, verbose=0, seed=0, device='auto',
|
||||
_init_setup_model=True):
|
||||
|
||||
super(PPO, self).__init__(policy, env, PPOPolicy, policy_kwargs,
|
||||
verbose, device, create_eval_env=create_eval_env, support_multi_env=True)
|
||||
super(PPO, self).__init__(policy, env, PPOPolicy, policy_kwargs=policy_kwargs,
|
||||
verbose=verbose, device=device,
|
||||
create_eval_env=create_eval_env, support_multi_env=True)
|
||||
|
||||
self.learning_rate = learning_rate
|
||||
self.seed = seed
|
||||
self.batch_size = batch_size
|
||||
self.n_optim = n_optim
|
||||
self.n_epochs = n_epochs
|
||||
self.n_steps = n_steps
|
||||
self.gamma = gamma
|
||||
self.gae_lambda = gae_lambda
|
||||
|
|
@ -177,7 +213,7 @@ class PPO(BaseRLModel):
|
|||
obs = self.env.reset()
|
||||
eval_env = self._get_eval_env(eval_env)
|
||||
|
||||
if self.tensorboard_log is not None:
|
||||
if self.tensorboard_log is not None and SummaryWriter is not None:
|
||||
self.tb_writer = SummaryWriter(log_dir=os.path.join(self.tensorboard_log, tb_log_name))
|
||||
|
||||
while self.num_timesteps < total_timesteps:
|
||||
|
|
@ -193,7 +229,7 @@ class PPO(BaseRLModel):
|
|||
self.num_timesteps += self.n_steps * self.n_envs
|
||||
timesteps_since_eval += self.n_steps * self.n_envs
|
||||
|
||||
self.train(self.n_optim, batch_size=self.batch_size)
|
||||
self.train(self.n_epochs, batch_size=self.batch_size)
|
||||
|
||||
# Evaluate agent
|
||||
if 0 < eval_freq <= timesteps_since_eval and eval_env is not None:
|
||||
|
|
|
|||
|
|
@ -1 +1,2 @@
|
|||
from torchy_baselines.sac.sac import SAC
|
||||
from torchy_baselines.sac.policies import MlpPolicy
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ class Critic(BaseNetwork):
|
|||
|
||||
class SACPolicy(BasePolicy):
|
||||
def __init__(self, observation_space, action_space,
|
||||
learning_rate=1e-3, net_arch=None, device='cpu',
|
||||
learning_rate=3e-4, net_arch=None, device='cpu',
|
||||
activation_fn=nn.ReLU):
|
||||
super(SACPolicy, self).__init__(observation_space, action_space, device)
|
||||
self.obs_dim = self.observation_space.shape[0]
|
||||
|
|
@ -96,9 +96,6 @@ class SACPolicy(BasePolicy):
|
|||
self.critic_target.load_state_dict(self.critic.state_dict())
|
||||
self.critic.optimizer = th.optim.Adam(self.critic.parameters(), lr=learning_rate)
|
||||
|
||||
def actor_forward(self, state, deterministic=False):
|
||||
pass
|
||||
|
||||
def make_actor(self):
|
||||
return Actor(**self.net_args).to(self.device)
|
||||
|
||||
|
|
|
|||
|
|
@ -48,6 +48,8 @@ class SAC(BaseRLModel):
|
|||
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
|
||||
:param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
|
||||
:param seed: (int) Seed for the pseudo random generators
|
||||
:param device: (str or th.device) Device (cpu, cuda, ...) on which the code should be run.
|
||||
Setting it to auto, the code will be run on the GPU if possible.
|
||||
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
def __init__(self, policy, env, learning_rate=3e-4, buffer_size=int(1e6),
|
||||
|
|
|
|||
|
|
@ -1 +1,2 @@
|
|||
from torchy_baselines.td3.td3 import TD3
|
||||
from torchy_baselines.td3.policies import MlpPolicy
|
||||
|
|
|
|||
|
|
@ -42,6 +42,8 @@ class TD3(BaseRLModel):
|
|||
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
|
||||
:param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug
|
||||
:param seed: (int) Seed for the pseudo random generators
|
||||
:param device: (str or th.device) Device (cpu, cuda, ...) on which the code should be run.
|
||||
Setting it to auto, the code will be run on the GPU if possible.
|
||||
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
def __init__(self, policy, env, buffer_size=int(1e6), learning_rate=1e-3,
|
||||
|
|
|
|||
Loading…
Reference in a new issue