Merge branch 'master' into feat/redq

This commit is contained in:
Antonin Raffin 2022-08-17 14:52:47 +02:00
commit ebf6ed1d0a
No known key found for this signature in database
GPG key ID: B8B48F65CAD6232C
77 changed files with 724 additions and 398 deletions

View file

@ -28,7 +28,7 @@ jobs:
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch==1.11+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install .[extra,tests,docs]
# Use headless version
pip install opencv-python-headless

View file

@ -9,6 +9,7 @@ pytest:
- python --version
# MKL_THREADING_LAYER=GNU to avoid MKL_THREADING_LAYER=INTEL incompatibility error
- MKL_THREADING_LAYER=GNU make pytest
coverage: '/^TOTAL.+?(\d+\%)$/'
doc-build:
script:

View file

@ -51,7 +51,7 @@ Documentation is available online: [https://stable-baselines3.readthedocs.io/](h
## Integrations
Stable-Baselines3 has some integration with other libraries/services like Weights & Biases for experiment tracking or Hugging Face for storing/sharing trained models. You can find out more in the [dedicated section](https://stable-baselines3.readthedocs.io/en/master/guide/integrations.html) of the documentation.
Stable-Baselines3 has some integration with other libraries/services like Weights & Biases for experiment tracking or Hugging Face for storing/sharing trained models. You can find out more in the [dedicated section](https://stable-baselines3.readthedocs.io/en/master/guide/integrations.html) of the documentation.
## RL Baselines3 Zoo: A Training Framework for Stable Baselines3 Reinforcement Learning Agents
@ -77,14 +77,14 @@ Documentation: https://stable-baselines3.readthedocs.io/en/master/guide/rl_zoo.h
We implement experimental features in a separate contrib repository: [SB3-Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib)
This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Truncated Quantile Critics (TQC), Quantile Regression DQN (QR-DQN) or PPO with invalid action masking (Maskable PPO).
This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Recurrent PPO (PPO LSTM), Truncated Quantile Critics (TQC), Quantile Regression DQN (QR-DQN) or PPO with invalid action masking (Maskable PPO).
Documentation is available online: [https://sb3-contrib.readthedocs.io/](https://sb3-contrib.readthedocs.io/)
## Installation
**Note:** Stable-Baselines3 supports PyTorch >= 1.8.1.
**Note:** Stable-Baselines3 supports PyTorch >= 1.11
### Prerequisites
Stable Baselines3 requires Python 3.7+.
@ -122,7 +122,7 @@ from stable_baselines3 import PPO
env = gym.make("CartPole-v1")
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10000)
model.learn(total_timesteps=10_000)
obs = env.reset()
for i in range(1000):
@ -140,7 +140,7 @@ Or just train a model with a one liner if [the environment is registered in Gym]
```python
from stable_baselines3 import PPO
model = PPO('MlpPolicy', 'CartPole-v1').learn(10000)
model = PPO("MlpPolicy", "CartPole-v1").learn(10_000)
```
Please read the [documentation](https://stable-baselines3.readthedocs.io/) for more examples.
@ -172,6 +172,7 @@ All the following examples can be executed online using Google colab notebooks:
| HER | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :x: |
| PPO | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| QR-DQN<sup>[1](#f1)</sup> | :x: | :x: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| RecurrentPPO<sup>[1](#f1)</sup> | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| SAC | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
| TD3 | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
| TQC<sup>[1](#f1)</sup> | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
@ -231,7 +232,7 @@ To cite this repository in publications:
## Maintainers
Stable-Baselines3 is currently maintained by [Ashley Hill](https://github.com/hill-a) (aka @hill-a), [Antonin Raffin](https://araffin.github.io/) (aka [@araffin](https://github.com/araffin)), [Maximilian Ernestus](https://github.com/ernestum) (aka @ernestum), [Adam Gleave](https://github.com/adamgleave) (@AdamGleave) and [Anssi Kanervisto](https://github.com/Miffyli) (@Miffyli).
Stable-Baselines3 is currently maintained by [Ashley Hill](https://github.com/hill-a) (aka @hill-a), [Antonin Raffin](https://araffin.github.io/) (aka [@araffin](https://github.com/araffin)), [Maximilian Ernestus](https://github.com/ernestum) (aka @ernestum), [Adam Gleave](https://github.com/adamgleave) (@AdamGleave), [Anssi Kanervisto](https://github.com/Miffyli) (@Miffyli) and [Quentin Gallouédec](https://gallouedec.com/) (@qgallouedec).
**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.

BIN
docs/_static/img/split_graph.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 KiB

View file

@ -6,9 +6,9 @@ dependencies:
- cpuonly=1.0=0
- pip=21.1
- python=3.7
- pytorch=1.8.1=py3.7_cpu_0
- pytorch=1.11=py3.7_cpu_0
- pip:
- gym>=0.17.2
- gym==0.21
- cloudpickle
- opencv-python-headless
- pandas
@ -16,5 +16,5 @@ dependencies:
- matplotlib
- sphinx_autodoc_typehints
- sphinx>=4.2
# See https://github.com/readthedocs/sphinx_rtd_theme/issues/1115
- sphinx_rtd_theme>=1.0
- sphinx_copybutton

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
#
# Configuration file for the Sphinx documentation builder.
#
@ -25,6 +24,14 @@ try:
except ImportError:
enable_spell_check = False
# Try to enable copy button
try:
import sphinx_copybutton # noqa: F401
enable_copy_button = True
except ImportError:
enable_copy_button = False
# source code directory, relative to this file, for sphinx-autobuild
sys.path.insert(0, os.path.abspath(".."))
@ -46,13 +53,13 @@ sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
# Read version from file
version_file = os.path.join(os.path.dirname(__file__), "../stable_baselines3", "version.txt")
with open(version_file, "r") as file_handler:
with open(version_file) as file_handler:
__version__ = file_handler.read().strip()
# -- Project information -----------------------------------------------------
project = "Stable Baselines3"
copyright = "2020, Stable Baselines3"
copyright = "2022, Stable Baselines3"
author = "Stable Baselines3 Contributors"
# The short X.Y version
@ -84,6 +91,9 @@ extensions = [
if enable_spell_check:
extensions.append("sphinxcontrib.spelling")
if enable_copy_button:
extensions.append("sphinx_copybutton")
# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
@ -101,7 +111,7 @@ master_doc = "index"
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = None
language = "en"
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.

View file

@ -15,6 +15,7 @@ DQN ❌ ✔️ ❌ ❌
HER ✔️ ✔️ ❌ ❌ ❌
PPO ✔️ ✔️ ✔️ ✔️ ✔️
QR-DQN [#f1]_ ✔️ ❌ ❌ ✔️
RecurrentPPO [#f1]_ ✔️ ✔️ ✔️ ✔️ ✔️
SAC ✔️ ❌ ❌ ❌ ✔️
TD3 ✔️ ❌ ❌ ❌ ✔️
TQC [#f1]_ ✔️ ❌ ❌ ❌ ✔️
@ -26,8 +27,8 @@ Maskable PPO [#f1]_ ❌ ✔️ ✔️ ✔
.. [#f1] Implemented in `SB3 Contrib <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib>`_
.. note::
``Tuple`` observation spaces are not supported by any environment
however single-level ``Dict`` spaces are (cf. :ref:`Examples <examples>`).
``Tuple`` observation spaces are not supported by any environment,
however, single-level ``Dict`` spaces are (cf. :ref:`Examples <examples>`).
Actions ``gym.spaces``:

View file

@ -61,7 +61,7 @@ Then you can define and train a RL agent with:
model = A2C('CnnPolicy', env).learn(total_timesteps=1000)
To check that your environment follows the gym interface, please use:
To check that your environment follows the Gym interface that SB3 supports, please use:
.. code-block:: python
@ -71,11 +71,11 @@ To check that your environment follows the gym interface, please use:
# It will check your custom environment and output additional warnings if needed
check_env(env)
Gym also have its own `env checker <https://www.gymlibrary.ml/content/api/#checking-api-conformity>`_ but it checks a superset of what SB3 supports (SB3 does not support all Gym features).
We have created a `colab notebook <https://colab.research.google.com/github/araffin/rl-tutorial-jnrr19/blob/master/5_custom_gym_env.ipynb>`_ for a concrete example on creating a custom environment along with an example of using it with Stable-Baselines3 interface.
Alternatively, you may look at OpenAI Gym `built-in environments <https://gym.openai.com/docs/#available-environments>`_. However, the readers are cautioned as per OpenAI Gym `official wiki <https://github.com/openai/gym/wiki/FAQ>`_, its advised not to customize their built-in environments. It is better to copy and create new ones if you need to modify them.
Alternatively, you may look at OpenAI Gym `built-in environments <https://www.gymlibrary.ml/>`_. However, the readers are cautioned as per OpenAI Gym `official wiki <https://github.com/openai/gym/wiki/FAQ>`_, its advised not to customize their built-in environments. It is better to copy and create new ones if you need to modify them.
Optionally, you can also register the environment with gym, that will allow you to create the RL agent in one line (and use ``gym.make()`` to instantiate the env):

View file

@ -729,6 +729,16 @@ to keep track of the agent progress.
model.learn(10_000)
SB3 with EnvPool or Isaac Gym
-----------------------------
Just like Procgen (see above), `EnvPool <https://github.com/sail-sg/envpool>`_ and `Isaac Gym <https://github.com/NVIDIA-Omniverse/IsaacGymEnvs>`_ accelerate the environment by
already providing a vectorized implementation.
To use SB3 with those tools, you must wrap the env with tool's specific ``VecEnvWrapper`` that will pre-process the data for SB3,
you can find links to those wrappers in `issue #772 <https://github.com/DLR-RM/stable-baselines3/issues/772#issuecomment-1048657002>`_.
Record a Video
--------------

View file

@ -6,7 +6,7 @@ Installation
Prerequisites
-------------
Stable-Baselines3 requires python 3.7+ and PyTorch >= 1.8.1.
Stable-Baselines3 requires python 3.7+ and PyTorch >= 1.11
Windows 10
~~~~~~~~~~
@ -54,6 +54,17 @@ Bleeding-edge version
pip install git+https://github.com/DLR-RM/stable-baselines3
.. note::
If you want to use latest gym version (0.24+), you have to use
.. code-block:: bash
pip install git+https://github.com/carlosluis/stable-baselines3@fix_tests
See `PR #780 <https://github.com/DLR-RM/stable-baselines3/pull/780>`_ for more information.
Development version
-------------------

View file

@ -48,11 +48,14 @@ Hugging Face 🤗
The Hugging Face Hub 🤗 is a central place where anyone can share and explore models. It allows you to host your saved models 💾.
You can see the list of stable-baselines3 saved models here: https://huggingface.co/models?other=stable-baselines3
Most of them are available via the RL Zoo.
Official pre-trained models are saved in the SB3 organization on the hub: https://huggingface.co/sb3
We wrote a tutorial on how to use 🤗 Hub and Stable-Baselines3 here: https://colab.research.google.com/drive/1GI0WpThwRHbl-Fu2RHfczq6dci5GBDVE#scrollTo=q4cz-w9MdO7T
For up to date instructions (for instance for using ``package_to_hub()``), please take a look at the Huggingface SB3 package README: https://github.com/huggingface/huggingface_sb3
Installation
-------------
@ -137,3 +140,56 @@ Then, in this example, we train a PPO agent to play CartPole-v1 and push it to a
filename="ppo-CartPole-v1",
commit_message="Added Cartpole-v1 model trained with PPO",
)
MLFLow
======
If you want to use `MLFLow <https://github.com/mlflow/mlflow>`_ to track your SB3 experiments,
you can adapt the following code which defines a custom logger output:
.. code-block:: python
import sys
from typing import Any, Dict, Tuple, Union
import mlflow
import numpy as np
from stable_baselines3 import SAC
from stable_baselines3.common.logger import HumanOutputFormat, KVWriter, Logger
class MLflowOutputFormat(KVWriter):
"""
Dumps key/value pairs into MLflow's numeric format.
"""
def write(
self,
key_values: Dict[str, Any],
key_excluded: Dict[str, Union[str, Tuple[str, ...]]],
step: int = 0,
) -> None:
for (key, value), (_, excluded) in zip(
sorted(key_values.items()), sorted(key_excluded.items())
):
if excluded is not None and "mlflow" in excluded:
continue
if isinstance(value, np.ScalarType):
if not isinstance(value, str):
mlflow.log_metric(key, value, step)
loggers = Logger(
folder=None,
output_formats=[HumanOutputFormat(sys.stdout), MLflowOutputFormat()],
)
with mlflow.start_run():
model = SAC("MlpPolicy", "Pendulum-v1", verbose=2)
# Set custom logger
model.set_logger(loggers)
model.learn(total_timesteps=10000, log_interval=1)

View file

@ -141,7 +141,7 @@ DQN
^^^
Only the vanilla DQN is implemented right now but extensions will follow.
Default hyperparameters are taken from the nature paper, except for the optimizer and learning rate that were taken from Stable Baselines defaults.
Default hyperparameters are taken from the Nature paper, except for the optimizer and learning rate that were taken from Stable Baselines defaults.
DDPG
^^^^

View file

@ -183,6 +183,16 @@ Some basic advice:
- start with shaped reward (i.e. informative reward) and simplified version of your problem
- debug with random actions to check that your environment works and follows the gym interface:
Two important things to keep in mind when creating a custom environment is to avoid breaking Markov assumption
and properly handle termination due to a timeout (maximum number of steps in an episode).
For instance, if there is some time delay between action and observation (e.g. due to wifi communication), you should give an history of observations
as input.
Termination due to timeout (max number of steps per episode) needs to be handled separately. You should fill the key in the info dict: ``info["TimeLimit.truncated"] = True``.
If you are using the gym ``TimeLimit`` wrapper, this will be done automatically.
You can read `Time Limit in RL <https://arxiv.org/abs/1712.00378>`_ or take a look at the `RL Tips and Tricks video <https://www.youtube.com/watch?v=Ikngt0_DXJg>`_
for more details.
We provide a helper to check that your environment runs without error:
@ -241,12 +251,15 @@ We *recommend following those steps to have a working RL algorithm*:
1. Read the original paper several times
2. Read existing implementations (if available)
3. Try to have some "sign of life" on toy problems
4. Validate the implementation by making it run on harder and harder envs (you can compare results against the RL zoo)
You usually need to run hyperparameter optimization for that step.
4. Validate the implementation by making it run on harder and harder envs (you can compare results against the RL zoo).
You usually need to run hyperparameter optimization for that step.
You need to be particularly careful on the shape of the different objects you are manipulating (a broadcast mistake will fail silently cf `issue #75 <https://github.com/hill-a/stable-baselines/pull/76>`_)
You need to be particularly careful on the shape of the different objects you are manipulating (a broadcast mistake will fail silently cf. `issue #75 <https://github.com/hill-a/stable-baselines/pull/76>`_)
and when to stop the gradient propagation.
Don't forget to handle termination due to timeout separately (see remark in the custom environment section above),
you can also take a look at `Issue #284 <https://github.com/DLR-RM/stable-baselines3/issues/284>`_ and `Issue #633 <https://github.com/DLR-RM/stable-baselines3/issues/633>`_.
A personal pick (by @araffin) for environments with gradual difficulty in RL with continuous actions:
1. Pendulum (easy to solve)

View file

@ -8,7 +8,7 @@ We implement experimental features in a separate contrib repository:
`SB3-Contrib`_
This allows Stable-Baselines3 (SB3) to maintain a stable and compact core, while still
providing the latest features, like Truncated Quantile Critics (TQC), Augmented Random Search (ARS), Trust Region Policy Optimization (TRPO) or
providing the latest features, like RecurrentPPO (PPO LSTM), Truncated Quantile Critics (TQC), Augmented Random Search (ARS), Trust Region Policy Optimization (TRPO) or
Quantile Regression DQN (QR-DQN).
Why create this repository?
@ -38,9 +38,11 @@ See documentation for the full list of included features.
- `Augmented Random Search (ARS) <https://arxiv.org/abs/1803.07055>`_
- `Quantile Regression DQN (QR-DQN)`_
- `PPO with invalid action masking (Maskable PPO) <https://arxiv.org/abs/2006.14171>`_
- `PPO with recurrent policy (RecurrentPPO aka PPO LSTM) <https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/>`_
- `Truncated Quantile Critics (TQC)`_
- `Trust Region Policy Optimization (TRPO) <https://arxiv.org/abs/1502.05477>`_
- `PPO with invalid action masking (Maskable PPO) <https://arxiv.org/abs/2006.14171>`_
**Gym Wrappers**:

View file

@ -26,10 +26,20 @@ You can also define custom logging name when training (by default it is the algo
model.learn(total_timesteps=10_000, tb_log_name="first_run")
# Pass reset_num_timesteps=False to continue the training curve in tensorboard
# By default, it will create a new curve
# Keep tb_log_name constant to have continuous curve (see note below)
model.learn(total_timesteps=10_000, tb_log_name="second_run", reset_num_timesteps=False)
model.learn(total_timesteps=10_000, tb_log_name="third_run", reset_num_timesteps=False)
.. note::
If you specify different ``tb_log_name`` in subsequent runs, you will have split graphs, like in the figure below.
If you want them to be continuous, you must keep the same ``tb_log_name`` (see `issue #975 <https://github.com/DLR-RM/stable-baselines3/issues/975#issuecomment-1198992211>`_).
And, if you still managed to get your graphs split by other means, just put tensorboard log files into the same folder.
.. image:: ../_static/img/split_graph.png
:width: 330
:alt: split_graph
Once the learn function is called, you can monitor the RL agent during or after the training, with the following bash command:
.. code-block:: bash

View file

@ -3,8 +3,7 @@
Changelog
==========
Release 1.5.1a0 (WIP)
Release 1.6.1a0 (WIP)
---------------------------
Breaking Changes:
@ -18,16 +17,80 @@ SB3-Contrib
Bug Fixes:
^^^^^^^^^^
- Fixed saving and loading large policies greater than 2GB (@jkterry1, @ycheng517)
- Fixed the issue that ``predict`` does not always return action as ``np.ndarray`` (@qgallouedec)
- Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers.
- Added multidimensional action space support (@qgallouedec)
- Fixed missing verbose parameter passing in the ``EvalCallback`` constructor (@burakdmb)
Deprecations:
^^^^^^^^^^^^^
Others:
^^^^^^^
- Fixed ``DictReplayBuffer.next_observations`` typing (@qgallouedec)
- Added support for ``device="auto"`` in buffers and made it default (@qgallouedec)
Documentation:
^^^^^^^^^^^^^^
- Fixed typo in docstring "nature" -> "Nature" (@Melanol)
- Added info on split tensorboard logs into (@Melanol)
- Fixed typo in ppo doc (@francescoluciano)
- Fixed typo in install doc(@jlp-ue)
Release 1.6.0 (2022-07-11)
---------------------------
**Recurrent PPO (PPO LSTM), better defaults for learning from pixels with SAC/TD3**
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Changed the way policy "aliases" are handled ("MlpPolicy", "CnnPolicy", ...), removing the former
``register_policy`` helper, ``policy_base`` parameter and using ``policy_aliases`` static attributes instead (@Gregwar)
- SB3 now requires PyTorch >= 1.11
- Changed the default network architecture when using ``CnnPolicy`` or ``MultiInputPolicy`` with SAC or DDPG/TD3,
``share_features_extractor`` is now set to False by default and the ``net_arch=[256, 256]`` (instead of ``net_arch=[]`` that was before)
New Features:
^^^^^^^^^^^^^
SB3-Contrib
^^^^^^^^^^^
- Added Recurrent PPO (PPO LSTM). See https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/53
Bug Fixes:
^^^^^^^^^^
- Fixed saving and loading large policies greater than 2GB (@jkterry1, @ycheng517)
- Fixed final goal selection strategy that did not sample the final achieved goal (@qgallouedec)
- Fixed a bug with special characters in the tensorboard log name (@quantitative-technologies)
- Fixed a bug in ``DummyVecEnv``'s and ``SubprocVecEnv``'s seeding function. None value was unchecked (@ScheiklP)
- Fixed a bug where ``EvalCallback`` would crash when trying to synchronize ``VecNormalize`` stats when observation normalization was disabled
- Added a check for unbounded actions
- Fixed issues due to newer version of protobuf (tensorboard) and sphinx
- Fix exception causes all over the codebase (@cool-RR)
- Prohibit simultaneous use of optimize_memory_usage and handle_timeout_termination due to a bug (@MWeltevrede)
- Fixed a bug in ``kl_divergence`` check that would fail when using numpy arrays with MultiCategorical distribution
Deprecations:
^^^^^^^^^^^^^
Others:
^^^^^^^
- Upgraded to Python 3.7+ syntax using ``pyupgrade``
- Removed redundant double-check for nested observations from ``BaseAlgorithm._wrap_env`` (@TibiGG)
Documentation:
^^^^^^^^^^^^^^
- Added link to gym doc and gym env checker
- Fix typo in PPO doc (@bcollazo)
- Added link to PPO ICLR blog post
- Added remark about breaking Markov assumption and timeout handling
- Added doc about MLFlow integration via custom logger (@git-thor)
- Updated Huggingface integration doc
- Added copy button for code snippets
- Added doc about EnvPool and Isaac Gym support
Release 1.5.0 (2022-03-25)
@ -923,7 +986,8 @@ Maintainers
-----------
Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_), `Ashley Hill`_ (aka @hill-a),
`Maximilian Ernestus`_ (aka @ernestum), `Adam Gleave`_ (`@AdamGleave`_) and `Anssi Kanervisto`_ (aka `@Miffyli`_).
`Maximilian Ernestus`_ (aka @ernestum), `Adam Gleave`_ (`@AdamGleave`_), `Anssi Kanervisto`_ (aka `@Miffyli`_)
and `Quentin Gallouédec`_ (aka @qgallouedec).
.. _Ashley Hill: https://github.com/hill-a
.. _Antonin Raffin: https://araffin.github.io/
@ -933,6 +997,8 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_)
.. _@AdamGleave: https://github.com/adamgleave
.. _Anssi Kanervisto: https://github.com/Miffyli
.. _@Miffyli: https://github.com/Miffyli
.. _Quentin Gallouédec: https://gallouedec.com/
.. _@qgallouedec: https://github.com/qgallouedec
@ -957,4 +1023,5 @@ And all the contributors:
@wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum
@eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485
@Gregwar @ycheng517
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb

View file

@ -8,14 +8,14 @@ PPO
The `Proximal Policy Optimization <https://arxiv.org/abs/1707.06347>`_ algorithm combines ideas from A2C (having multiple workers)
and TRPO (it uses a trust region to improve the actor).
The main idea is that after an update, the new policy should be not too far form the old policy.
The main idea is that after an update, the new policy should be not too far from the old policy.
For that, ppo uses clipping to avoid too large update.
.. note::
PPO contains several modifications from the original algorithm not documented
by OpenAI: advantages are normalized and value function can be also clipped .
by OpenAI: advantages are normalized and value function can be also clipped.
Notes
@ -25,11 +25,22 @@ Notes
- Clear explanation of PPO on Arxiv Insights channel: https://www.youtube.com/watch?v=5P7I-xPq8u8
- OpenAI blog post: https://blog.openai.com/openai-baselines-ppo/
- Spinning Up guide: https://spinningup.openai.com/en/latest/algorithms/ppo.html
- 37 implementation details blog: https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/
Can I use?
----------
.. note::
A recurrent version of PPO is available in our contrib repo: https://sb3-contrib.readthedocs.io/en/master/modules/ppo_recurrent.html
However we advise users to start with simple frame-stacking as a simpler, faster
and usually competitive alternative, more info in our report: https://wandb.ai/sb3/no-vel-envs/reports/PPO-vs-RecurrentPPO-aka-PPO-LSTM-on-environments-with-masked-velocity--VmlldzoxOTI4NjE4
See also `Procgen paper appendix Fig 11. <https://arxiv.org/abs/1912.01588>`_.
In practice, you can stack multiple observations using ``VecFrameStack``.
- Recurrent policies: ❌
- Multi processing: ✔️
- Gym spaces:

View file

@ -2,7 +2,7 @@ import os
from setuptools import find_packages, setup
with open(os.path.join("stable_baselines3", "version.txt"), "r") as file_handler:
with open(os.path.join("stable_baselines3", "version.txt")) as file_handler:
__version__ = file_handler.read().strip()
@ -43,10 +43,10 @@ import gym
from stable_baselines3 import PPO
env = gym.make('CartPole-v1')
env = gym.make("CartPole-v1")
model = PPO('MlpPolicy', env, verbose=1)
model.learn(total_timesteps=10000)
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10_000)
obs = env.reset()
for i in range(1000):
@ -57,12 +57,12 @@ for i in range(1000):
obs = env.reset()
```
Or just train a model with a one liner if [the environment is registered in Gym](https://github.com/openai/gym/wiki/Environments) and if [the policy is registered](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html):
Or just train a model with a one liner if [the environment is registered in Gym](https://www.gymlibrary.ml/content/environment_creation/) and if [the policy is registered](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html):
```python
from stable_baselines3 import PPO
model = PPO('MlpPolicy', 'CartPole-v1').learn(10000)
model = PPO("MlpPolicy", "CartPole-v1").learn(10_000)
```
""" # noqa:E501
@ -75,7 +75,7 @@ setup(
install_requires=[
"gym==0.21", # Fixed version due to breaking changes in 0.22
"numpy",
"torch>=1.8.1",
"torch>=1.11",
# For saving models
"cloudpickle",
# For reading logs
@ -111,16 +111,21 @@ setup(
"sphinxcontrib.spelling",
# Type hints support
"sphinx-autodoc-typehints",
# Copy button for code snippets
"sphinx_copybutton",
],
"extra": [
# For render
"opencv-python",
# For atari games,
"ale-py~=0.7.4",
"ale-py==0.7.4",
"autorom[accept-rom-license]~=0.4.2",
"pillow",
# Tensorboard support
"tensorboard>=2.2.0",
# Protobuf >= 4 has breaking changes
# which does play well with tensorboard
"protobuf~=3.19.0",
# Checking memory taken by replay buffer
"psutil",
],

View file

@ -11,7 +11,7 @@ from stable_baselines3.td3 import TD3
# Read version from file
version_file = os.path.join(os.path.dirname(__file__), "version.txt")
with open(version_file, "r") as file_handler:
with open(version_file) as file_handler:
__version__ = file_handler.read().strip()

View file

@ -5,7 +5,7 @@ from gym import spaces
from torch.nn import functional as F
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance
@ -51,6 +51,12 @@ class A2C(OnPolicyAlgorithm):
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
policy_aliases: Dict[str, Type[BasePolicy]] = {
"MlpPolicy": ActorCriticPolicy,
"CnnPolicy": ActorCriticCnnPolicy,
"MultiInputPolicy": MultiInputActorCriticPolicy,
}
def __init__(
self,
policy: Union[str, Type[ActorCriticPolicy]],
@ -76,7 +82,7 @@ class A2C(OnPolicyAlgorithm):
_init_setup_model: bool = True,
):
super(A2C, self).__init__(
super().__init__(
policy,
env,
learning_rate=learning_rate,
@ -188,7 +194,7 @@ class A2C(OnPolicyAlgorithm):
reset_num_timesteps: bool = True,
) -> "A2C":
return super(A2C, self).learn(
return super().learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,

View file

@ -1,16 +1,7 @@
# This file is here just to define MlpPolicy/CnnPolicy
# that work for A2C
from stable_baselines3.common.policies import (
ActorCriticCnnPolicy,
ActorCriticPolicy,
MultiInputActorCriticPolicy,
register_policy,
)
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
MlpPolicy = ActorCriticPolicy
CnnPolicy = ActorCriticCnnPolicy
MultiInputPolicy = MultiInputActorCriticPolicy
register_policy("MlpPolicy", ActorCriticPolicy)
register_policy("CnnPolicy", ActorCriticCnnPolicy)
register_policy("MultiInputPolicy", MultiInputPolicy)

View file

@ -245,4 +245,4 @@ class AtariWrapper(gym.Wrapper):
if clip_reward:
env = ClipRewardEnv(env)
super(AtariWrapper, self).__init__(env)
super().__init__(env)

View file

@ -17,7 +17,7 @@ from stable_baselines3.common.env_util import is_wrapped
from stable_baselines3.common.logger import Logger
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.policies import BasePolicy, get_policy_from_name
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space, is_image_space_channels_first
from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
@ -60,7 +60,6 @@ class BaseAlgorithm(ABC):
:param policy: Policy object
:param env: The environment to learn from
(if registered in Gym, can be str. Can be None for loading trained models)
:param policy_base: The base policy used by this method
:param learning_rate: learning rate for the optimizer,
it can be a function of the current progress remaining (from 1 to 0)
:param policy_kwargs: Additional arguments to be passed to the policy on creation
@ -83,11 +82,13 @@ class BaseAlgorithm(ABC):
:param supported_action_spaces: The action spaces supported by the algorithm.
"""
# Policy aliases (see _get_policy_from_name())
policy_aliases: Dict[str, Type[BasePolicy]] = {}
def __init__(
self,
policy: Type[BasePolicy],
env: Union[GymEnv, str, None],
policy_base: Type[BasePolicy],
learning_rate: Union[float, Schedule],
policy_kwargs: Optional[Dict[str, Any]] = None,
tensorboard_log: Optional[str] = None,
@ -101,9 +102,8 @@ class BaseAlgorithm(ABC):
sde_sample_freq: int = -1,
supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None,
):
if isinstance(policy, str) and policy_base is not None:
self.policy_class = get_policy_from_name(policy_base, policy)
if isinstance(policy, str):
self.policy_class = self._get_policy_from_name(policy)
else:
self.policy_class = policy
@ -185,6 +185,11 @@ class BaseAlgorithm(ABC):
if self.use_sde and not isinstance(self.action_space, gym.spaces.Box):
raise ValueError("generalized State-Dependent Exploration (gSDE) can only be used with continuous actions.")
if isinstance(self.action_space, gym.spaces.Box):
assert np.all(
np.isfinite(np.array([self.action_space.low, self.action_space.high]))
), "Continuous action space must have a finite lower and upper bound"
@staticmethod
def _wrap_env(env: GymEnv, verbose: int = 0, monitor_wrapper: bool = True) -> VecEnv:
""" "
@ -209,11 +214,6 @@ class BaseAlgorithm(ABC):
# Make sure that dict-spaces are not nested (not supported)
check_for_nested_spaces(env.observation_space)
if isinstance(env.observation_space, gym.spaces.Dict):
for space in env.observation_space.spaces.values():
if isinstance(space, gym.spaces.Dict):
raise ValueError("Nested observation spaces are not supported (Dict spaces inside Dict space).")
if not is_vecenv_wrapped(env, VecTransposeImage):
wrap_with_vectranspose = False
if isinstance(env.observation_space, gym.spaces.Dict):
@ -325,6 +325,23 @@ class BaseAlgorithm(ABC):
"_custom_logger",
]
def _get_policy_from_name(self, policy_name: str) -> Type[BasePolicy]:
"""
Get a policy class from its name representation.
The goal here is to standardize policy naming, e.g.
all algorithms can call upon "MlpPolicy" or "CnnPolicy",
and they receive respective policies that work for them.
:param policy_name: Alias of the policy
:return: A policy class (type)
"""
if policy_name in self.policy_aliases:
return self.policy_aliases[policy_name]
else:
raise ValueError(f"Policy {policy_name} unknown")
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
"""
Get the name of the torch variables that will be saved with
@ -375,6 +392,7 @@ class BaseAlgorithm(ABC):
log_path=log_path,
eval_freq=eval_freq,
n_eval_episodes=n_eval_episodes,
verbose=self.verbose,
)
callback = CallbackList([callback, eval_callback])
@ -405,7 +423,7 @@ class BaseAlgorithm(ABC):
:param tb_log_name: the name of the run for tensorboard log
:return:
"""
self.start_time = time.time()
self.start_time = time.time_ns()
if self.ep_info_buffer is None or reset_num_timesteps:
# Initialize buffers if they don't exist, or reinitialize if resetting counters
@ -611,11 +629,11 @@ class BaseAlgorithm(ABC):
attr = None
try:
attr = recursive_getattr(self, name)
except Exception:
except Exception as e:
# What errors recursive_getattr could throw? KeyError, but
# possible something else too (e.g. if key is an int?).
# Catch anything for now.
raise ValueError(f"Key {name} is an invalid object name.")
raise ValueError(f"Key {name} is an invalid object name.") from e
if isinstance(attr, th.optim.Optimizer):
# Optimizers do not support "strict" keyword...

View file

@ -13,6 +13,7 @@ from stable_baselines3.common.type_aliases import (
ReplayBufferSamples,
RolloutBufferSamples,
)
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import VecNormalize
try:
@ -39,10 +40,10 @@ class BaseBuffer(ABC):
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "cpu",
device: Union[th.device, str] = "auto",
n_envs: int = 1,
):
super(BaseBuffer, self).__init__()
super().__init__()
self.buffer_size = buffer_size
self.observation_space = observation_space
self.action_space = action_space
@ -51,7 +52,7 @@ class BaseBuffer(ABC):
self.action_dim = get_action_dim(action_space)
self.pos = 0
self.full = False
self.device = device
self.device = get_device(device)
self.n_envs = n_envs
@staticmethod
@ -157,13 +158,14 @@ class ReplayBuffer(BaseBuffer):
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device:
:param device: PyTorch device
:param n_envs: Number of parallel environments
:param optimize_memory_usage: Enable a memory efficient variant
of the replay buffer which reduces by almost a factor two the memory used,
at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
and https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
Cannot be used in combination with handle_timeout_termination.
:param handle_timeout_termination: Handle timeout termination (due to timelimit)
separately and treat the task as infinite horizon task.
https://github.com/DLR-RM/stable-baselines3/issues/284
@ -174,12 +176,12 @@ class ReplayBuffer(BaseBuffer):
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "cpu",
device: Union[th.device, str] = "auto",
n_envs: int = 1,
optimize_memory_usage: bool = False,
handle_timeout_termination: bool = True,
):
super(ReplayBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
# Adjust buffer size
self.buffer_size = max(buffer_size // n_envs, 1)
@ -188,6 +190,13 @@ class ReplayBuffer(BaseBuffer):
if psutil is not None:
mem_available = psutil.virtual_memory().available
# there is a bug if both optimize_memory_usage and handle_timeout_termination are true
# see https://github.com/DLR-RM/stable-baselines3/issues/934
if optimize_memory_usage and handle_timeout_termination:
raise ValueError(
"ReplayBuffer does not support optimize_memory_usage = True "
"and handle_timeout_termination = True simultaneously."
)
self.optimize_memory_usage = optimize_memory_usage
self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=observation_space.dtype)
@ -239,8 +248,7 @@ class ReplayBuffer(BaseBuffer):
next_obs = next_obs.reshape((self.n_envs,) + self.obs_shape)
# Same, for actions
if isinstance(self.action_space, spaces.Discrete):
action = action.reshape((self.n_envs, self.action_dim))
action = action.reshape((self.n_envs, self.action_dim))
# Copy to avoid modification by reference
self.observations[self.pos] = np.array(obs).copy()
@ -321,7 +329,7 @@ class RolloutBuffer(BaseBuffer):
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device:
:param device: PyTorch device
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
Equivalent to classic advantage when set to 1.
:param gamma: Discount factor
@ -333,13 +341,13 @@ class RolloutBuffer(BaseBuffer):
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "cpu",
device: Union[th.device, str] = "auto",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
):
super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
self.gae_lambda = gae_lambda
self.gamma = gamma
self.observations, self.actions, self.rewards, self.advantages = None, None, None, None
@ -358,7 +366,7 @@ class RolloutBuffer(BaseBuffer):
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.generator_ready = False
super(RolloutBuffer, self).reset()
super().reset()
def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None:
"""
@ -425,6 +433,9 @@ class RolloutBuffer(BaseBuffer):
if isinstance(self.observation_space, spaces.Discrete):
obs = obs.reshape((self.n_envs,) + self.obs_shape)
# Same reshape, for actions
action = action.reshape((self.n_envs, self.action_dim))
self.observations[self.pos] = np.array(obs).copy()
self.actions[self.pos] = np.array(action).copy()
self.rewards[self.pos] = np.array(reward).copy()
@ -483,7 +494,7 @@ class DictReplayBuffer(ReplayBuffer):
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device:
:param device: PyTorch device
:param n_envs: Number of parallel environments
:param optimize_memory_usage: Enable a memory efficient variant
Disabled for now (see https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702)
@ -497,7 +508,7 @@ class DictReplayBuffer(ReplayBuffer):
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "cpu",
device: Union[th.device, str] = "auto",
n_envs: int = 1,
optimize_memory_usage: bool = False,
handle_timeout_termination: bool = True,
@ -578,8 +589,7 @@ class DictReplayBuffer(ReplayBuffer):
self.next_observations[key][self.pos] = np.array(next_obs[key]).copy()
# Same reshape, for actions
if isinstance(self.action_space, spaces.Discrete):
action = action.reshape((self.n_envs, self.action_dim))
action = action.reshape((self.n_envs, self.action_dim))
self.actions[self.pos] = np.array(action).copy()
self.rewards[self.pos] = np.array(reward).copy()
@ -649,7 +659,7 @@ class DictRolloutBuffer(RolloutBuffer):
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device:
:param device: PyTorch device
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
Equivalent to Monte-Carlo advantage estimate when set to 1.
:param gamma: Discount factor
@ -661,7 +671,7 @@ class DictRolloutBuffer(RolloutBuffer):
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "cpu",
device: Union[th.device, str] = "auto",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,

View file

@ -19,7 +19,7 @@ class BaseCallback(ABC):
"""
def __init__(self, verbose: int = 0):
super(BaseCallback, self).__init__()
super().__init__()
# The RL model
self.model = None # type: Optional[base_class.BaseAlgorithm]
# An alias for self.model.get_env(), the environment used for training
@ -127,14 +127,14 @@ class EventCallback(BaseCallback):
"""
def __init__(self, callback: Optional[BaseCallback] = None, verbose: int = 0):
super(EventCallback, self).__init__(verbose=verbose)
super().__init__(verbose=verbose)
self.callback = callback
# Give access to the parent
if callback is not None:
self.callback.parent = self
def init_callback(self, model: "base_class.BaseAlgorithm") -> None:
super(EventCallback, self).init_callback(model)
super().init_callback(model)
if self.callback is not None:
self.callback.init_callback(self.model)
@ -169,7 +169,7 @@ class CallbackList(BaseCallback):
"""
def __init__(self, callbacks: List[BaseCallback]):
super(CallbackList, self).__init__()
super().__init__()
assert isinstance(callbacks, list)
self.callbacks = callbacks
@ -228,7 +228,7 @@ class CheckpointCallback(BaseCallback):
"""
def __init__(self, save_freq: int, save_path: str, name_prefix: str = "rl_model", verbose: int = 0):
super(CheckpointCallback, self).__init__(verbose)
super().__init__(verbose)
self.save_freq = save_freq
self.save_path = save_path
self.name_prefix = name_prefix
@ -256,7 +256,7 @@ class ConvertCallback(BaseCallback):
"""
def __init__(self, callback: Callable[[Dict[str, Any], Dict[str, Any]], bool], verbose: int = 0):
super(ConvertCallback, self).__init__(verbose)
super().__init__(verbose)
self.callback = callback
def _on_step(self) -> bool:
@ -307,7 +307,7 @@ class EvalCallback(EventCallback):
verbose: int = 1,
warn: bool = True,
):
super(EvalCallback, self).__init__(callback_after_eval, verbose=verbose)
super().__init__(callback_after_eval, verbose=verbose)
self.callback_on_new_best = callback_on_new_best
if self.callback_on_new_best is not None:
@ -380,12 +380,12 @@ class EvalCallback(EventCallback):
if self.model.get_vec_normalize_env() is not None:
try:
sync_envs_normalization(self.training_env, self.eval_env)
except AttributeError:
except AttributeError as e:
raise AssertionError(
"Training and eval env are not wrapped the same way, "
"see https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html#evalcallback "
"and warning above."
)
) from e
# Reset success rate buffer
self._is_success_buffer = []
@ -480,7 +480,7 @@ class StopTrainingOnRewardThreshold(BaseCallback):
"""
def __init__(self, reward_threshold: float, verbose: int = 0):
super(StopTrainingOnRewardThreshold, self).__init__(verbose=verbose)
super().__init__(verbose=verbose)
self.reward_threshold = reward_threshold
def _on_step(self) -> bool:
@ -505,7 +505,7 @@ class EveryNTimesteps(EventCallback):
"""
def __init__(self, n_steps: int, callback: BaseCallback):
super(EveryNTimesteps, self).__init__(callback)
super().__init__(callback)
self.n_steps = n_steps
self.last_time_trigger = 0
@ -528,7 +528,7 @@ class StopTrainingOnMaxEpisodes(BaseCallback):
"""
def __init__(self, max_episodes: int, verbose: int = 0):
super(StopTrainingOnMaxEpisodes, self).__init__(verbose=verbose)
super().__init__(verbose=verbose)
self.max_episodes = max_episodes
self._total_max_episodes = max_episodes
self.n_episodes = 0
@ -573,7 +573,7 @@ class StopTrainingOnNoModelImprovement(BaseCallback):
"""
def __init__(self, max_no_improvement_evals: int, min_evals: int = 0, verbose: int = 0):
super(StopTrainingOnNoModelImprovement, self).__init__(verbose=verbose)
super().__init__(verbose=verbose)
self.max_no_improvement_evals = max_no_improvement_evals
self.min_evals = min_evals
self.last_best_mean_reward = -np.inf

View file

@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Union
import gym
import numpy as np
import torch as th
from gym import spaces
from torch import nn
@ -16,7 +17,7 @@ class Distribution(ABC):
"""Abstract base class for distributions."""
def __init__(self):
super(Distribution, self).__init__()
super().__init__()
self.distribution = None
@abstractmethod
@ -120,7 +121,7 @@ class DiagGaussianDistribution(Distribution):
"""
def __init__(self, action_dim: int):
super(DiagGaussianDistribution, self).__init__()
super().__init__()
self.action_dim = action_dim
self.mean_actions = None
self.log_std = None
@ -201,13 +202,13 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
"""
def __init__(self, action_dim: int, epsilon: float = 1e-6):
super(SquashedDiagGaussianDistribution, self).__init__(action_dim)
super().__init__(action_dim)
# Avoid NaN (prevents division by zero or log of zero)
self.epsilon = epsilon
self.gaussian_actions = None
def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "SquashedDiagGaussianDistribution":
super(SquashedDiagGaussianDistribution, self).proba_distribution(mean_actions, log_std)
super().proba_distribution(mean_actions, log_std)
return self
def log_prob(self, actions: th.Tensor, gaussian_actions: Optional[th.Tensor] = None) -> th.Tensor:
@ -219,7 +220,7 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
gaussian_actions = TanhBijector.inverse(actions)
# Log likelihood for a Gaussian distribution
log_prob = super(SquashedDiagGaussianDistribution, self).log_prob(gaussian_actions)
log_prob = super().log_prob(gaussian_actions)
# Squash correction (from original SAC implementation)
# this comes from the fact that tanh is bijective and differentiable
log_prob -= th.sum(th.log(1 - actions**2 + self.epsilon), dim=1)
@ -254,7 +255,7 @@ class CategoricalDistribution(Distribution):
"""
def __init__(self, action_dim: int):
super(CategoricalDistribution, self).__init__()
super().__init__()
self.action_dim = action_dim
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
@ -305,7 +306,7 @@ class MultiCategoricalDistribution(Distribution):
"""
def __init__(self, action_dims: List[int]):
super(MultiCategoricalDistribution, self).__init__()
super().__init__()
self.action_dims = action_dims
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
@ -360,7 +361,7 @@ class BernoulliDistribution(Distribution):
"""
def __init__(self, action_dims: int):
super(BernoulliDistribution, self).__init__()
super().__init__()
self.action_dims = action_dims
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
@ -433,7 +434,7 @@ class StateDependentNoiseDistribution(Distribution):
learn_features: bool = False,
epsilon: float = 1e-6,
):
super(StateDependentNoiseDistribution, self).__init__()
super().__init__()
self.action_dim = action_dim
self.latent_sde_dim = None
self.mean_actions = None
@ -577,10 +578,10 @@ class StateDependentNoiseDistribution(Distribution):
return th.mm(latent_sde, self.exploration_mat)
# Use batch matrix multiplication for efficient computation
# (batch_size, n_features) -> (batch_size, 1, n_features)
latent_sde = latent_sde.unsqueeze(1)
latent_sde = latent_sde.unsqueeze(dim=1)
# (batch_size, 1, n_actions)
noise = th.bmm(latent_sde, self.exploration_matrices)
return noise.squeeze(1)
return noise.squeeze(dim=1)
def actions_from_params(
self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor, deterministic: bool = False
@ -597,7 +598,7 @@ class StateDependentNoiseDistribution(Distribution):
return actions, log_prob
class TanhBijector(object):
class TanhBijector:
"""
Bijective transformation of a probability distribution
using a squashing function (tanh)
@ -607,7 +608,7 @@ class TanhBijector(object):
"""
def __init__(self, epsilon: float = 1e-6):
super(TanhBijector, self).__init__()
super().__init__()
self.epsilon = epsilon
@staticmethod
@ -657,7 +658,6 @@ def make_proba_distribution(
dist_kwargs = {}
if isinstance(action_space, spaces.Box):
assert len(action_space.shape) == 1, "Error: the action space must be a vector"
cls = StateDependentNoiseDistribution if use_sde else DiagGaussianDistribution
return cls(get_action_dim(action_space), **dist_kwargs)
elif isinstance(action_space, spaces.Discrete):
@ -688,7 +688,7 @@ def kl_divergence(dist_true: Distribution, dist_pred: Distribution) -> th.Tensor
# MultiCategoricalDistribution is not a PyTorch Distribution subclass
# so we need to implement it ourselves!
if isinstance(dist_pred, MultiCategoricalDistribution):
assert dist_pred.action_dims == dist_true.action_dims, "Error: distributions must have the same input space"
assert np.allclose(dist_pred.action_dims, dist_true.action_dims), "Error: distributions must have the same input space"
return th.stack(
[th.distributions.kl_divergence(p, q) for p, q in zip(dist_true.distribution, dist_pred.distribution)],
dim=1,

View file

@ -147,7 +147,7 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action
try:
_check_obs(obs[key], observation_space.spaces[key], "reset")
except AssertionError as e:
raise AssertionError(f"Error while checking key={key}: " + str(e))
raise AssertionError(f"Error while checking key={key}: " + str(e)) from e
else:
_check_obs(obs, observation_space, "reset")
@ -166,7 +166,7 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action
try:
_check_obs(obs[key], observation_space.spaces[key], "step")
except AssertionError as e:
raise AssertionError(f"Error while checking key={key}: " + str(e))
raise AssertionError(f"Error while checking key={key}: " + str(e)) from e
else:
_check_obs(obs, observation_space, "step")
@ -274,6 +274,11 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -
"cf https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html"
)
if isinstance(action_space, spaces.Box):
assert np.all(
np.isfinite(np.array([action_space.low, action_space.high]))
), "Continuous action space must have a finite lower and upper bound"
if isinstance(action_space, spaces.Box) and action_space.dtype != np.dtype(np.float32):
warnings.warn(
f"Your action space has dtype {action_space.dtype}, we recommend using np.float32 to avoid cast errors."

View file

@ -36,7 +36,7 @@ class BitFlippingEnv(GoalEnv):
image_obs_space: bool = False,
channel_first: bool = True,
):
super(BitFlippingEnv, self).__init__()
super().__init__()
# Shape of the observation when using image space
self.image_shape = (1, 36, 36) if channel_first else (36, 36, 1)
# The achieved goal is determined by the current state
@ -115,7 +115,7 @@ class BitFlippingEnv(GoalEnv):
if self.discrete_obs_space:
# The internal state is the binary representation of the
# observed one
return int(sum([state[i] * 2**i for i in range(len(state))]))
return int(sum(state[i] * 2**i for i in range(len(state))))
if self.image_obs_space:
size = np.prod(self.image_shape)
@ -135,7 +135,7 @@ class BitFlippingEnv(GoalEnv):
if isinstance(state, int):
state = np.array(state).reshape(batch_size, -1)
# Convert to binary representation
state = (((state[:, :] & (1 << np.arange(len(self.state))))) > 0).astype(int)
state = ((state[:, :] & (1 << np.arange(len(self.state)))) > 0).astype(int)
elif self.image_obs_space:
state = state.reshape(batch_size, -1)[:, : len(self.state)] / 255
else:

View file

@ -42,7 +42,7 @@ class SimpleMultiObsEnv(gym.Env):
discrete_actions: bool = True,
channel_last: bool = True,
):
super(SimpleMultiObsEnv, self).__init__()
super().__init__()
self.vector_size = 5
if channel_last:

View file

@ -17,6 +17,7 @@ try:
except ImportError:
SummaryWriter = None
DEBUG = 10
INFO = 20
WARN = 30
@ -24,7 +25,7 @@ ERROR = 40
DISABLED = 50
class Video(object):
class Video:
"""
Video data class storing the video frames and the frame per seconds
@ -37,7 +38,7 @@ class Video(object):
self.fps = fps
class Figure(object):
class Figure:
"""
Figure data class storing a matplotlib figure and whether to close the figure after logging it
@ -50,7 +51,7 @@ class Figure(object):
self.close = close
class Image(object):
class Image:
"""
Image data class storing an image and data format
@ -80,13 +81,13 @@ class FormatUnsupportedError(NotImplementedError):
format_str = f"formats {', '.join(unsupported_formats)} are"
else:
format_str = f"format {unsupported_formats[0]} is"
super(FormatUnsupportedError, self).__init__(
super().__init__(
f"The {format_str} not supported for the {value_description} value logged.\n"
f"You can exclude formats via the `exclude` parameter of the logger's `record` function."
)
class KVWriter(object):
class KVWriter:
"""
Key Value writer
"""
@ -108,7 +109,7 @@ class KVWriter(object):
raise NotImplementedError
class SeqWriter(object):
class SeqWriter:
"""
sequence writer
"""
@ -246,12 +247,13 @@ def filter_excluded_keys(
class JSONOutputFormat(KVWriter):
def __init__(self, filename: str):
"""
log to a file, in the JSON format
"""
Log to a file, in the JSON format
:param filename: the file to write the log to
"""
:param filename: the file to write the log to
"""
def __init__(self, filename: str):
self.file = open(filename, "wt")
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
@ -287,13 +289,13 @@ class JSONOutputFormat(KVWriter):
class CSVOutputFormat(KVWriter):
"""
Log to a file, in a CSV format
:param filename: the file to write the log to
"""
def __init__(self, filename: str):
"""
log to a file, in a CSV format
:param filename: the file to write the log to
"""
self.file = open(filename, "w+t")
self.keys = []
self.separator = ","
@ -351,12 +353,13 @@ class CSVOutputFormat(KVWriter):
class TensorBoardOutputFormat(KVWriter):
def __init__(self, folder: str):
"""
Dumps key/value pairs into TensorBoard's numeric format.
"""
Dumps key/value pairs into TensorBoard's numeric format.
:param folder: the folder to write the log to
"""
:param folder: the folder to write the log to
"""
def __init__(self, folder: str):
assert SummaryWriter is not None, "tensorboard is not installed, you can use " "pip install tensorboard to do so"
self.writer = SummaryWriter(log_dir=folder)
@ -427,7 +430,7 @@ def make_output_format(_format: str, log_dir: str, log_suffix: str = "") -> KVWr
# ================================================================
class Logger(object):
class Logger:
"""
The logger class.
@ -623,7 +626,7 @@ def read_json(filename: str) -> pandas.DataFrame:
:return: the data in the json
"""
data = []
with open(filename, "rt") as file_handler:
with open(filename) as file_handler:
for line in file_handler:
data.append(json.loads(line))
return pandas.DataFrame(data)

View file

@ -36,7 +36,7 @@ class Monitor(gym.Wrapper):
reset_keywords: Tuple[str, ...] = (),
info_keywords: Tuple[str, ...] = (),
):
super(Monitor, self).__init__(env=env)
super().__init__(env=env)
self.t_start = time.time()
if filename is not None:
self.results_writer = ResultsWriter(
@ -110,7 +110,7 @@ class Monitor(gym.Wrapper):
"""
Closes the environment
"""
super(Monitor, self).close()
super().close()
if self.results_writer is not None:
self.results_writer.close()
@ -224,7 +224,7 @@ def load_results(path: str) -> pandas.DataFrame:
raise LoadMonitorResultsError(f"No monitor files of the form *{Monitor.EXT} found in {path}")
data_frames, headers = [], []
for file_name in monitor_files:
with open(file_name, "rt") as file_handler:
with open(file_name) as file_handler:
first_line = file_handler.readline()
assert first_line[0] == "#"
header = json.loads(first_line[1:])

View file

@ -11,7 +11,7 @@ class ActionNoise(ABC):
"""
def __init__(self):
super(ActionNoise, self).__init__()
super().__init__()
def reset(self) -> None:
"""
@ -35,7 +35,7 @@ class NormalActionNoise(ActionNoise):
def __init__(self, mean: np.ndarray, sigma: np.ndarray):
self._mu = mean
self._sigma = sigma
super(NormalActionNoise, self).__init__()
super().__init__()
def __call__(self) -> np.ndarray:
return np.random.normal(self._mu, self._sigma)
@ -72,7 +72,7 @@ class OrnsteinUhlenbeckActionNoise(ActionNoise):
self.initial_noise = initial_noise
self.noise_prev = np.zeros_like(self._mu)
self.reset()
super(OrnsteinUhlenbeckActionNoise, self).__init__()
super().__init__()
def __call__(self) -> np.ndarray:
noise = (
@ -105,8 +105,8 @@ class VectorizedActionNoise(ActionNoise):
try:
self.n_envs = int(n_envs)
assert self.n_envs > 0
except (TypeError, AssertionError):
raise ValueError(f"Expected n_envs={n_envs} to be positive integer greater than 0")
except (TypeError, AssertionError) as e:
raise ValueError(f"Expected n_envs={n_envs} to be positive integer greater than 0") from e
self.base_noise = base_noise
self.noises = [copy.deepcopy(self.base_noise) for _ in range(n_envs)]

View file

@ -1,5 +1,6 @@
import io
import pathlib
import sys
import time
import warnings
from copy import deepcopy
@ -28,7 +29,6 @@ class OffPolicyAlgorithm(BaseAlgorithm):
:param policy: Policy object
:param env: The environment to learn from
(if registered in Gym, can be str. Can be None for loading trained models)
:param policy_base: The base policy used by this method
:param learning_rate: learning rate for the optimizer,
it can be a function of the current progress remaining (from 1 to 0)
:param buffer_size: size of the replay buffer
@ -76,7 +76,6 @@ class OffPolicyAlgorithm(BaseAlgorithm):
self,
policy: Type[BasePolicy],
env: Union[GymEnv, str],
policy_base: Type[BasePolicy],
learning_rate: Union[float, Schedule],
buffer_size: int = 1_000_000, # 1e6
learning_starts: int = 100,
@ -104,10 +103,9 @@ class OffPolicyAlgorithm(BaseAlgorithm):
supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None,
):
super(OffPolicyAlgorithm, self).__init__(
super().__init__(
policy=policy,
env=env,
policy_base=policy_base,
learning_rate=learning_rate,
policy_kwargs=policy_kwargs,
tensorboard_log=tensorboard_log,
@ -160,8 +158,10 @@ class OffPolicyAlgorithm(BaseAlgorithm):
try:
train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1]))
except ValueError:
raise ValueError(f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!")
except ValueError as e:
raise ValueError(
f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!"
) from e
if not isinstance(train_freq[0], int):
raise ValueError(f"The frequency of `train_freq` must be an integer and not {train_freq[0]}")
@ -428,8 +428,8 @@ class OffPolicyAlgorithm(BaseAlgorithm):
"""
Write log.
"""
time_elapsed = time.time() - self.start_time
fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time_elapsed + 1e-8))
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
self.logger.record("time/episodes", self._episode_num, exclude="tensorboard")
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))

View file

@ -1,3 +1,4 @@
import sys
import time
from typing import Any, Dict, List, Optional, Tuple, Type, Union
@ -8,7 +9,7 @@ import torch as th
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import obs_as_tensor, safe_mean
from stable_baselines3.common.vec_env import VecEnv
@ -34,7 +35,6 @@ class OnPolicyAlgorithm(BaseAlgorithm):
instead of action noise exploration (default: False)
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param policy_base: The base policy used by this method
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param create_eval_env: Whether to create a second environment that will be
used for evaluating the agent periodically. (Only available when passing string for the environment)
@ -62,7 +62,6 @@ class OnPolicyAlgorithm(BaseAlgorithm):
max_grad_norm: float,
use_sde: bool,
sde_sample_freq: int,
policy_base: Type[BasePolicy] = ActorCriticPolicy,
tensorboard_log: Optional[str] = None,
create_eval_env: bool = False,
monitor_wrapper: bool = True,
@ -74,10 +73,9 @@ class OnPolicyAlgorithm(BaseAlgorithm):
supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None,
):
super(OnPolicyAlgorithm, self).__init__(
super().__init__(
policy=policy,
env=env,
policy_base=policy_base,
learning_rate=learning_rate,
policy_kwargs=policy_kwargs,
verbose=verbose,
@ -257,13 +255,14 @@ class OnPolicyAlgorithm(BaseAlgorithm):
# Display training infos
if log_interval is not None and iteration % log_interval == 0:
fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time.time() - self.start_time))
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
self.logger.record("time/iterations", iteration, exclude="tensorboard")
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
self.logger.record("time/fps", fps)
self.logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard")
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
self.logger.dump(step=self.num_timesteps)

View file

@ -67,7 +67,7 @@ class BaseModel(nn.Module, ABC):
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super(BaseModel, self).__init__()
super().__init__()
if optimizer_kwargs is None:
optimizer_kwargs = {}
@ -267,7 +267,7 @@ class BasePolicy(BaseModel):
"""
def __init__(self, *args, squash_output: bool = False, **kwargs):
super(BasePolicy, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
self._squash_output = squash_output
@staticmethod
@ -336,8 +336,8 @@ class BasePolicy(BaseModel):
with th.no_grad():
actions = self._predict(observation, deterministic=deterministic)
# Convert to numpy
actions = actions.cpu().numpy()
# Convert to numpy, and reshape to the original action shape
actions = actions.cpu().numpy().reshape((-1,) + self.action_space.shape)
if isinstance(self.action_space, gym.spaces.Box):
if self.squash_output:
@ -350,7 +350,7 @@ class BasePolicy(BaseModel):
# Remove batch dimension if needed
if not vectorized_env:
actions = actions[0]
actions = actions.squeeze(axis=0)
return actions, state
@ -437,7 +437,7 @@ class ActorCriticPolicy(BasePolicy):
if optimizer_class == th.optim.Adam:
optimizer_kwargs["eps"] = 1e-5
super(ActorCriticPolicy, self).__init__(
super().__init__(
observation_space,
action_space,
features_extractor_class,
@ -592,6 +592,7 @@ class ActorCriticPolicy(BasePolicy):
distribution = self._get_action_dist_from_latent(latent_pi)
actions = distribution.get_actions(deterministic=deterministic)
log_prob = distribution.log_prob(actions)
actions = actions.reshape((-1,) + self.action_space.shape)
return actions, values, log_prob
def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distribution:
@ -724,7 +725,7 @@ class ActorCriticCnnPolicy(ActorCriticPolicy):
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super(ActorCriticCnnPolicy, self).__init__(
super().__init__(
observation_space,
action_space,
lr_schedule,
@ -799,7 +800,7 @@ class MultiInputActorCriticPolicy(ActorCriticPolicy):
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super(MultiInputActorCriticPolicy, self).__init__(
super().__init__(
observation_space,
action_space,
lr_schedule,
@ -895,68 +896,3 @@ class ContinuousCritic(BaseModel):
with th.no_grad():
features = self.extract_features(obs)
return self.q_networks[0](th.cat([features, actions], dim=1))
_policy_registry = dict() # type: Dict[Type[BasePolicy], Dict[str, Type[BasePolicy]]]
def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[BasePolicy]:
"""
Returns the registered policy from the base type and name.
See `register_policy` for registering policies and explanation.
:param base_policy_type: the base policy class
:param name: the policy name
:return: the policy
"""
if base_policy_type not in _policy_registry:
raise KeyError(f"Error: the policy type {base_policy_type} is not registered!")
if name not in _policy_registry[base_policy_type]:
raise KeyError(
f"Error: unknown policy type {name},"
f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!"
)
return _policy_registry[base_policy_type][name]
def register_policy(name: str, policy: Type[BasePolicy]) -> None:
"""
Register a policy, so it can be called using its name.
e.g. SAC('MlpPolicy', ...) instead of SAC(MlpPolicy, ...).
The goal here is to standardize policy naming, e.g.
all algorithms can call upon "MlpPolicy" or "CnnPolicy",
and they receive respective policies that work for them.
Consider following:
OnlinePolicy
-- OnlineMlpPolicy ("MlpPolicy")
-- OnlineCnnPolicy ("CnnPolicy")
OfflinePolicy
-- OfflineMlpPolicy ("MlpPolicy")
-- OfflineCnnPolicy ("CnnPolicy")
Two policies have name "MlpPolicy" and two have "CnnPolicy".
In `get_policy_from_name`, the parent class (e.g. OnlinePolicy)
is given and used to select and return the correct policy.
:param name: the policy name
:param policy: the policy class
"""
sub_class = None
for cls in BasePolicy.__subclasses__():
if issubclass(policy, cls):
sub_class = cls
break
if sub_class is None:
raise ValueError(f"Error: the policy {policy} is not of any known subclasses of BasePolicy!")
if sub_class not in _policy_registry:
_policy_registry[sub_class] = {}
if name in _policy_registry[sub_class]:
# Check if the registered policy is same
# we try to register. If not so,
# do not override and complain.
if _policy_registry[sub_class][name] != policy:
raise ValueError(f"Error: the name {name} is already registered for a different policy, will not override.")
_policy_registry[sub_class][name] = policy

View file

@ -3,7 +3,7 @@ from typing import Tuple, Union
import numpy as np
class RunningMeanStd(object):
class RunningMeanStd:
def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()):
"""
Calulates the running mean and std of a data stream

View file

@ -206,8 +206,8 @@ def open_path(path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verb
mode = mode.lower()
try:
mode = {"write": "w", "read": "r", "w": "w", "r": "r"}[mode]
except KeyError:
raise ValueError("Expected mode to be either 'w' or 'r'.")
except KeyError as e:
raise ValueError("Expected mode to be either 'w' or 'r'.") from e
if ("w" == mode) and not path.writable() or ("r" == mode) and not path.readable():
e1 = "writable" if "w" == mode else "readable"
raise ValueError(f"Expected a {e1} file.")
@ -441,7 +441,7 @@ def load_from_zip_file(
# State dicts. Store into params dictionary
# with same name as in .zip file (without .pth)
params[os.path.splitext(file_path)[0]] = th_object
except zipfile.BadZipFile:
except zipfile.BadZipFile as e:
# load_path wasn't a zip file
raise ValueError(f"Error: the file {load_path} wasn't a zip-file")
raise ValueError(f"Error: the file {load_path} wasn't a zip-file") from e
return data, params, pytorch_variables

View file

@ -54,21 +54,21 @@ class RMSpropTFLike(Optimizer):
centered: bool = False,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= momentum:
raise ValueError("Invalid momentum value: {}".format(momentum))
raise ValueError(f"Invalid momentum value: {momentum}")
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
if not 0.0 <= alpha:
raise ValueError("Invalid alpha value: {}".format(alpha))
raise ValueError(f"Invalid alpha value: {alpha}")
defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay)
super(RMSpropTFLike, self).__init__(params, defaults)
super().__init__(params, defaults)
def __setstate__(self, state: Dict[str, Any]) -> None:
super(RMSpropTFLike, self).__setstate__(state)
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("momentum", 0)
group.setdefault("centered", False)

View file

@ -19,7 +19,7 @@ class BaseFeaturesExtractor(nn.Module):
"""
def __init__(self, observation_space: gym.Space, features_dim: int = 0):
super(BaseFeaturesExtractor, self).__init__()
super().__init__()
assert features_dim > 0
self._observation_space = observation_space
self._features_dim = features_dim
@ -41,7 +41,7 @@ class FlattenExtractor(BaseFeaturesExtractor):
"""
def __init__(self, observation_space: gym.Space):
super(FlattenExtractor, self).__init__(observation_space, get_flattened_obs_dim(observation_space))
super().__init__(observation_space, get_flattened_obs_dim(observation_space))
self.flatten = nn.Flatten()
def forward(self, observations: th.Tensor) -> th.Tensor:
@ -50,7 +50,7 @@ class FlattenExtractor(BaseFeaturesExtractor):
class NatureCNN(BaseFeaturesExtractor):
"""
CNN from DQN nature paper:
CNN from DQN Nature paper:
Mnih, Volodymyr, et al.
"Human-level control through deep reinforcement learning."
Nature 518.7540 (2015): 529-533.
@ -61,7 +61,7 @@ class NatureCNN(BaseFeaturesExtractor):
"""
def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512):
super(NatureCNN, self).__init__(observation_space, features_dim)
super().__init__(observation_space, features_dim)
# We assume CxHxW images (channels first)
# Re-ordering will be done by pre-preprocessing or wrapper
assert is_image_space(observation_space, check_channels=False), (
@ -169,7 +169,7 @@ class MlpExtractor(nn.Module):
activation_fn: Type[nn.Module],
device: Union[th.device, str] = "auto",
):
super(MlpExtractor, self).__init__()
super().__init__()
device = get_device(device)
shared_net, policy_net, value_net = [], [], []
policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network
@ -250,7 +250,7 @@ class CombinedExtractor(BaseFeaturesExtractor):
def __init__(self, observation_space: gym.spaces.Dict, cnn_output_dim: int = 256):
# TODO we do not know features-dim here before going over all the items, so put something there. This is dirty!
super(CombinedExtractor, self).__init__(observation_space, features_dim=1)
super().__init__(observation_space, features_dim=1)
extractors = {}

View file

@ -50,7 +50,7 @@ class ReplayBufferSamples(NamedTuple):
class DictReplayBufferSamples(ReplayBufferSamples):
observations: TensorDict
actions: th.Tensor
next_observations: th.Tensor
next_observations: TensorDict
dones: th.Tensor
rewards: th.Tensor

View file

@ -154,15 +154,18 @@ def get_device(device: Union[th.device, str] = "auto") -> th.device:
return device
def get_latest_run_id(log_path: Optional[str] = None, log_name: str = "") -> int:
def get_latest_run_id(log_path: str = "", log_name: str = "") -> int:
"""
Returns the latest run number for the given log name and log path,
by finding the greatest number in the directories.
:param log_path: Path to the log folder containing several runs.
:param log_name: Name of the experiment. Each run is stored
in a folder named ``log_name_1``, ``log_name_2``, ...
:return: latest run number
"""
max_run_id = 0
for path in glob.glob(f"{log_path}/{log_name}_[0-9]*"):
for path in glob.glob(os.path.join(log_path, f"{glob.escape(log_name)}_[0-9]*")):
file_name = path.split(os.sep)[-1]
ext = file_name.split("_")[-1]
if log_name == "_".join(file_name.split("_")[:-1]) and ext.isdigit() and int(ext) > max_run_id:

View file

@ -66,7 +66,9 @@ def sync_envs_normalization(env: "GymEnv", eval_env: "GymEnv") -> None:
env_tmp, eval_env_tmp = env, eval_env
while isinstance(env_tmp, VecEnvWrapper):
if isinstance(env_tmp, VecNormalize):
eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms)
# Only synchronize if observation normalization exists
if hasattr(env_tmp, "obs_rms"):
eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms)
eval_env_tmp.ret_rms = deepcopy(env_tmp.ret_rms)
env_tmp = env_tmp.venv
eval_env_tmp = eval_env_tmp.venv

View file

@ -305,7 +305,7 @@ class VecEnvWrapper(VecEnv):
own_class = f"{type(self).__module__}.{type(self).__name__}"
error_str = (
f"Error: Recursive attribute lookup for {name} from {own_class} is "
"ambiguous and hides attribute from {blocked_class}"
f"ambiguous and hides attribute from {blocked_class}"
)
raise AttributeError(error_str)

View file

@ -51,7 +51,9 @@ class DummyVecEnv(VecEnv):
return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos))
def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
seeds = list()
if seed is None:
seed = np.random.randint(0, 2**32 - 1)
seeds = []
for idx, env in enumerate(self.envs):
seeds.append(env.seed(seed + idx))
return seeds

View file

@ -7,7 +7,7 @@ from gym import spaces
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
class StackedObservations(object):
class StackedObservations:
"""
Frame stacking wrapper for data.

View file

@ -123,6 +123,8 @@ class SubprocVecEnv(VecEnv):
return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos
def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
if seed is None:
seed = np.random.randint(0, 2**32 - 1)
for idx, remote in enumerate(self.remotes):
remote.send(("seed", seed + idx))
return [remote.recv() for remote in self.remotes]
@ -215,6 +217,6 @@ def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: gym.space
elif isinstance(space, gym.spaces.Tuple):
assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space"
obs_len = len(space.spaces)
return tuple((np.stack([o[i] for o in obs]) for i in range(obs_len)))
return tuple(np.stack([o[i] for o in obs]) for i in range(obs_len))
else:
return np.stack(obs)

View file

@ -37,7 +37,7 @@ def dict_to_obs(obs_space: gym.spaces.Space, obs_dict: Dict[Any, np.ndarray]) ->
return obs_dict
elif isinstance(obs_space, gym.spaces.Tuple):
assert len(obs_dict) == len(obs_space.spaces), "size of observation does not match size of observation space"
return tuple((obs_dict[i] for i in range(len(obs_space.spaces))))
return tuple(obs_dict[i] for i in range(len(obs_space.spaces)))
else:
assert set(obs_dict.keys()) == {None}, "multiple observation keys for unstructured observation space"
return obs_dict[None]

View file

@ -26,7 +26,7 @@ class VecTransposeImage(VecEnvWrapper):
self.skip = skip
# Do nothing
if skip:
super(VecTransposeImage, self).__init__(venv)
super().__init__(venv)
return
if isinstance(venv.observation_space, spaces.dict.Dict):
@ -39,7 +39,7 @@ class VecTransposeImage(VecEnvWrapper):
observation_space.spaces[key] = self.transpose_space(space, key)
else:
observation_space = self.transpose_space(venv.observation_space)
super(VecTransposeImage, self).__init__(venv, observation_space=observation_space)
super().__init__(venv, observation_space=observation_space)
@staticmethod
def transpose_space(observation_space: spaces.Box, key: str = "") -> spaces.Box:

View file

@ -78,7 +78,7 @@ class DDPG(TD3):
_init_setup_model: bool = True,
):
super(DDPG, self).__init__(
super().__init__(
policy=policy,
env=env,
learning_rate=learning_rate,
@ -127,7 +127,7 @@ class DDPG(TD3):
reset_num_timesteps: bool = True,
) -> OffPolicyAlgorithm:
return super(DDPG, self).learn(
return super().learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,

View file

@ -8,10 +8,11 @@ from torch.nn import functional as F
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.preprocessing import maybe_transpose
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update
from stable_baselines3.dqn.policies import DQNPolicy
from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy
class DQN(OffPolicyAlgorithm):
@ -19,7 +20,7 @@ class DQN(OffPolicyAlgorithm):
Deep Q-Network (DQN)
Paper: https://arxiv.org/abs/1312.5602, https://www.nature.com/articles/nature14236
Default hyperparameters are taken from the nature paper,
Default hyperparameters are taken from the Nature paper,
except for the optimizer and learning rate that were taken from Stable Baselines defaults.
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
@ -59,6 +60,12 @@ class DQN(OffPolicyAlgorithm):
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
policy_aliases: Dict[str, Type[BasePolicy]] = {
"MlpPolicy": MlpPolicy,
"CnnPolicy": CnnPolicy,
"MultiInputPolicy": MultiInputPolicy,
}
def __init__(
self,
policy: Union[str, Type[DQNPolicy]],
@ -88,10 +95,9 @@ class DQN(OffPolicyAlgorithm):
_init_setup_model: bool = True,
):
super(DQN, self).__init__(
super().__init__(
policy,
env,
DQNPolicy,
learning_rate,
buffer_size,
learning_starts,
@ -132,7 +138,7 @@ class DQN(OffPolicyAlgorithm):
self._setup_model()
def _setup_model(self) -> None:
super(DQN, self)._setup_model()
super()._setup_model()
self._create_aliases()
self.exploration_schedule = get_linear_fn(
self.exploration_initial_eps,
@ -255,7 +261,7 @@ class DQN(OffPolicyAlgorithm):
reset_num_timesteps: bool = True,
) -> OffPolicyAlgorithm:
return super(DQN, self).learn(
return super().learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,
@ -268,7 +274,7 @@ class DQN(OffPolicyAlgorithm):
)
def _excluded_save_params(self) -> List[str]:
return super(DQN, self)._excluded_save_params() + ["q_net", "q_net_target"]
return super()._excluded_save_params() + ["q_net", "q_net_target"]
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "policy.optimizer"]

View file

@ -4,7 +4,7 @@ import gym
import torch as th
from torch import nn
from stable_baselines3.common.policies import BasePolicy, register_policy
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
CombinedExtractor,
@ -37,7 +37,7 @@ class QNetwork(BasePolicy):
activation_fn: Type[nn.Module] = nn.ReLU,
normalize_images: bool = True,
):
super(QNetwork, self).__init__(
super().__init__(
observation_space,
action_space,
features_extractor=features_extractor,
@ -118,7 +118,7 @@ class DQNPolicy(BasePolicy):
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super(DQNPolicy, self).__init__(
super().__init__(
observation_space,
action_space,
features_extractor_class,
@ -239,7 +239,7 @@ class CnnPolicy(DQNPolicy):
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super(CnnPolicy, self).__init__(
super().__init__(
observation_space,
action_space,
lr_schedule,
@ -284,7 +284,7 @@ class MultiInputPolicy(DQNPolicy):
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super(MultiInputPolicy, self).__init__(
super().__init__(
observation_space,
action_space,
lr_schedule,
@ -296,8 +296,3 @@ class MultiInputPolicy(DQNPolicy):
optimizer_class,
optimizer_kwargs,
)
register_policy("MlpPolicy", MlpPolicy)
register_policy("CnnPolicy", CnnPolicy)
register_policy("MultiInputPolicy", MultiInputPolicy)

View file

@ -28,13 +28,13 @@ def get_time_limit(env: VecEnv, current_max_episode_length: Optional[int]) -> in
if current_max_episode_length is None:
raise AttributeError
# if not available check if a valid value was passed as an argument
except AttributeError:
except AttributeError as e:
raise ValueError(
"The max episode length could not be inferred.\n"
"You must specify a `max_episode_steps` when registering the environment,\n"
"use a `gym.wrappers.TimeLimit` wrapper "
"or pass `max_episode_length` to the model constructor"
)
) from e
return current_max_episode_length
@ -73,7 +73,7 @@ class HerReplayBuffer(DictReplayBuffer):
self,
env: VecEnv,
buffer_size: int,
device: Union[th.device, str] = "cpu",
device: Union[th.device, str] = "auto",
replay_buffer: Optional[DictReplayBuffer] = None,
max_episode_length: Optional[int] = None,
n_sampled_goal: int = 4,
@ -82,7 +82,7 @@ class HerReplayBuffer(DictReplayBuffer):
handle_timeout_termination: bool = True,
):
super(HerReplayBuffer, self).__init__(buffer_size, env.observation_space, env.action_space, device, env.num_envs)
super().__init__(buffer_size, env.observation_space, env.action_space, device, env.num_envs)
# convert goal_selection_strategy into GoalSelectionStrategy if string
if isinstance(goal_selection_strategy, str):
@ -252,7 +252,7 @@ class HerReplayBuffer(DictReplayBuffer):
elif self.goal_selection_strategy == GoalSelectionStrategy.FUTURE:
# replay with random state which comes from the same episode and was observed after current transition
transitions_indices = np.random.randint(
transitions_indices[her_indices] + 1, self.episode_lengths[her_episode_indices]
transitions_indices[her_indices], self.episode_lengths[her_episode_indices]
)
elif self.goal_selection_strategy == GoalSelectionStrategy.EPISODE:
@ -262,7 +262,7 @@ class HerReplayBuffer(DictReplayBuffer):
else:
raise ValueError(f"Strategy {self.goal_selection_strategy} for sampling goals not supported!")
return self._buffer["achieved_goal"][her_episode_indices, transitions_indices]
return self._buffer["next_achieved_goal"][her_episode_indices, transitions_indices]
def _sample_transitions(
self,
@ -304,14 +304,6 @@ class HerReplayBuffer(DictReplayBuffer):
ep_lengths = self.episode_lengths[episode_indices]
# Special case when using the "future" goal sampling strategy
# we cannot sample all transitions, we have to remove the last timestep
if self.goal_selection_strategy == GoalSelectionStrategy.FUTURE:
# restrict the sampling domain when ep_lengths > 1
# otherwise filter out the indices
her_indices = her_indices[ep_lengths[her_indices] > 1]
ep_lengths[her_indices] -= 1
if online_sampling:
# Select which transitions to use
transitions_indices = np.random.randint(ep_lengths)

View file

@ -1,16 +1,7 @@
# This file is here just to define MlpPolicy/CnnPolicy
# that work for PPO
from stable_baselines3.common.policies import (
ActorCriticCnnPolicy,
ActorCriticPolicy,
MultiInputActorCriticPolicy,
register_policy,
)
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
MlpPolicy = ActorCriticPolicy
CnnPolicy = ActorCriticCnnPolicy
MultiInputPolicy = MultiInputActorCriticPolicy
register_policy("MlpPolicy", ActorCriticPolicy)
register_policy("CnnPolicy", ActorCriticCnnPolicy)
register_policy("MultiInputPolicy", MultiInputPolicy)

View file

@ -7,7 +7,7 @@ from gym import spaces
from torch.nn import functional as F
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
@ -19,7 +19,7 @@ class PPO(OnPolicyAlgorithm):
Paper: https://arxiv.org/abs/1707.06347
Code: This implementation borrows code from OpenAI Spinning Up (https://github.com/openai/spinningup/)
https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and
and Stable Baselines (PPO2 from https://github.com/hill-a/stable-baselines)
Stable Baselines (PPO2 from https://github.com/hill-a/stable-baselines)
Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html
@ -65,6 +65,12 @@ class PPO(OnPolicyAlgorithm):
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
policy_aliases: Dict[str, Type[BasePolicy]] = {
"MlpPolicy": ActorCriticPolicy,
"CnnPolicy": ActorCriticCnnPolicy,
"MultiInputPolicy": MultiInputActorCriticPolicy,
}
def __init__(
self,
policy: Union[str, Type[ActorCriticPolicy]],
@ -93,7 +99,7 @@ class PPO(OnPolicyAlgorithm):
_init_setup_model: bool = True,
):
super(PPO, self).__init__(
super().__init__(
policy,
env,
learning_rate=learning_rate,
@ -156,7 +162,7 @@ class PPO(OnPolicyAlgorithm):
self._setup_model()
def _setup_model(self) -> None:
super(PPO, self)._setup_model()
super()._setup_model()
# Initialize schedules for policy/value clipping
self.clip_range = get_schedule_fn(self.clip_range)
@ -301,7 +307,7 @@ class PPO(OnPolicyAlgorithm):
reset_num_timesteps: bool = True,
) -> "PPO":
return super(PPO, self).learn(
return super().learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,

View file

@ -6,7 +6,7 @@ import torch as th
from torch import nn
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, register_policy
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
@ -65,7 +65,7 @@ class Actor(BasePolicy):
clip_mean: float = 2.0,
normalize_images: bool = True,
):
super(Actor, self).__init__(
super().__init__(
observation_space,
action_space,
features_extractor=features_extractor,
@ -235,9 +235,9 @@ class SACPolicy(BasePolicy):
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2,
share_features_extractor: bool = True,
share_features_extractor: bool = False,
):
super(SACPolicy, self).__init__(
super().__init__(
observation_space,
action_space,
features_extractor_class,
@ -248,10 +248,7 @@ class SACPolicy(BasePolicy):
)
if net_arch is None:
if features_extractor_class == NatureCNN:
net_arch = []
else:
net_arch = [256, 256]
net_arch = [256, 256]
actor_arch, critic_arch = get_actor_critic_arch(net_arch)
@ -422,9 +419,9 @@ class CnnPolicy(SACPolicy):
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2,
share_features_extractor: bool = True,
share_features_extractor: bool = False,
):
super(CnnPolicy, self).__init__(
super().__init__(
observation_space,
action_space,
lr_schedule,
@ -493,9 +490,9 @@ class MultiInputPolicy(SACPolicy):
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2,
share_features_extractor: bool = True,
share_features_extractor: bool = False,
):
super(MultiInputPolicy, self).__init__(
super().__init__(
observation_space,
action_space,
lr_schedule,
@ -514,8 +511,3 @@ class MultiInputPolicy(SACPolicy):
n_critics,
share_features_extractor,
)
register_policy("MlpPolicy", MlpPolicy)
register_policy("CnnPolicy", CnnPolicy)
register_policy("MultiInputPolicy", MultiInputPolicy)

View file

@ -8,9 +8,10 @@ from torch.nn import functional as F
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import polyak_update
from stable_baselines3.sac.policies import SACPolicy
from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy
class SAC(OffPolicyAlgorithm):
@ -72,6 +73,12 @@ class SAC(OffPolicyAlgorithm):
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
policy_aliases: Dict[str, Type[BasePolicy]] = {
"MlpPolicy": MlpPolicy,
"CnnPolicy": CnnPolicy,
"MultiInputPolicy": MultiInputPolicy,
}
def __init__(
self,
policy: Union[str, Type[SACPolicy]],
@ -103,10 +110,9 @@ class SAC(OffPolicyAlgorithm):
_init_setup_model: bool = True,
):
super(SAC, self).__init__(
super().__init__(
policy,
env,
SACPolicy,
learning_rate,
buffer_size,
learning_starts,
@ -144,7 +150,7 @@ class SAC(OffPolicyAlgorithm):
self._setup_model()
def _setup_model(self) -> None:
super(SAC, self)._setup_model()
super()._setup_model()
self._create_aliases()
# Target entropy is used when learning the entropy coefficient
if self.target_entropy == "auto":
@ -248,7 +254,7 @@ class SAC(OffPolicyAlgorithm):
current_q_values = self.critic(replay_data.observations, replay_data.actions)
# Compute critic loss
critic_loss = 0.5 * sum([F.mse_loss(current_q, target_q_values) for current_q in current_q_values])
critic_loss = 0.5 * sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values)
critic_losses.append(critic_loss.item())
# Optimize the critic
@ -297,7 +303,7 @@ class SAC(OffPolicyAlgorithm):
reset_num_timesteps: bool = True,
) -> OffPolicyAlgorithm:
return super(SAC, self).learn(
return super().learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,
@ -310,7 +316,7 @@ class SAC(OffPolicyAlgorithm):
)
def _excluded_save_params(self) -> List[str]:
return super(SAC, self)._excluded_save_params() + ["actor", "critic", "critic_target"]
return super()._excluded_save_params() + ["actor", "critic", "critic_target"]
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]

View file

@ -4,7 +4,7 @@ import gym
import torch as th
from torch import nn
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, register_policy
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
@ -42,7 +42,7 @@ class Actor(BasePolicy):
activation_fn: Type[nn.Module] = nn.ReLU,
normalize_images: bool = True,
):
super(Actor, self).__init__(
super().__init__(
observation_space,
action_space,
features_extractor=features_extractor,
@ -119,9 +119,9 @@ class TD3Policy(BasePolicy):
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2,
share_features_extractor: bool = True,
share_features_extractor: bool = False,
):
super(TD3Policy, self).__init__(
super().__init__(
observation_space,
action_space,
features_extractor_class,
@ -134,7 +134,7 @@ class TD3Policy(BasePolicy):
# Default network architecture, from the original paper
if net_arch is None:
if features_extractor_class == NatureCNN:
net_arch = []
net_arch = [256, 256]
else:
net_arch = [400, 300]
@ -281,9 +281,9 @@ class CnnPolicy(TD3Policy):
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2,
share_features_extractor: bool = True,
share_features_extractor: bool = False,
):
super(CnnPolicy, self).__init__(
super().__init__(
observation_space,
action_space,
lr_schedule,
@ -335,9 +335,9 @@ class MultiInputPolicy(TD3Policy):
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2,
share_features_extractor: bool = True,
share_features_extractor: bool = False,
):
super(MultiInputPolicy, self).__init__(
super().__init__(
observation_space,
action_space,
lr_schedule,
@ -351,8 +351,3 @@ class MultiInputPolicy(TD3Policy):
n_critics,
share_features_extractor,
)
register_policy("MlpPolicy", MlpPolicy)
register_policy("CnnPolicy", CnnPolicy)
register_policy("MultiInputPolicy", MultiInputPolicy)

View file

@ -8,9 +8,10 @@ from torch.nn import functional as F
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import polyak_update
from stable_baselines3.td3.policies import TD3Policy
from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TD3Policy
class TD3(OffPolicyAlgorithm):
@ -60,6 +61,12 @@ class TD3(OffPolicyAlgorithm):
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
policy_aliases: Dict[str, Type[BasePolicy]] = {
"MlpPolicy": MlpPolicy,
"CnnPolicy": CnnPolicy,
"MultiInputPolicy": MultiInputPolicy,
}
def __init__(
self,
policy: Union[str, Type[TD3Policy]],
@ -88,10 +95,9 @@ class TD3(OffPolicyAlgorithm):
_init_setup_model: bool = True,
):
super(TD3, self).__init__(
super().__init__(
policy,
env,
TD3Policy,
learning_rate,
buffer_size,
learning_starts,
@ -123,7 +129,7 @@ class TD3(OffPolicyAlgorithm):
self._setup_model()
def _setup_model(self) -> None:
super(TD3, self)._setup_model()
super()._setup_model()
self._create_aliases()
def _create_aliases(self) -> None:
@ -162,7 +168,7 @@ class TD3(OffPolicyAlgorithm):
current_q_values = self.critic(replay_data.observations, replay_data.actions)
# Compute critic loss
critic_loss = sum([F.mse_loss(current_q, target_q_values) for current_q in current_q_values])
critic_loss = sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values)
critic_losses.append(critic_loss.item())
# Optimize the critics
@ -202,7 +208,7 @@ class TD3(OffPolicyAlgorithm):
reset_num_timesteps: bool = True,
) -> OffPolicyAlgorithm:
return super(TD3, self).learn(
return super().learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,
@ -215,7 +221,7 @@ class TD3(OffPolicyAlgorithm):
)
def _excluded_save_params(self) -> List[str]:
return super(TD3, self)._excluded_save_params() + ["actor", "critic", "actor_target", "critic_target"]
return super()._excluded_save_params() + ["actor", "critic", "actor_target", "critic_target"]
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]

View file

@ -1 +1 @@
1.5.1a0
1.6.1a0

View file

@ -4,9 +4,10 @@ import pytest
import torch as th
from gym import spaces
from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer
from stable_baselines3.common.buffers import DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.type_aliases import DictReplayBufferSamples, ReplayBufferSamples
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import VecNormalize
@ -71,7 +72,7 @@ def test_replay_buffer_normalization(replay_buffer_cls):
env = make_vec_env(env)
env = VecNormalize(env)
buffer = replay_buffer_cls(100, env.observation_space, env.action_space)
buffer = replay_buffer_cls(100, env.observation_space, env.action_space, device="cpu")
# Interract and store transitions
env.reset()
@ -94,3 +95,47 @@ def test_replay_buffer_normalization(replay_buffer_cls):
assert th.allclose(observations.mean(0), th.zeros(1), atol=1)
# Test reward normalization
assert np.allclose(sample.rewards.mean(0), np.zeros(1), atol=1)
@pytest.mark.parametrize("replay_buffer_cls", [DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer])
@pytest.mark.parametrize("device", ["cpu", "cuda", "auto"])
def test_device_buffer(replay_buffer_cls, device):
if device == "cuda" and not th.cuda.is_available():
pytest.skip("CUDA not available")
env = {
RolloutBuffer: DummyEnv,
DictRolloutBuffer: DummyDictEnv,
ReplayBuffer: DummyEnv,
DictReplayBuffer: DummyDictEnv,
}[replay_buffer_cls]
env = make_vec_env(env)
buffer = replay_buffer_cls(100, env.observation_space, env.action_space, device=device)
# Interract and store transitions
obs = env.reset()
for _ in range(100):
action = env.action_space.sample()
next_obs, reward, done, info = env.step(action)
if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]:
episode_start, values, log_prob = np.zeros(1), th.zeros(1), th.ones(1)
buffer.add(obs, action, reward, episode_start, values, log_prob)
else:
buffer.add(obs, next_obs, action, reward, done, info)
obs = next_obs
# Get data from the buffer
if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]:
data = buffer.get(50)
elif replay_buffer_cls in [ReplayBuffer, DictReplayBuffer]:
data = buffer.sample(50)
# Check that all data are on the desired device
desired_device = get_device(device).type
for value in list(data):
if isinstance(value, dict):
for key in value.keys():
assert value[key].device.type == desired_device
elif isinstance(value, th.Tensor):
assert value.device.type == desired_device

View file

@ -163,7 +163,9 @@ def test_categorical(dist, CAT_ACTIONS):
BernoulliDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS)),
CategoricalDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS)),
DiagGaussianDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS), th.rand(N_ACTIONS)),
MultiCategoricalDistribution([N_ACTIONS, N_ACTIONS]).proba_distribution(th.rand(1, sum([N_ACTIONS, N_ACTIONS]))),
MultiCategoricalDistribution(np.array([N_ACTIONS, N_ACTIONS])).proba_distribution(
th.rand(1, sum([N_ACTIONS, N_ACTIONS]))
),
SquashedDiagGaussianDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS), th.rand(N_ACTIONS)),
StateDependentNoiseDistribution(N_ACTIONS).proba_distribution(
th.rand(N_ACTIONS), th.rand([N_ACTIONS, N_ACTIONS]), th.rand([N_ACTIONS, N_ACTIONS])

View file

@ -141,6 +141,8 @@ def test_non_default_spaces(new_obs_space):
spaces.Box(low=1, high=-1, shape=(2,), dtype=np.float32),
# Same boundaries
spaces.Box(low=1, high=1, shape=(2,), dtype=np.float32),
# Unbounded action space
spaces.Box(low=-np.inf, high=1, shape=(2,), dtype=np.float32),
# Almost good, except for one dim
spaces.Box(low=np.array([-1, -1, -1]), high=np.array([1, 1, 0.99]), dtype=np.float32),
],
@ -156,8 +158,14 @@ def test_non_default_action_spaces(new_action_space):
# Change the action space
env.action_space = new_action_space
with pytest.warns(UserWarning):
check_env(env)
# Unbounded action space throws an error,
# the rest only warning
if not np.all(np.isfinite(env.action_space.low)):
with pytest.raises(AssertionError), pytest.warns(UserWarning):
check_env(env)
else:
with pytest.warns(UserWarning):
check_env(env)
def check_reset_assert_error(env, new_reset_return):

View file

@ -10,7 +10,7 @@ from stable_baselines3.common.policies import ActorCriticPolicy
class CustomEnv(gym.Env):
def __init__(self, max_steps=8):
super(CustomEnv, self).__init__()
super().__init__()
self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
self.max_steps = max_steps
@ -54,7 +54,7 @@ class InfiniteHorizonEnv(gym.Env):
class CheckGAECallback(BaseCallback):
def __init__(self):
super(CheckGAECallback, self).__init__(verbose=0)
super().__init__(verbose=0)
def _on_rollout_end(self):
buffer = self.model.rollout_buffer
@ -99,7 +99,7 @@ class CustomPolicy(ActorCriticPolicy):
"""Custom Policy with a constant value function"""
def __init__(self, *args, **kwargs):
super(CustomPolicy, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
self.constant_value = 0.0
def forward(self, obs, deterministic=False):

View file

@ -156,7 +156,7 @@ def test_save_load(tmp_path, model_class, use_sde, online_sampling):
params = deepcopy(model.policy.state_dict())
# Modify all parameters to be random values
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
random_params = {param_name: th.rand_like(param) for param_name, param in params.items()}
# Update model parameters with the new random values
model.policy.load_state_dict(random_params)

View file

@ -1,6 +1,7 @@
import os
import time
from typing import Sequence
from unittest import mock
import gym
import numpy as np
@ -381,3 +382,16 @@ def test_fps_logger(tmp_path, algo):
# third time, FPS should be the same
model.learn(100, log_interval=1, reset_num_timesteps=False)
assert max_fps / 10 <= logger.name_to_value["time/fps"] <= max_fps
@pytest.mark.parametrize("algo", [A2C, DQN])
def test_fps_no_div_zero(algo):
"""Set time to constant and train algorithm to check no division by zero error.
Time can appear to be constant during short runs on platforms with low-precision
timers. We should avoid division by zero errors e.g. when computing FPS in
this situation."""
with mock.patch("time.time", lambda: 42.0):
with mock.patch("time.time_ns", lambda: 42.0):
model = algo("MlpPolicy", "CartPole-v1")
model.learn(total_timesteps=100)

View file

@ -14,7 +14,7 @@ def test_monitor(tmp_path):
"""
env = gym.make("CartPole-v1")
env.seed(0)
monitor_file = os.path.join(str(tmp_path), "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4()))
monitor_file = os.path.join(str(tmp_path), f"stable_baselines-test-{uuid.uuid4()}.monitor.csv")
monitor_env = Monitor(env, monitor_file)
monitor_env.reset()
total_steps = 1000
@ -37,7 +37,7 @@ def test_monitor(tmp_path):
assert sum(monitor_env.get_episode_rewards()) == sum(ep_rewards)
_ = monitor_env.get_episode_times()
with open(monitor_file, "rt") as file_handler:
with open(monitor_file) as file_handler:
first_line = file_handler.readline()
assert first_line.startswith("#")
metadata = json.loads(first_line[1:])
@ -56,7 +56,7 @@ def test_monitor_load_results(tmp_path):
tmp_path = str(tmp_path)
env1 = gym.make("CartPole-v1")
env1.seed(0)
monitor_file1 = os.path.join(tmp_path, "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4()))
monitor_file1 = os.path.join(tmp_path, f"stable_baselines-test-{uuid.uuid4()}.monitor.csv")
monitor_env1 = Monitor(env1, monitor_file1)
monitor_files = get_monitor_files(tmp_path)
@ -76,7 +76,7 @@ def test_monitor_load_results(tmp_path):
env2 = gym.make("CartPole-v1")
env2.seed(0)
monitor_file2 = os.path.join(tmp_path, "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4()))
monitor_file2 = os.path.join(tmp_path, f"stable_baselines-test-{uuid.uuid4()}.monitor.csv")
monitor_env2 = Monitor(env2, monitor_file2)
monitor_files = get_monitor_files(tmp_path)
assert len(monitor_files) == 2

View file

@ -73,11 +73,13 @@ def test_predict(model_class, env_id, device):
obs = env.reset()
action, _ = model.predict(obs)
assert isinstance(action, np.ndarray)
assert action.shape == env.action_space.shape
assert env.action_space.contains(action)
vec_env_obs = vec_env.reset()
action, _ = model.predict(vec_env_obs)
assert isinstance(action, np.ndarray)
assert action.shape[0] == vec_env_obs.shape[0]
# Special case for DQN to check the epsilon greedy exploration

View file

@ -10,7 +10,10 @@ normal_action_noise = NormalActionNoise(np.zeros(1), 0.1 * np.ones(1))
@pytest.mark.parametrize("model_class", [TD3, DDPG])
@pytest.mark.parametrize("action_noise", [normal_action_noise, OrnsteinUhlenbeckActionNoise(np.zeros(1), 0.1 * np.ones(1))])
@pytest.mark.parametrize(
"action_noise",
[normal_action_noise, OrnsteinUhlenbeckActionNoise(np.zeros(1), 0.1 * np.ones(1))],
)
def test_deterministic_pg(model_class, action_noise):
"""
Test for DDPG and variants (TD3).

View file

@ -64,7 +64,7 @@ def test_save_load(tmp_path, model_class):
model.set_parameters(invalid_object_params, exact_match=False)
# Test that exact_match catches when something was missed.
missing_object_params = dict((k, v) for k, v in list(original_params.items())[:-1])
missing_object_params = {k: v for k, v in list(original_params.items())[:-1]}
with pytest.raises(ValueError):
model.set_parameters(missing_object_params, exact_match=True)
@ -375,6 +375,9 @@ def test_warn_buffer(recwarn, model_class, optimize_memory_usage):
select_env(model_class),
buffer_size=100,
optimize_memory_usage=optimize_memory_usage,
# we cannot use optimize_memory_usage and handle_timeout_termination
# at the same time
replay_buffer_kwargs={"handle_timeout_termination": not optimize_memory_usage},
policy_kwargs=dict(net_arch=[64]),
learning_starts=10,
)
@ -446,7 +449,7 @@ def test_save_load_policy(tmp_path, model_class, policy_str, use_sde):
params = deepcopy(policy.state_dict())
# Modify all parameters to be random values
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
random_params = {param_name: th.rand_like(param) for param_name, param in params.items()}
# Update model parameters with the new random values
policy.load_state_dict(random_params)
@ -537,7 +540,7 @@ def test_save_load_q_net(tmp_path, model_class, policy_str):
params = deepcopy(q_net.state_dict())
# Modify all parameters to be random values
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
random_params = {param_name: th.rand_like(param) for param_name, param in params.items()}
# Update model parameters with the new random values
q_net.load_state_dict(random_params)

View file

@ -9,7 +9,7 @@ from stable_baselines3.common.evaluation import evaluate_policy
class DummyMultiDiscreteSpace(gym.Env):
def __init__(self, nvec):
super(DummyMultiDiscreteSpace, self).__init__()
super().__init__()
self.observation_space = gym.spaces.MultiDiscrete(nvec)
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
@ -22,7 +22,7 @@ class DummyMultiDiscreteSpace(gym.Env):
class DummyMultiBinary(gym.Env):
def __init__(self, n):
super(DummyMultiBinary, self).__init__()
super().__init__()
self.observation_space = gym.spaces.MultiBinary(n)
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
@ -33,6 +33,19 @@ class DummyMultiBinary(gym.Env):
return self.observation_space.sample(), 0.0, False, {}
class DummyMultidimensionalAction(gym.Env):
def __init__(self):
super().__init__()
self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32)
def reset(self):
return self.observation_space.sample()
def step(self, action):
return self.observation_space.sample(), 0.0, False, {}
@pytest.mark.parametrize("model_class", [SAC, TD3, DQN])
@pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8)])
def test_identity_spaces(model_class, env):
@ -53,22 +66,39 @@ def test_identity_spaces(model_class, env):
@pytest.mark.parametrize("model_class", [A2C, DDPG, DQN, PPO, SAC, TD3])
@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1"])
@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1", DummyMultidimensionalAction()])
def test_action_spaces(model_class, env):
kwargs = {}
if model_class in [SAC, DDPG, TD3]:
supported_action_space = env == "Pendulum-v1"
supported_action_space = env == "Pendulum-v1" or isinstance(env, DummyMultidimensionalAction)
kwargs["learning_starts"] = 2
kwargs["train_freq"] = 32
elif model_class == DQN:
supported_action_space = env == "CartPole-v1"
elif model_class in [A2C, PPO]:
supported_action_space = True
kwargs["n_steps"] = 64
if supported_action_space:
model_class("MlpPolicy", env)
model = model_class("MlpPolicy", env, **kwargs)
if isinstance(env, DummyMultidimensionalAction):
model.learn(64)
else:
with pytest.raises(AssertionError):
model_class("MlpPolicy", env)
def test_sde_multi_dim():
SAC(
"MlpPolicy",
DummyMultidimensionalAction(),
learning_starts=10,
use_sde=True,
sde_sample_freq=2,
use_sde_at_warmup=True,
).learn(20)
@pytest.mark.parametrize("model_class", [A2C, PPO, DQN])
@pytest.mark.parametrize("env", ["Taxi-v3"])
def test_discrete_obs_space(model_class, env):

View file

@ -3,6 +3,7 @@ import os
import pytest
from stable_baselines3 import A2C, PPO, SAC, TD3
from stable_baselines3.common.utils import get_latest_run_id
MODEL_DICT = {
"a2c": (A2C, "CartPole-v1"),
@ -35,3 +36,13 @@ def test_tensorboard(tmp_path, model_name):
assert os.path.isdir(tmp_path / str(logname + "_1"))
# Check that the log dir name increments correctly
assert os.path.isdir(tmp_path / str(logname + "_2"))
def test_escape_log_name(tmp_path):
# Log name that must be escaped
log_name = "filename[16, 16]"
# Create folder
os.makedirs(str(tmp_path) + f"/{log_name}_1", exist_ok=True)
os.makedirs(str(tmp_path) + f"/{log_name}_2", exist_ok=True)
last_run_id = get_latest_run_id(tmp_path, log_name)
assert last_run_id == 2

View file

@ -28,7 +28,7 @@ class FlattenBatchNormDropoutExtractor(BaseFeaturesExtractor):
"""
def __init__(self, observation_space: gym.Space):
super(FlattenBatchNormDropoutExtractor, self).__init__(
super().__init__(
observation_space,
get_flattened_obs_dim(observation_space),
)

View file

@ -180,7 +180,7 @@ class AlwaysDoneWrapper(gym.Wrapper):
# Pretends that environment only has single step for each
# episode.
def __init__(self, env):
super(AlwaysDoneWrapper, self).__init__(env)
super().__init__(env)
self.last_obs = None
self.needs_reset = True

View file

@ -12,7 +12,7 @@ class NanAndInfEnv(gym.Env):
metadata = {"render.modes": ["human"]}
def __init__(self):
super(NanAndInfEnv, self).__init__()
super().__init__()
self.action_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64)
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64)

View file

@ -31,7 +31,7 @@ class CustomGymEnv(gym.Env):
return self.state
def step(self, action):
reward = 1
reward = float(np.random.rand())
self._choose_next_state()
self.current_step += 1
done = self.current_step >= self.ep_length
@ -45,7 +45,9 @@ class CustomGymEnv(gym.Env):
return np.zeros((4, 4, 3))
def seed(self, seed=None):
pass
if seed is not None:
np.random.seed(seed)
self.observation_space.seed(seed)
@staticmethod
def custom_method(dim_0=1, dim_1=1):
@ -440,3 +442,34 @@ def test_vec_env_is_wrapped():
vec_env = VecFrameStack(vec_env, n_stack=2)
assert vec_env.env_is_wrapped(Monitor) == [False, True]
@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES)
def test_vec_seeding(vec_env_class):
def make_env():
return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2)))
# For SubprocVecEnv check for all starting methods
start_methods = [None]
if vec_env_class != DummyVecEnv:
all_methods = {"forkserver", "spawn", "fork"}
available_methods = multiprocessing.get_all_start_methods()
start_methods = list(all_methods.intersection(available_methods))
for start_method in start_methods:
if start_method is not None:
vec_env_class = functools.partial(SubprocVecEnv, start_method=start_method)
n_envs = 3
vec_env = vec_env_class([make_env] * n_envs)
# Seed with no argument
vec_env.seed()
obs = vec_env.reset()
_, rewards, _, _ = vec_env.step(np.array([vec_env.action_space.sample() for _ in range(n_envs)]))
# Seed should be different per process
assert not np.allclose(obs[0], obs[1])
assert not np.allclose(rewards[0], rewards[1])
assert not np.allclose(obs[1], obs[2])
assert not np.allclose(rewards[1], rewards[2])
vec_env.close()

View file

@ -36,7 +36,7 @@ def test_vec_monitor(tmp_path):
monitor_env.close()
with open(monitor_file, "rt") as file_handler:
with open(monitor_file) as file_handler:
first_line = file_handler.readline()
assert first_line.startswith("#")
metadata = json.loads(first_line[1:])
@ -66,7 +66,7 @@ def test_vec_monitor_info_keywords(tmp_path):
monitor_env.close()
with open(monitor_file, "rt") as f:
with open(monitor_file) as f:
reader = csv.reader(f)
for i, line in enumerate(reader):
if i == 0 or i == 1:

View file

@ -47,7 +47,7 @@ class DummyDictEnv(gym.GoalEnv):
"""
def __init__(self):
super(DummyDictEnv, self).__init__()
super().__init__()
self.observation_space = spaces.Dict(
{
"observation": spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32),
@ -388,11 +388,11 @@ def test_offpolicy_normalization(model_class, online_sampling):
@pytest.mark.parametrize("make_env", [make_env, make_dict_env])
def test_sync_vec_normalize(make_env):
env = DummyVecEnv([make_env])
original_env = DummyVecEnv([make_env])
assert unwrap_vec_normalize(env) is None
assert unwrap_vec_normalize(original_env) is None
env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=100.0, clip_reward=100.0)
env = VecNormalize(original_env, norm_obs=True, norm_reward=True, clip_obs=100.0, clip_reward=100.0)
assert isinstance(unwrap_vec_normalize(env), VecNormalize)
@ -433,6 +433,17 @@ def test_sync_vec_normalize(make_env):
assert allclose(obs, eval_env.normalize_obs(original_obs))
assert allclose(env.normalize_reward(dummy_rewards), eval_env.normalize_reward(dummy_rewards))
# Check synchronization when only reward is normalized
env = VecNormalize(original_env, norm_obs=False, norm_reward=True, clip_reward=100.0)
eval_env = DummyVecEnv([make_env])
eval_env = VecNormalize(eval_env, training=False, norm_obs=False, norm_reward=False)
env.reset()
env.step([env.action_space.sample()])
assert not np.allclose(env.ret_rms.mean, eval_env.ret_rms.mean)
sync_envs_normalization(env, eval_env)
assert np.allclose(env.ret_rms.mean, eval_env.ret_rms.mean)
assert np.allclose(env.ret_rms.var, eval_env.ret_rms.var)
def test_discrete_obs():
with pytest.raises(ValueError, match=".*only supports.*"):