diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..47f98cd --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = -W # make warnings fatal +SPHINXBUILD = sphinx-build +SPHINXPROJ = StableBaselines +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) \ No newline at end of file diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 0000000..ffb3753 --- /dev/null +++ b/docs/README.md @@ -0,0 +1,25 @@ +## Stable Baselines Documentation + +This folder contains documentation for the RL baselines. + + +### Build the Documentation + +#### Install Sphinx and Theme + +``` +pip install sphinx sphinx-autobuild sphinx-rtd-theme +``` + +#### Building the Docs + +In the `docs/` folder: +``` +make html +``` + +if you want to building each time a file is changed: + +``` +sphinx-autobuild . _build/html +``` diff --git a/docs/_static/css/baselines_theme.css b/docs/_static/css/baselines_theme.css new file mode 100644 index 0000000..89455aa --- /dev/null +++ b/docs/_static/css/baselines_theme.css @@ -0,0 +1,52 @@ +/* Main colors adapted from pytorch doc */ +:root{ + --main-bg-color: #343A40; + --link-color: #FD7E14; +} + +/* Header fonts y */ +h1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend, p.caption { + font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif; +} + + +/* Docs background */ +.wy-side-nav-search{ + background-color: var(--main-bg-color); +} + +/* Mobile version */ +.wy-nav-top{ + background-color: var(--main-bg-color); +} + +/* Change link colors (except for the menu) */ +a { + color: var(--link-color); +} + +a:hover { + color: #4F778F; +} + +.wy-menu a { + color: #b3b3b3; +} + +.wy-menu a:hover { + color: #b3b3b3; +} + +a.icon.icon-home { + color: #b3b3b3; +} + +.version{ + color: var(--link-color) !important; +} + + +/* Make code blocks have a background */ +.codeblock,pre.literal-block,.rst-content .literal-block,.rst-content pre.literal-block,div[class^='highlight'] { + background: #f8f8f8;; +} diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..606a195 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,203 @@ +# -*- coding: utf-8 -*- +# +# Configuration file for the Sphinx documentation builder. +# +# This file does only contain a selection of the most common options. For a +# full list see the documentation: +# http://www.sphinx-doc.org/en/master/config + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +import sys +from unittest.mock import MagicMock + +# source code directory, relative to this file, for sphinx-autobuild +sys.path.insert(0, os.path.abspath('..')) + + +class Mock(MagicMock): + __subclasses__ = [] + @classmethod + def __getattr__(cls, name): + return MagicMock() + + +# Mock modules that requires C modules +# Note: because of that we cannot test examples using CI +# 'torch', 'torch.nn', 'torch.nn.functional', +MOCK_MODULES = ['joblib', 'scipy', 'scipy.signal', + 'pandas', 'mpi4py', 'mujoco-py', 'cv2', + 'tensorflow', 'torch', 'torch.nn', 'torch.nn.functional', + 'torch.distributions', + 'tensorflow.contrib', 'tensorflow.contrib.layers', + 'tensorflow.python', 'tensorflow.python.client', 'tensorflow.python.ops', + 'tqdm', 'cloudpickle', 'matplotlib', 'matplotlib.pyplot', + 'seaborn', 'gym', 'gym.spaces', 'gym.core', + 'tensorflow.core', 'tensorflow.core.util', 'tensorflow.python.util', + 'gym.wrappers', 'gym.wrappers.monitoring', 'zmq'] +sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES) + + +import torchy_baselines + + +# -- Project information ----------------------------------------------------- + +project = 'Torchy Baselines' +copyright = '2019, Torchy Baselines' +author = 'Torchy Baselines Contributors' + +# The short X.Y version +version = 'master (' + torchy_baselines.__version__ + ' )' +# The full version, including alpha/beta/rc tags +release = torchy_baselines.__version__ + + +# -- General configuration --------------------------------------------------- + +# If your documentation needs a minimal Sphinx version, state it here. +# +# needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.autosummary', + 'sphinx.ext.mathjax', + 'sphinx.ext.ifconfig', + 'sphinx.ext.viewcode', +] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +# source_suffix = ['.rst', '.md'] +source_suffix = '.rst' + +# The master toctree document. +master_doc = 'index' + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = None + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path . +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = 'sphinx' + + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. + +# Fix for read the docs +on_rtd = os.environ.get('READTHEDOCS') == 'True' +if on_rtd: + html_theme = 'default' +else: + html_theme = 'sphinx_rtd_theme' + +html_logo = '_static/img/logo.png' + + +def setup(app): + app.add_stylesheet("css/baselines_theme.css") + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +# +# html_theme_options = {} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] + +# Custom sidebar templates, must be a dictionary that maps document names +# to template names. +# +# The default sidebars (for documents that don't match any pattern) are +# defined by theme itself. Builtin themes are using these templates by +# default: ``['localtoc.html', 'relations.html', 'sourcelink.html', +# 'searchbox.html']``. +# +# html_sidebars = {} + + +# -- Options for HTMLHelp output --------------------------------------------- + +# Output file base name for HTML help builder. +htmlhelp_basename = 'TorchyBaselinesdoc' + + +# -- Options for LaTeX output ------------------------------------------------ + +latex_elements = { + # The paper size ('letterpaper' or 'a4paper'). + # + # 'papersize': 'letterpaper', + + # The font size ('10pt', '11pt' or '12pt'). + # + # 'pointsize': '10pt', + + # Additional stuff for the LaTeX preamble. + # + # 'preamble': '', + + # Latex figure (float) alignment + # + # 'figure_align': 'htbp', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, +# author, documentclass [howto, manual, or own class]). +latex_documents = [ + (master_doc, 'TorchyBaselines.tex', 'Torchy Baselines Documentation', + 'Torchy Baselines Contributors', 'manual'), +] + + +# -- Options for manual page output ------------------------------------------ + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [ + (master_doc, 'torchybaselines', 'Torchy Baselines Documentation', + [author], 1) +] + + +# -- Options for Texinfo output ---------------------------------------------- + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + (master_doc, 'TorchyBaselines', 'Torchy Baselines Documentation', + author, 'TorchyBaselines', 'One line description of project.', + 'Miscellaneous'), +] + + +# -- Extension configuration ------------------------------------------------- diff --git a/docs/guide/quickstart.rst b/docs/guide/quickstart.rst new file mode 100644 index 0000000..94e6601 --- /dev/null +++ b/docs/guide/quickstart.rst @@ -0,0 +1,40 @@ +.. _quickstart: + +=============== +Getting Started +=============== + +Most of the library tries to follow a sklearn-like syntax for the Reinforcement Learning algorithms. + +Here is a quick example of how to train and run SAC on a Pendulum environment: + +.. code-block:: python + + import gym + + from torchy_baselines.sac.policies import MlpPolicy + from torchy_baselines.common.vec_env import DummyVecEnv + from torchy_baselines import SAC + + # The algorithms require a vectorized environment to run + env = DummyVecEnv([lambda: gym.make('Pendulum-v0')]) + + model = SAC(MlpPolicy, env, verbose=1) + model.learn(total_timesteps=10000) + + obs = env.reset() + for i in range(1000): + action = model.predict(obs) + obs, rewards, dones, info = env.step(action) + env.render() + + +Or just train a model with a one liner if +`the environment is registered in Gym `_ and if +the policy is registered: + +.. code-block:: python + + from torchy_baselines import SAC + + model = SAC('MlpPolicy', 'Pendulum-v0').learn(10000) diff --git a/docs/guide/vec_envs.rst b/docs/guide/vec_envs.rst new file mode 100644 index 0000000..0a989ea --- /dev/null +++ b/docs/guide/vec_envs.rst @@ -0,0 +1,71 @@ +.. _vec_env: + +.. automodule:: torchy_baselines.common.vec_env + +Vectorized Environments +======================= + +Vectorized Environments are a method for stacking multiple independent environments into a single environment. +Instead of training an RL agent on 1 environment per step, it allows us to train it on `n` environments per step. +Because of this, `actions` passed to the environment are now a vector (of dimension `n`). +It is the same for `observations`, `rewards` and end of episode signals (`dones`). +In the case of non-array observation spaces such as `Dict` or `Tuple`, where different sub-spaces +may have different shapes, the sub-observations are vectors (of dimension `n`). + +============= ======= ============ ======== ========= ================ +Name ``Box`` ``Discrete`` ``Dict`` ``Tuple`` Multi Processing +============= ======= ============ ======== ========= ================ +DummyVecEnv ✔️ ✔️ ✔️ ✔️ ❌️ +SubprocVecEnv ✔️ ✔️ ✔️ ✔️ ✔️ +============= ======= ============ ======== ========= ================ + +.. note:: + + Vectorized environments are required when using wrappers for frame-stacking or normalization. + +.. note:: + + When using vectorized environments, the environments are automatically reset at the end of each episode. + Thus, the observation returned for the i-th environment when ``done[i]`` is true will in fact be the first observation of the next episode, not the last observation of the episode that has just terminated. + You can access the "real" final observation of the terminated episode—that is, the one that accompanied the ``done`` event provided by the underlying environment—using the ``terminal_observation`` keys in the info dicts returned by the vecenv. + +.. warning:: + + When using ``SubprocVecEnv``, users must wrap the code in an ``if __name__ == "__main__":`` if using the ``forkserver`` or ``spawn`` start method (default on Windows). + On Linux, the default start method is ``fork`` which is not thread safe and can create deadlocks. + + For more information, see Python's `multiprocessing guidelines `_. + +VecEnv +------ + +.. autoclass:: VecEnv + :members: + +DummyVecEnv +----------- + +.. autoclass:: DummyVecEnv + :members: + +SubprocVecEnv +------------- + +.. autoclass:: SubprocVecEnv + :members: + +Wrappers +-------- + +VecFrameStack +~~~~~~~~~~~~~ + +.. autoclass:: VecFrameStack + :members: + + +VecNormalize +~~~~~~~~~~~~ + +.. autoclass:: VecNormalize + :members: diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000..77dab25 --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,63 @@ +.. Stable Baselines documentation master file, created by + sphinx-quickstart on Thu Sep 26 11:06:54 2019. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Welcome to Torchy Baselines docs! - Pytorch RL Baselines +======================================================== + +`Torchy Baselines `_ is the PyTorch version of `Stable Baselines `_, +a set of improved implementations of reinforcement learning algorithms. + +RL Baselines Zoo (collection of pre-trained agents): https://github.com/araffin/rl-baselines-zoo + +RL Baselines zoo also offers a simple interface to train, evaluate agents and do hyperparameter tuning. + + + +.. toctree:: + :maxdepth: 2 + :caption: User Guide + + guide/quickstart + guide/vec_envs + + +.. toctree:: + :maxdepth: 1 + :caption: RL Algorithms + + modules/base + modules/ppo + modules/sac + modules/td3 + + +.. toctree:: + :maxdepth: 1 + :caption: Misc + + misc/changelog + + +Citing Torchy Baselines +----------------------- +To cite this project in publications: + +.. code-block:: bibtex + + @misc{torchy-baselines, + author = {Raffin, Antonin and Hill, Ashley and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi}, + title = {Torchy Baselines}, + year = {2019}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/hill-a/stable-baselines}}, + } + +Indices and tables +------------------- + +* :ref:`genindex` +* :ref:`search` +* :ref:`modindex` diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000..22b5fff --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,36 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=. +set BUILDDIR=_build +set SPHINXPROJ=StableBaselines + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% + +:end +popd diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst new file mode 100644 index 0000000..4af5310 --- /dev/null +++ b/docs/misc/changelog.rst @@ -0,0 +1,46 @@ +.. _changelog: + +Changelog +========== + + +Pre-Release 0.0.3a0 (WIP) +------------------------- +**Initial Release** + +Breaking Changes: +^^^^^^^^^^^^^^^^^ + +New Features: +^^^^^^^^^^^^^ +- Initial release of CEM-RL, PPO, SAC and TD3 + +Bug Fixes: +^^^^^^^^^^ + +Deprecations: +^^^^^^^^^^^^^ + + +Others: +^^^^^^^ + +Documentation: +^^^^^^^^^^^^^^ + + +Maintainers +----------- + +Torchy-Baselines is currently maintained by `Antonin Raffin`_ (aka `@araffin`_). + +.. _Antonin Raffin: https://araffin.github.io/ +.. _@araffin: https://github.com/araffin + + + +Contributors: +------------- +In random order... + +Thanks to @hill-a @enerijunior @AdamGleave @Miffyli diff --git a/docs/modules/base.rst b/docs/modules/base.rst new file mode 100644 index 0000000..d32268d --- /dev/null +++ b/docs/modules/base.rst @@ -0,0 +1,12 @@ +.. _base_algo: + +.. automodule:: torchy_baselines.common.base_class + + +Base RL Class +============= + +Common interface for all the RL algorithms + +.. autoclass:: BaseRLModel + :members: diff --git a/docs/modules/ppo.rst b/docs/modules/ppo.rst new file mode 100644 index 0000000..5531d3a --- /dev/null +++ b/docs/modules/ppo.rst @@ -0,0 +1,83 @@ +.. _ppo2: + +.. automodule:: torchy_baselines.ppo + +PPO +=== + +The `Proximal Policy Optimization `_ algorithm combines ideas from A2C (having multiple workers) +and TRPO (it uses a trust region to improve the actor). + +The main idea is that after an update, the new policy should be not too far form the old policy. +For that, ppo uses clipping to avoid too large update. + + +.. note:: + + PPO contains several modifications from the original algorithm not documented + by OpenAI: advantages are normalized and value function can be also clipped . + + +Notes +----- + +- Original paper: https://arxiv.org/abs/1707.06347 +- Clear explanation of PPO on Arxiv Insights channel: https://www.youtube.com/watch?v=5P7I-xPq8u8 +- OpenAI blog post: https://blog.openai.com/openai-baselines-ppo/ + + +Can I use? +---------- + +- Recurrent policies: ❌ +- Multi processing: ✔️ +- Gym spaces: + + +============= ====== =========== +Space Action Observation +============= ====== =========== +Discrete ❌ ❌ +Box ✔️ ✔️ +MultiDiscrete ❌ ❌ +MultiBinary ❌ ❌ +============= ====== =========== + +Example +------- + +Train a PPO agent on `Pendulum-v0` using 4 processes. + +.. code-block:: python + + import gym + + from torchy_baselines.ppo.policies import MlpPolicy + from torchy_baselines.common.vec_env import SubprocVecEnv + from torchy_baselines import PPO + + # multiprocess environment + n_cpu = 4 + env = SubprocVecEnv([lambda: gym.make('Pendulum-v0') for i in range(n_cpu)]) + + model = PPO(MlpPolicy, env, verbose=1) + model.learn(total_timesteps=25000) + model.save("ppo2_cartpole") + + del model # remove to demonstrate saving and loading + + model = PPO.load("ppo2_cartpole") + + # Enjoy trained agent + obs = env.reset() + while True: + action, _states = model.predict(obs) + obs, rewards, dones, info = env.step(action) + env.render() + +Parameters +---------- + +.. autoclass:: PPO + :members: + :inherited-members: diff --git a/docs/modules/sac.rst b/docs/modules/sac.rst new file mode 100644 index 0000000..7bd949d --- /dev/null +++ b/docs/modules/sac.rst @@ -0,0 +1,110 @@ +.. _sac: + +.. automodule:: torchy_baselines.sac + + +SAC +=== + +`Soft Actor Critic (SAC) `_ Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor. + +SAC is the successor of `Soft Q-Learning SQL `_ and incorporates the double Q-learning trick from TD3. +A key feature of SAC, and a major difference with common RL algorithms, is that it is trained to maximize a trade-off between expected return and entropy, a measure of randomness in the policy. + + +.. warning:: + + The SAC model does not support ``torchy_baselines.common.policies`` because it uses double q-values + and value estimation, as a result it must use its own policy models (see :ref:`sac_policies`). + + +.. rubric:: Available Policies + +.. autosummary:: + :nosignatures: + + MlpPolicy + + +Notes +----- + +- Original paper: https://arxiv.org/abs/1801.01290 +- OpenAI Spinning Guide for SAC: https://spinningup.openai.com/en/latest/algorithms/sac.html +- Original Implementation: https://github.com/haarnoja/sac +- Blog post on using SAC with real robots: https://bair.berkeley.edu/blog/2018/12/14/sac/ + +.. note:: + In our implementation, we use an entropy coefficient (as in OpenAI Spinning or Facebook Horizon), + which is the equivalent to the inverse of reward scale in the original SAC paper. + The main reason is that it avoids having too high errors when updating the Q functions. + + +.. note:: + + The default policies for SAC differ a bit from others MlpPolicy: it uses ReLU instead of tanh activation, + to match the original paper + + +Can I use? +---------- + +- Recurrent policies: ❌ +- Multi processing: ❌ +- Gym spaces: + + +============= ====== =========== +Space Action Observation +============= ====== =========== +Discrete ❌ ❌ +Box ✔️ ✔️ +MultiDiscrete ❌ ❌ +MultiBinary ❌ ❌ +============= ====== =========== + + +Example +------- + +.. code-block:: python + + import gym + import numpy as np + + from torchy_baselines.sac.policies import MlpPolicy + from torchy_baselines.common.vec_env import DummyVecEnv + from torchy_baselines import SAC + + env = gym.make('Pendulum-v0') + env = DummyVecEnv([lambda: env]) + + model = SAC(MlpPolicy, env, verbose=1) + model.learn(total_timesteps=50000, log_interval=10) + model.save("sac_pendulum") + + del model # remove to demonstrate saving and loading + + model = SAC.load("sac_pendulum") + + obs = env.reset() + while True: + action, _states = model.predict(obs) + obs, rewards, dones, info = env.step(action) + env.render() + +Parameters +---------- + +.. autoclass:: SAC + :members: + :inherited-members: + +.. _sac_policies: + +SAC Policies +------------- + +.. autoclass:: MlpPolicy + :members: + :inherited-members: diff --git a/docs/modules/td3.rst b/docs/modules/td3.rst new file mode 100644 index 0000000..b476350 --- /dev/null +++ b/docs/modules/td3.rst @@ -0,0 +1,104 @@ +.. _td3: + +.. automodule:: torchy_baselines.td3 + + +TD3 +=== + +`Twin Delayed DDPG (TD3) `_ Addressing Function Approximation Error in Actor-Critic Methods. + +TD3 is a direct successor of DDPG and improves it using three major tricks: clipped double Q-Learning, delayed policy update and target policy smoothing. +We recommend reading `OpenAI Spinning guide on TD3 `_ to learn more about those. + + +.. warning:: + + The TD3 model does not support ``torchy_baselines.common.policies`` because it uses double q-values + estimation, as a result it must use its own policy models (see :ref:`td3_policies`). + + +.. rubric:: Available Policies + +.. autosummary:: + :nosignatures: + + MlpPolicy + + +Notes +----- + +- Original paper: https://arxiv.org/pdf/1802.09477.pdf +- OpenAI Spinning Guide for TD3: https://spinningup.openai.com/en/latest/algorithms/td3.html +- Original Implementation: https://github.com/sfujim/TD3 + +.. note:: + + The default policies for TD3 differ a bit from others MlpPolicy: it uses ReLU instead of tanh activation, + to match the original paper + + +Can I use? +---------- + +- Recurrent policies: ❌ +- Multi processing: ❌ +- Gym spaces: + + +============= ====== =========== +Space Action Observation +============= ====== =========== +Discrete ❌ ❌ +Box ✔️ ✔️ +MultiDiscrete ❌ ❌ +MultiBinary ❌ ❌ +============= ====== =========== + + +Example +------- + +.. code-block:: python + + import numpy as np + + from torchy_baselines import TD3 + from torchy_baselines.td3.policies import MlpPolicy + from torchy_baselines.ddpg.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise + + # The noise objects for TD3 + n_actions = env.action_space.shape[-1] + action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions)) + + model = TD3(MlpPolicy, 'Pendulum-v0', action_noise=action_noise, verbose=1) + model.learn(total_timesteps=50000, log_interval=10) + model.save("td3_pendulum") + env = model.get_env() + + del model # remove to demonstrate saving and loading + + model = TD3.load("td3_pendulum") + + obs = env.reset() + while True: + action, _states = model.predict(obs) + obs, rewards, dones, info = env.step(action) + env.render() + +Parameters +---------- + +.. autoclass:: TD3 + :members: + :inherited-members: + +.. _td3_policies: + +TD3 Policies +------------- + +.. autoclass:: MlpPolicy + :members: + :inherited-members: diff --git a/torchy_baselines/cem_rl/__init__.py b/torchy_baselines/cem_rl/__init__.py index b93cd30..52c3c9a 100644 --- a/torchy_baselines/cem_rl/__init__.py +++ b/torchy_baselines/cem_rl/__init__.py @@ -1 +1,2 @@ from torchy_baselines.cem_rl.cem_rl import CEMRL +from torchy_baselines.td3.policies import MlpPolicy diff --git a/torchy_baselines/common/distributions.py b/torchy_baselines/common/distributions.py index ec19087..bca56f5 100644 --- a/torchy_baselines/common/distributions.py +++ b/torchy_baselines/common/distributions.py @@ -52,6 +52,7 @@ class DiagGaussianDistribution(Distribution): def proba_distribution_net(self, latent_dim, log_std_init=0.0): mean_actions = nn.Linear(latent_dim, self.action_dim) + # TODO: allow action dependent std log_std = nn.Parameter(th.ones(self.action_dim) * log_std_init) return mean_actions, log_std @@ -73,6 +74,11 @@ class DiagGaussianDistribution(Distribution): def entropy(self): return self.distribution.entropy() + def log_prob_from_params(self, mean_actions, log_std): + action, _ = self.proba_distribution(mean_actions, log_std) + log_prob = self.log_prob(action) + return action, log_prob + def log_prob(self, action): log_prob = self.distribution.log_prob(action) if len(log_prob.shape) > 1: @@ -87,6 +93,7 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution): super(SquashedDiagGaussianDistribution, self).__init__(action_dim) # Avoid NaN (prevents division by zero or log of zero) self.epsilon = epsilon + self.gaussian_action = None def proba_distribution(self, mean_actions, log_std, deterministic=False): action, _ = super(SquashedDiagGaussianDistribution, self).proba_distribution(mean_actions, log_std, deterministic) @@ -114,6 +121,6 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution): # Log likelihood for a gaussian distribution log_prob = super(SquashedDiagGaussianDistribution, self).log_prob(gaussian_action) - # Squash correction (from original implementation) + # Squash correction (from original SAC implementation) log_prob -= th.sum(th.log(1 - action ** 2 + self.epsilon), dim=1) return log_prob diff --git a/torchy_baselines/common/policies.py b/torchy_baselines/common/policies.py index 38922ba..55e356c 100644 --- a/torchy_baselines/common/policies.py +++ b/torchy_baselines/common/policies.py @@ -124,10 +124,14 @@ def register_policy(name, policy): :param policy: (subclass of BasePolicy) the policy """ sub_class = None - for cls in BasePolicy.__subclasses__(): - if issubclass(policy, cls): - sub_class = cls - break + # For building the doc + try: + for cls in BasePolicy.__subclasses__(): + if issubclass(policy, cls): + sub_class = cls + break + except AttributeError: + sub_class = str(th.random.randint(100)) if sub_class is None: raise ValueError("Error: the policy {} is not of any known subclasses of BasePolicy!".format(policy)) diff --git a/torchy_baselines/ppo/__init__.py b/torchy_baselines/ppo/__init__.py index 2ce3051..72a5560 100644 --- a/torchy_baselines/ppo/__init__.py +++ b/torchy_baselines/ppo/__init__.py @@ -1 +1,2 @@ from torchy_baselines.ppo.ppo import PPO +from torchy_baselines.ppo.policies import MlpPolicy diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index 2dc738c..1efd412 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -5,6 +5,7 @@ from copy import deepcopy import gym import torch as th import torch.nn.functional as F +# Check if tensorboard is available for pytorch try: from torch.utils.tensorboard import SummaryWriter except ImportError: @@ -21,28 +22,63 @@ from torchy_baselines.ppo.policies import PPOPolicy class PPO(BaseRLModel): """ - Implementation of Proximal Policy Optimization (PPO) (clip version) + Proximal Policy Optimization algorithm (PPO) (clip version) + Paper: https://arxiv.org/abs/1707.06347 - Code: https://github.com/openai/spinningup/ - and https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail - and stable_baselines + Code: This implementation borrows code from OpenAI spinningup (https://github.com/openai/spinningup/) + https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and + and Stable Baselines (PPO2 from https://github.com/hill-a/stable-baselines) + + Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html + + :param policy: (PPOPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str) + :param learning_rate: (float or callable) The learning rate, it can be a function + :param n_steps: (int) The number of steps to run for each environment per update + (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel) + :param batch_size: (int) Minibatch size + :param n_epochs: (int) Number of epoch when optimizing the surrogate loss + :param gamma: (float) Discount factor + :param gae_lambda: (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator + :param clip_range: (float or callable) Clipping parameter, it can be a function + :param clip_range_vf: (float or callable) Clipping parameter for the value function, it can be a function. + This is a parameter specific to the OpenAI implementation. If None is passed (default), + no clipping will be done on the value function. + IMPORTANT: this clipping depends on the reward scaling. + :param ent_coef: (float) Entropy coefficient for the loss calculation + :param vf_coef: (float) Value function coefficient for the loss calculation + :param max_grad_norm: (float) The maximum value for the gradient clipping + :param target_kl: (float) Limit the KL divergence between updates, + because the clipping is not enough to prevent large update + see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213) + By default, there is no limit on the kl div. + :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) + :param create_eval_env: (bool) Whether to create a second environment that will be + used for evaluating the agent periodically. (Only available when passing string for the environment) + :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation + :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug + :param seed: (int) Seed for the pseudo random generators + :param device: (str or th.device) Device (cpu, cuda, ...) on which the code should be run. + Setting it to auto, the code will be run on the GPU if possible. + :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance """ - def __init__(self, policy, env, policy_kwargs=None, verbose=0, - learning_rate=3e-4, seed=0, device='auto', - n_optim=5, batch_size=64, n_steps=256, - gamma=0.99, gae_lambda=0.95, clip_range=0.2, - ent_coef=0.01, vf_coef=0.5, max_grad_norm=0.5, - target_kl=None, clip_range_vf=None, create_eval_env=False, - tensorboard_log=None, + + def __init__(self, policy, env, learning_rate=3e-4, + n_steps=2048, batch_size=64, n_epochs=10, + gamma=0.99, gae_lambda=0.95, clip_range=0.2, clip_range_vf=None, + ent_coef=0.0, vf_coef=0.5, max_grad_norm=0.5, + target_kl=None, tensorboard_log=None, create_eval_env=False, + policy_kwargs=None, verbose=0, seed=0, device='auto', _init_setup_model=True): - super(PPO, self).__init__(policy, env, PPOPolicy, policy_kwargs, - verbose, device, create_eval_env=create_eval_env, support_multi_env=True) + super(PPO, self).__init__(policy, env, PPOPolicy, policy_kwargs=policy_kwargs, + verbose=verbose, device=device, + create_eval_env=create_eval_env, support_multi_env=True) self.learning_rate = learning_rate self.seed = seed self.batch_size = batch_size - self.n_optim = n_optim + self.n_epochs = n_epochs self.n_steps = n_steps self.gamma = gamma self.gae_lambda = gae_lambda @@ -177,7 +213,7 @@ class PPO(BaseRLModel): obs = self.env.reset() eval_env = self._get_eval_env(eval_env) - if self.tensorboard_log is not None: + if self.tensorboard_log is not None and SummaryWriter is not None: self.tb_writer = SummaryWriter(log_dir=os.path.join(self.tensorboard_log, tb_log_name)) while self.num_timesteps < total_timesteps: @@ -193,7 +229,7 @@ class PPO(BaseRLModel): self.num_timesteps += self.n_steps * self.n_envs timesteps_since_eval += self.n_steps * self.n_envs - self.train(self.n_optim, batch_size=self.batch_size) + self.train(self.n_epochs, batch_size=self.batch_size) # Evaluate agent if 0 < eval_freq <= timesteps_since_eval and eval_env is not None: diff --git a/torchy_baselines/sac/__init__.py b/torchy_baselines/sac/__init__.py index 6f70061..1132a37 100644 --- a/torchy_baselines/sac/__init__.py +++ b/torchy_baselines/sac/__init__.py @@ -1 +1,2 @@ from torchy_baselines.sac.sac import SAC +from torchy_baselines.sac.policies import MlpPolicy diff --git a/torchy_baselines/sac/policies.py b/torchy_baselines/sac/policies.py index 34a1534..4e1b544 100644 --- a/torchy_baselines/sac/policies.py +++ b/torchy_baselines/sac/policies.py @@ -69,7 +69,7 @@ class Critic(BaseNetwork): class SACPolicy(BasePolicy): def __init__(self, observation_space, action_space, - learning_rate=1e-3, net_arch=None, device='cpu', + learning_rate=3e-4, net_arch=None, device='cpu', activation_fn=nn.ReLU): super(SACPolicy, self).__init__(observation_space, action_space, device) self.obs_dim = self.observation_space.shape[0] @@ -96,9 +96,6 @@ class SACPolicy(BasePolicy): self.critic_target.load_state_dict(self.critic.state_dict()) self.critic.optimizer = th.optim.Adam(self.critic.parameters(), lr=learning_rate) - def actor_forward(self, state, deterministic=False): - pass - def make_actor(self): return Actor(**self.net_args).to(self.device) diff --git a/torchy_baselines/sac/sac.py b/torchy_baselines/sac/sac.py index 07fcc52..56c741c 100644 --- a/torchy_baselines/sac/sac.py +++ b/torchy_baselines/sac/sac.py @@ -48,6 +48,8 @@ class SAC(BaseRLModel): :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug :param seed: (int) Seed for the pseudo random generators + :param device: (str or th.device) Device (cpu, cuda, ...) on which the code should be run. + Setting it to auto, the code will be run on the GPU if possible. :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance """ def __init__(self, policy, env, learning_rate=3e-4, buffer_size=int(1e6), diff --git a/torchy_baselines/td3/__init__.py b/torchy_baselines/td3/__init__.py index 51225e6..148be49 100644 --- a/torchy_baselines/td3/__init__.py +++ b/torchy_baselines/td3/__init__.py @@ -1 +1,2 @@ from torchy_baselines.td3.td3 import TD3 +from torchy_baselines.td3.policies import MlpPolicy diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index 4a8b69c..8fc00ae 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -42,6 +42,8 @@ class TD3(BaseRLModel): :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug :param seed: (int) Seed for the pseudo random generators + :param device: (str or th.device) Device (cpu, cuda, ...) on which the code should be run. + Setting it to auto, the code will be run on the GPU if possible. :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance """ def __init__(self, policy, env, buffer_size=int(1e6), learning_rate=1e-3,