From fc6c5d3daa6a84c1bf4cf67edd676e164a7b88b2 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sun, 11 Oct 2020 23:22:12 +0200 Subject: [PATCH] Migration Guide (#123) * Start migration guide * Update guide * Add comment on RMSpropTFLike plus PPO/A2C migrations * Add note about set/get-parameters * Update migration guide * Update changelog and readme * Update doc + clean changelog * Address comments Co-authored-by: Anssi "Miffyli" Kanervisto --- README.md | 4 +- docs/guide/install.rst | 2 +- docs/guide/migration.rst | 190 ++++++++++++++++++++++++++++++++++++++- docs/misc/changelog.rst | 8 +- 4 files changed, 195 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 95e7c15..94c005d 100644 --- a/README.md +++ b/README.md @@ -50,9 +50,9 @@ Planned features: - [ ] TRPO -## Migration guide +## Migration guide: from Stable-Baselines (SB2) to Stable-Baselines3 (SB3) -**TODO: migration guide from Stable-Baselines in the documentation** +A migration guide from SB2 to SB3 can be found in the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/migration.html). ## Documentation diff --git a/docs/guide/install.rst b/docs/guide/install.rst index 27c431b..e5ec926 100644 --- a/docs/guide/install.rst +++ b/docs/guide/install.rst @@ -29,7 +29,7 @@ To install Stable Baselines3 with pip, execute: pip install stable-baselines3[extra] -This includes an optional dependencies like Tensorboard, OpenCV or ```atari-py``` to train on atari games. If you do not need those, you can use: +This includes an optional dependencies like Tensorboard, OpenCV or ``atari-py`` to train on atari games. If you do not need those, you can use: .. code-block:: bash diff --git a/docs/guide/migration.rst b/docs/guide/migration.rst index 263c939..82fbc69 100644 --- a/docs/guide/migration.rst +++ b/docs/guide/migration.rst @@ -5,8 +5,194 @@ Migrating from Stable-Baselines ================================ -This is a guide to migrate from Stable-Baselines to Stable-Baselines3. +This is a guide to migrate from Stable-Baselines (SB2) to Stable-Baselines3 (SB3). It also references the main changes. -**TODO** + +Overview +======== + +Overall Stable-Baselines3 (SB3) keeps the high-level API of Stable-Baselines (SB2). +Most of the changes are to ensure more consistency and are internal ones. +Because of the backend change, from Tensorflow to PyTorch, the internal code is much much readable and easy to debug +at the cost of some speed (dynamic graph vs static graph., see `Issue #90 `_) +However, the algorithms were extensively benchmarked on Atari games and continuous control PyBullet envs +(see `Issue #48 `_ and `Issue #49 `_) +so you should not expect performance drop when switching from SB2 to SB3. + + +How to migrate? +=============== + +In most cases, replacing ``from stable_baselines`` by ``from stable_baselines3`` will be sufficient. +Some files were moved to the common folder (cf below) and could result to import errors. +Some algorithms were removed because of their complexity to improve the maintainability of the project. +We recommend reading this guide carefully to understand all the changes that were made. +You can also take a look at the `rl-zoo3 `_ and compare the imports +to the `rl-zoo `_ of SB2 to have a concrete example of successful migration. + + +Breaking Changes +================ + + +- SB3 requires python 3.6+ (instead of python 3.5+ for SB2) +- Dropped MPI support +- Dropped layer normalized policies (e.g. ``LnMlpPolicy``) +- Dropped parameter noise for DDPG and DQN +- PPO is now closer to the original implementation (no clipping of the value function by default), cf PPO section below +- Orthogonal initialization is only used by A2C/PPO +- The features extractor (CNN extractor) is shared between policy and q-networks for DDPG/SAC/TD3 and only the policy loss used to update it (much faster) +- Tensorboard legacy logging was dropped in favor of having one logger for the terminal and Tensorboard (cf :ref:`Tensorboard integration `) +- We dropped ACKTR/ACER support because of their complexity compared to simpler alternatives (PPO, SAC, TD3) performing as good. +- We dropped GAIL support as we are focusing on model-free RL only, you can however take a look at the `Imitation Learning Baseline Implementations `_ + which are based on SB3. +- ``action_probability`` is currently not implemented in the base class + +You can take a look at the `issue about SB3 implementation design `_ for more details. + + +Moved Files +----------- + +- ``bench/monitor.py`` -> ``common/monitor.py`` +- ``logger.py`` -> ``common/logger.py`` +- ``results_plotter.py`` -> ``common/results_plotter.py`` + +Utility functions are no longer exported from ``common`` module, you should import them with their absolute path, e.g.: + +.. code-block:: python + + from stable_baselines3.common.cmd_util import make_atari_env, make_vec_env + from stable_baselines3.common.utils import set_random_seed + +instead of ``from stable_baselines3.common import make_atari_env`` + + + +Changes and renaming in parameters +---------------------------------- + +Base-class (all algorithms) +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +- ``load_parameters`` -> ``set_parameters`` + + - ``get/set_parameters`` return a dictionary mapping object names + to their respective PyTorch tensors and other objects representing + their parameters, instead of simpler mapping of parameter name to + a NumPy array. These functions also return PyTorch tensors rather + than NumPy arrays. + + +Policies +^^^^^^^^ + +- ``cnn_extractor`` -> ``feature_extractor``, as ``feature_extractor`` in now used with ``MlpPolicy`` too + +A2C +^^^ + +- ``epsilon`` -> ``rms_prop_eps`` +- ``lr_schedule`` is part of ``learning_rate`` (it can be a callable). +- ``alpha``, ``momentum`` are modifiable through ``policy_kwargs`` key ``optimizer_kwargs``. + +.. warning:: + + PyTorch implementation of RMSprop `differs from Tensorflow's `_, + which leads to `different and potentially more unstable results `_. + Use ``stable_baselines3.common.sb2_compat.rmsprop_tf_like.RMSpropTFLike`` optimizer to match the results + with Tensorflow's implementation. This can be done through ``policy_kwargs``: ``A2C(policy_kwargs=dict(optimizer_class=RMSpropTFLike))`` + + +PPO +^^^ + +- ``cliprange`` -> ``clip_range`` +- ``cliprange_vf`` -> ``clip_range_vf`` +- ``nminibatches`` -> ``batch_size`` + +.. warning:: + + ``nminibatches`` gave different batch size depending on the number of environments: ``batch_size = (n_steps * n_envs) // nminibatches`` + + +- ``clip_range_vf`` behavior for PPO is slightly different: Set it to ``None`` (default) to deactivate clipping (in SB2, you had to pass ``-1``, ``None`` meant to use ``clip_range`` for the clipping) +- ``lam`` -> ``gae_lambda`` +- ``noptepochs`` -> ``n_epochs`` + +PPO default hyperparameters are the one tuned for continuous control environment. +We recommend taking a look at the :ref:`RL Zoo ` for hyperparameters tuned for Atari games. + + +DQN +^^^ + +Only the vanilla DQN is implemented right now but extensions will follow. +Default hyperparameters are taken from the nature paper, except for the optimizer and learning rate that were taken from Stable Baselines defaults. + +DDPG +^^^^ + +DDPG now follows the same interface as SAC/TD3. +For state/reward normalization, you should use ``VecNormalize`` as for all other algorithms. + +SAC/TD3 +^^^^^^^ + +SAC/TD3 now accept any number of critics, e.g. ``policy_kwargs=dict(n_critics=3)``, instead of only two before. + + +.. note:: + + SAC/TD3 default hyperparameters (including network architecture) now match the ones from the original papers. + DDPG is using TD3 defaults. + + +SAC +^^^ + +SAC implementation matches the latest version of the original implementation: it uses two Q function networks and two target Q function networks +instead of two Q function networks and one Value function network (SB2 implementation, first version of the original implementation). +Despite this change, no change in performance should be expected. + +.. note:: + + SAC ``predict()`` method has now ``deterministic=False`` by default for consistency. + To match SB2 behavior, you need to explicitly pass ``deterministic=True`` + + + +New logger API +-------------- + +- Methods were renamed in the logger: + + - ``logkv`` -> ``record``, ``writekvs`` -> ``write``, ``writeseq`` -> ``write_sequence``, + - ``logkvs`` -> ``record_dict``, ``dumpkvs`` -> ``dump``, + - ``getkvs`` -> ``get_log_dict``, ``logkv_mean`` -> ``record_mean``, + + +Internal Changes +---------------- + +Please read the :ref:`Developer Guide ` section. + + +New Features (SB3 vs SB2) +========================= + +- Much cleaner and consistent base code (and no more warnings =D!) and static type checks +- Independent saving/loading/predict for policies +- A2C now supports Generalized Advantage Estimation (GAE) and advantage normalization (both are deactivated by default) +- Generalized State-Dependent Exploration (gSDE) exploration is available for A2C/PPO/SAC. It allows to use RL directly on real robots (cf https://arxiv.org/abs/2005.05719) +- Proper evaluation (using separate env) is included in the base class (using ``EvalCallback``), + if you pass the environment as a string, you can pass ``create_eval_env=True`` to the algorithm constructor. +- Better saving/loading: optimizers are now included in the saved parameters and there is two new methods ``save_replay_buffer`` and ``load_replay_buffer`` for the replay buffer when using off-policy algorithms (DQN/DDPG/SAC/TD3) +- You can pass ``optimizer_class`` and ``optimizer_kwargs`` to ``policy_kwargs`` in order to easily + customize optimizers +- Seeding now works properly to have deterministic results +- Replay buffer does not grow, allocate everything at build time (faster) +- We added a memory efficient replay buffer variant (pass ``optimize_memory_usage=True`` to the constructor), it reduces drastically the memory used especially when using images +- You can specify an arbitrary number of critics for SAC/TD3 (e.g. ``policy_kwargs=dict(n_critics=3)``) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 90b5322..da32a8e 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -26,6 +26,7 @@ Others: Documentation: ^^^^^^^^^^^^^^ +- Added first draft of migration guide Pre-Release 0.9.0 (2020-10-03) @@ -36,8 +37,7 @@ Pre-Release 0.9.0 (2020-10-03) Breaking Changes: ^^^^^^^^^^^^^^^^^ - Removed ``device`` keyword argument of policies; use ``policy.to(device)`` instead. (@qxcv) -- Rename ``BaseClass.get_torch_variables`` -> ``BaseClass._get_torch_save_params`` and - ``BaseClass.excluded_save_params`` -> ``BaseClass._excluded_save_params`` +- Rename ``BaseClass.get_torch_variables`` -> ``BaseClass._get_torch_save_params`` and ``BaseClass.excluded_save_params`` -> ``BaseClass._excluded_save_params`` - Renamed saved items ``tensors`` to ``pytorch_variables`` for clarity - ``make_atari_env``, ``make_vec_env`` and ``set_random_seed`` must be imported with (and not directly from ``stable_baselines3.common``): @@ -73,7 +73,7 @@ Others: - Removed ``AlreadySteppingError`` and ``NotSteppingError`` that were not used - Fixed typos in SAC and TD3 - Reorganized functions for clarity in ``BaseClass`` (save/load functions close to each other, private - functions at top) + functions at top) - Clarified docstrings on what is saved and loaded to/from files - Simplified ``save_to_zip_file`` function by removing duplicate code - Store library version along with the saved models @@ -100,7 +100,7 @@ Breaking Changes: - Refactored ``Critic`` class for ``TD3`` and ``SAC``, it is now called ``ContinuousCritic`` and has an additional parameter ``n_critics`` - ``SAC`` and ``TD3`` now accept an arbitrary number of critics (e.g. ``policy_kwargs=dict(n_critics=3)``) - instead of only 2 previously + instead of only 2 previously New Features: ^^^^^^^^^^^^^