mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-20 02:08:01 +00:00
Merge branch 'master' into feat/redq
This commit is contained in:
commit
ebf6ed1d0a
77 changed files with 724 additions and 398 deletions
2
.github/workflows/ci.yml
vendored
2
.github/workflows/ci.yml
vendored
|
|
@ -28,7 +28,7 @@ jobs:
|
|||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
# cpu version of pytorch
|
||||
pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
pip install torch==1.11+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
pip install .[extra,tests,docs]
|
||||
# Use headless version
|
||||
pip install opencv-python-headless
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ pytest:
|
|||
- python --version
|
||||
# MKL_THREADING_LAYER=GNU to avoid MKL_THREADING_LAYER=INTEL incompatibility error
|
||||
- MKL_THREADING_LAYER=GNU make pytest
|
||||
coverage: '/^TOTAL.+?(\d+\%)$/'
|
||||
|
||||
doc-build:
|
||||
script:
|
||||
|
|
|
|||
13
README.md
13
README.md
|
|
@ -51,7 +51,7 @@ Documentation is available online: [https://stable-baselines3.readthedocs.io/](h
|
|||
|
||||
## Integrations
|
||||
|
||||
Stable-Baselines3 has some integration with other libraries/services like Weights & Biases for experiment tracking or Hugging Face for storing/sharing trained models. You can find out more in the [dedicated section](https://stable-baselines3.readthedocs.io/en/master/guide/integrations.html) of the documentation.
|
||||
Stable-Baselines3 has some integration with other libraries/services like Weights & Biases for experiment tracking or Hugging Face for storing/sharing trained models. You can find out more in the [dedicated section](https://stable-baselines3.readthedocs.io/en/master/guide/integrations.html) of the documentation.
|
||||
|
||||
|
||||
## RL Baselines3 Zoo: A Training Framework for Stable Baselines3 Reinforcement Learning Agents
|
||||
|
|
@ -77,14 +77,14 @@ Documentation: https://stable-baselines3.readthedocs.io/en/master/guide/rl_zoo.h
|
|||
|
||||
We implement experimental features in a separate contrib repository: [SB3-Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib)
|
||||
|
||||
This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Truncated Quantile Critics (TQC), Quantile Regression DQN (QR-DQN) or PPO with invalid action masking (Maskable PPO).
|
||||
This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Recurrent PPO (PPO LSTM), Truncated Quantile Critics (TQC), Quantile Regression DQN (QR-DQN) or PPO with invalid action masking (Maskable PPO).
|
||||
|
||||
Documentation is available online: [https://sb3-contrib.readthedocs.io/](https://sb3-contrib.readthedocs.io/)
|
||||
|
||||
|
||||
## Installation
|
||||
|
||||
**Note:** Stable-Baselines3 supports PyTorch >= 1.8.1.
|
||||
**Note:** Stable-Baselines3 supports PyTorch >= 1.11
|
||||
|
||||
### Prerequisites
|
||||
Stable Baselines3 requires Python 3.7+.
|
||||
|
|
@ -122,7 +122,7 @@ from stable_baselines3 import PPO
|
|||
env = gym.make("CartPole-v1")
|
||||
|
||||
model = PPO("MlpPolicy", env, verbose=1)
|
||||
model.learn(total_timesteps=10000)
|
||||
model.learn(total_timesteps=10_000)
|
||||
|
||||
obs = env.reset()
|
||||
for i in range(1000):
|
||||
|
|
@ -140,7 +140,7 @@ Or just train a model with a one liner if [the environment is registered in Gym]
|
|||
```python
|
||||
from stable_baselines3 import PPO
|
||||
|
||||
model = PPO('MlpPolicy', 'CartPole-v1').learn(10000)
|
||||
model = PPO("MlpPolicy", "CartPole-v1").learn(10_000)
|
||||
```
|
||||
|
||||
Please read the [documentation](https://stable-baselines3.readthedocs.io/) for more examples.
|
||||
|
|
@ -172,6 +172,7 @@ All the following examples can be executed online using Google colab notebooks:
|
|||
| HER | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :x: |
|
||||
| PPO | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| QR-DQN<sup>[1](#f1)</sup> | :x: | :x: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
|
||||
| RecurrentPPO<sup>[1](#f1)</sup> | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| SAC | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
|
||||
| TD3 | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
|
||||
| TQC<sup>[1](#f1)</sup> | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
|
||||
|
|
@ -231,7 +232,7 @@ To cite this repository in publications:
|
|||
|
||||
## Maintainers
|
||||
|
||||
Stable-Baselines3 is currently maintained by [Ashley Hill](https://github.com/hill-a) (aka @hill-a), [Antonin Raffin](https://araffin.github.io/) (aka [@araffin](https://github.com/araffin)), [Maximilian Ernestus](https://github.com/ernestum) (aka @ernestum), [Adam Gleave](https://github.com/adamgleave) (@AdamGleave) and [Anssi Kanervisto](https://github.com/Miffyli) (@Miffyli).
|
||||
Stable-Baselines3 is currently maintained by [Ashley Hill](https://github.com/hill-a) (aka @hill-a), [Antonin Raffin](https://araffin.github.io/) (aka [@araffin](https://github.com/araffin)), [Maximilian Ernestus](https://github.com/ernestum) (aka @ernestum), [Adam Gleave](https://github.com/adamgleave) (@AdamGleave), [Anssi Kanervisto](https://github.com/Miffyli) (@Miffyli) and [Quentin Gallouédec](https://gallouedec.com/) (@qgallouedec).
|
||||
|
||||
**Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email.
|
||||
Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case.
|
||||
|
|
|
|||
BIN
docs/_static/img/split_graph.png
vendored
Normal file
BIN
docs/_static/img/split_graph.png
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 21 KiB |
|
|
@ -6,9 +6,9 @@ dependencies:
|
|||
- cpuonly=1.0=0
|
||||
- pip=21.1
|
||||
- python=3.7
|
||||
- pytorch=1.8.1=py3.7_cpu_0
|
||||
- pytorch=1.11=py3.7_cpu_0
|
||||
- pip:
|
||||
- gym>=0.17.2
|
||||
- gym==0.21
|
||||
- cloudpickle
|
||||
- opencv-python-headless
|
||||
- pandas
|
||||
|
|
@ -16,5 +16,5 @@ dependencies:
|
|||
- matplotlib
|
||||
- sphinx_autodoc_typehints
|
||||
- sphinx>=4.2
|
||||
# See https://github.com/readthedocs/sphinx_rtd_theme/issues/1115
|
||||
- sphinx_rtd_theme>=1.0
|
||||
- sphinx_copybutton
|
||||
|
|
|
|||
18
docs/conf.py
18
docs/conf.py
|
|
@ -1,4 +1,3 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Configuration file for the Sphinx documentation builder.
|
||||
#
|
||||
|
|
@ -25,6 +24,14 @@ try:
|
|||
except ImportError:
|
||||
enable_spell_check = False
|
||||
|
||||
# Try to enable copy button
|
||||
try:
|
||||
import sphinx_copybutton # noqa: F401
|
||||
|
||||
enable_copy_button = True
|
||||
except ImportError:
|
||||
enable_copy_button = False
|
||||
|
||||
# source code directory, relative to this file, for sphinx-autobuild
|
||||
sys.path.insert(0, os.path.abspath(".."))
|
||||
|
||||
|
|
@ -46,13 +53,13 @@ sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
|
|||
|
||||
# Read version from file
|
||||
version_file = os.path.join(os.path.dirname(__file__), "../stable_baselines3", "version.txt")
|
||||
with open(version_file, "r") as file_handler:
|
||||
with open(version_file) as file_handler:
|
||||
__version__ = file_handler.read().strip()
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = "Stable Baselines3"
|
||||
copyright = "2020, Stable Baselines3"
|
||||
copyright = "2022, Stable Baselines3"
|
||||
author = "Stable Baselines3 Contributors"
|
||||
|
||||
# The short X.Y version
|
||||
|
|
@ -84,6 +91,9 @@ extensions = [
|
|||
if enable_spell_check:
|
||||
extensions.append("sphinxcontrib.spelling")
|
||||
|
||||
if enable_copy_button:
|
||||
extensions.append("sphinx_copybutton")
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ["_templates"]
|
||||
|
||||
|
|
@ -101,7 +111,7 @@ master_doc = "index"
|
|||
#
|
||||
# This is also used if you do content translation via gettext catalogs.
|
||||
# Usually you set "language" from the command line for these cases.
|
||||
language = None
|
||||
language = "en"
|
||||
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ DQN ❌ ✔️ ❌ ❌
|
|||
HER ✔️ ✔️ ❌ ❌ ❌
|
||||
PPO ✔️ ✔️ ✔️ ✔️ ✔️
|
||||
QR-DQN [#f1]_ ❌ ️ ✔️ ❌ ❌ ✔️
|
||||
RecurrentPPO [#f1]_ ✔️ ✔️ ✔️ ✔️ ✔️
|
||||
SAC ✔️ ❌ ❌ ❌ ✔️
|
||||
TD3 ✔️ ❌ ❌ ❌ ✔️
|
||||
TQC [#f1]_ ✔️ ❌ ❌ ❌ ✔️
|
||||
|
|
@ -26,8 +27,8 @@ Maskable PPO [#f1]_ ❌ ✔️ ✔️ ✔
|
|||
.. [#f1] Implemented in `SB3 Contrib <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib>`_
|
||||
|
||||
.. note::
|
||||
``Tuple`` observation spaces are not supported by any environment
|
||||
however single-level ``Dict`` spaces are (cf. :ref:`Examples <examples>`).
|
||||
``Tuple`` observation spaces are not supported by any environment,
|
||||
however, single-level ``Dict`` spaces are (cf. :ref:`Examples <examples>`).
|
||||
|
||||
|
||||
Actions ``gym.spaces``:
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ Then you can define and train a RL agent with:
|
|||
model = A2C('CnnPolicy', env).learn(total_timesteps=1000)
|
||||
|
||||
|
||||
To check that your environment follows the gym interface, please use:
|
||||
To check that your environment follows the Gym interface that SB3 supports, please use:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
|
@ -71,11 +71,11 @@ To check that your environment follows the gym interface, please use:
|
|||
# It will check your custom environment and output additional warnings if needed
|
||||
check_env(env)
|
||||
|
||||
|
||||
Gym also have its own `env checker <https://www.gymlibrary.ml/content/api/#checking-api-conformity>`_ but it checks a superset of what SB3 supports (SB3 does not support all Gym features).
|
||||
|
||||
We have created a `colab notebook <https://colab.research.google.com/github/araffin/rl-tutorial-jnrr19/blob/master/5_custom_gym_env.ipynb>`_ for a concrete example on creating a custom environment along with an example of using it with Stable-Baselines3 interface.
|
||||
|
||||
Alternatively, you may look at OpenAI Gym `built-in environments <https://gym.openai.com/docs/#available-environments>`_. However, the readers are cautioned as per OpenAI Gym `official wiki <https://github.com/openai/gym/wiki/FAQ>`_, its advised not to customize their built-in environments. It is better to copy and create new ones if you need to modify them.
|
||||
Alternatively, you may look at OpenAI Gym `built-in environments <https://www.gymlibrary.ml/>`_. However, the readers are cautioned as per OpenAI Gym `official wiki <https://github.com/openai/gym/wiki/FAQ>`_, its advised not to customize their built-in environments. It is better to copy and create new ones if you need to modify them.
|
||||
|
||||
Optionally, you can also register the environment with gym, that will allow you to create the RL agent in one line (and use ``gym.make()`` to instantiate the env):
|
||||
|
||||
|
|
|
|||
|
|
@ -729,6 +729,16 @@ to keep track of the agent progress.
|
|||
model.learn(10_000)
|
||||
|
||||
|
||||
SB3 with EnvPool or Isaac Gym
|
||||
-----------------------------
|
||||
|
||||
Just like Procgen (see above), `EnvPool <https://github.com/sail-sg/envpool>`_ and `Isaac Gym <https://github.com/NVIDIA-Omniverse/IsaacGymEnvs>`_ accelerate the environment by
|
||||
already providing a vectorized implementation.
|
||||
|
||||
To use SB3 with those tools, you must wrap the env with tool's specific ``VecEnvWrapper`` that will pre-process the data for SB3,
|
||||
you can find links to those wrappers in `issue #772 <https://github.com/DLR-RM/stable-baselines3/issues/772#issuecomment-1048657002>`_.
|
||||
|
||||
|
||||
Record a Video
|
||||
--------------
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ Installation
|
|||
Prerequisites
|
||||
-------------
|
||||
|
||||
Stable-Baselines3 requires python 3.7+ and PyTorch >= 1.8.1.
|
||||
Stable-Baselines3 requires python 3.7+ and PyTorch >= 1.11
|
||||
|
||||
Windows 10
|
||||
~~~~~~~~~~
|
||||
|
|
@ -54,6 +54,17 @@ Bleeding-edge version
|
|||
pip install git+https://github.com/DLR-RM/stable-baselines3
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
If you want to use latest gym version (0.24+), you have to use
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install git+https://github.com/carlosluis/stable-baselines3@fix_tests
|
||||
|
||||
See `PR #780 <https://github.com/DLR-RM/stable-baselines3/pull/780>`_ for more information.
|
||||
|
||||
|
||||
Development version
|
||||
-------------------
|
||||
|
||||
|
|
|
|||
|
|
@ -48,11 +48,14 @@ Hugging Face 🤗
|
|||
The Hugging Face Hub 🤗 is a central place where anyone can share and explore models. It allows you to host your saved models 💾.
|
||||
|
||||
You can see the list of stable-baselines3 saved models here: https://huggingface.co/models?other=stable-baselines3
|
||||
Most of them are available via the RL Zoo.
|
||||
|
||||
Official pre-trained models are saved in the SB3 organization on the hub: https://huggingface.co/sb3
|
||||
|
||||
We wrote a tutorial on how to use 🤗 Hub and Stable-Baselines3 here: https://colab.research.google.com/drive/1GI0WpThwRHbl-Fu2RHfczq6dci5GBDVE#scrollTo=q4cz-w9MdO7T
|
||||
|
||||
For up to date instructions (for instance for using ``package_to_hub()``), please take a look at the Huggingface SB3 package README: https://github.com/huggingface/huggingface_sb3
|
||||
|
||||
Installation
|
||||
-------------
|
||||
|
||||
|
|
@ -137,3 +140,56 @@ Then, in this example, we train a PPO agent to play CartPole-v1 and push it to a
|
|||
filename="ppo-CartPole-v1",
|
||||
commit_message="Added Cartpole-v1 model trained with PPO",
|
||||
)
|
||||
|
||||
MLFLow
|
||||
======
|
||||
|
||||
If you want to use `MLFLow <https://github.com/mlflow/mlflow>`_ to track your SB3 experiments,
|
||||
you can adapt the following code which defines a custom logger output:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import sys
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
import mlflow
|
||||
import numpy as np
|
||||
|
||||
from stable_baselines3 import SAC
|
||||
from stable_baselines3.common.logger import HumanOutputFormat, KVWriter, Logger
|
||||
|
||||
|
||||
class MLflowOutputFormat(KVWriter):
|
||||
"""
|
||||
Dumps key/value pairs into MLflow's numeric format.
|
||||
"""
|
||||
|
||||
def write(
|
||||
self,
|
||||
key_values: Dict[str, Any],
|
||||
key_excluded: Dict[str, Union[str, Tuple[str, ...]]],
|
||||
step: int = 0,
|
||||
) -> None:
|
||||
|
||||
for (key, value), (_, excluded) in zip(
|
||||
sorted(key_values.items()), sorted(key_excluded.items())
|
||||
):
|
||||
|
||||
if excluded is not None and "mlflow" in excluded:
|
||||
continue
|
||||
|
||||
if isinstance(value, np.ScalarType):
|
||||
if not isinstance(value, str):
|
||||
mlflow.log_metric(key, value, step)
|
||||
|
||||
|
||||
loggers = Logger(
|
||||
folder=None,
|
||||
output_formats=[HumanOutputFormat(sys.stdout), MLflowOutputFormat()],
|
||||
)
|
||||
|
||||
with mlflow.start_run():
|
||||
model = SAC("MlpPolicy", "Pendulum-v1", verbose=2)
|
||||
# Set custom logger
|
||||
model.set_logger(loggers)
|
||||
model.learn(total_timesteps=10000, log_interval=1)
|
||||
|
|
|
|||
|
|
@ -141,7 +141,7 @@ DQN
|
|||
^^^
|
||||
|
||||
Only the vanilla DQN is implemented right now but extensions will follow.
|
||||
Default hyperparameters are taken from the nature paper, except for the optimizer and learning rate that were taken from Stable Baselines defaults.
|
||||
Default hyperparameters are taken from the Nature paper, except for the optimizer and learning rate that were taken from Stable Baselines defaults.
|
||||
|
||||
DDPG
|
||||
^^^^
|
||||
|
|
|
|||
|
|
@ -183,6 +183,16 @@ Some basic advice:
|
|||
- start with shaped reward (i.e. informative reward) and simplified version of your problem
|
||||
- debug with random actions to check that your environment works and follows the gym interface:
|
||||
|
||||
Two important things to keep in mind when creating a custom environment is to avoid breaking Markov assumption
|
||||
and properly handle termination due to a timeout (maximum number of steps in an episode).
|
||||
For instance, if there is some time delay between action and observation (e.g. due to wifi communication), you should give an history of observations
|
||||
as input.
|
||||
|
||||
Termination due to timeout (max number of steps per episode) needs to be handled separately. You should fill the key in the info dict: ``info["TimeLimit.truncated"] = True``.
|
||||
If you are using the gym ``TimeLimit`` wrapper, this will be done automatically.
|
||||
You can read `Time Limit in RL <https://arxiv.org/abs/1712.00378>`_ or take a look at the `RL Tips and Tricks video <https://www.youtube.com/watch?v=Ikngt0_DXJg>`_
|
||||
for more details.
|
||||
|
||||
|
||||
We provide a helper to check that your environment runs without error:
|
||||
|
||||
|
|
@ -241,12 +251,15 @@ We *recommend following those steps to have a working RL algorithm*:
|
|||
1. Read the original paper several times
|
||||
2. Read existing implementations (if available)
|
||||
3. Try to have some "sign of life" on toy problems
|
||||
4. Validate the implementation by making it run on harder and harder envs (you can compare results against the RL zoo)
|
||||
You usually need to run hyperparameter optimization for that step.
|
||||
4. Validate the implementation by making it run on harder and harder envs (you can compare results against the RL zoo).
|
||||
You usually need to run hyperparameter optimization for that step.
|
||||
|
||||
You need to be particularly careful on the shape of the different objects you are manipulating (a broadcast mistake will fail silently cf `issue #75 <https://github.com/hill-a/stable-baselines/pull/76>`_)
|
||||
You need to be particularly careful on the shape of the different objects you are manipulating (a broadcast mistake will fail silently cf. `issue #75 <https://github.com/hill-a/stable-baselines/pull/76>`_)
|
||||
and when to stop the gradient propagation.
|
||||
|
||||
Don't forget to handle termination due to timeout separately (see remark in the custom environment section above),
|
||||
you can also take a look at `Issue #284 <https://github.com/DLR-RM/stable-baselines3/issues/284>`_ and `Issue #633 <https://github.com/DLR-RM/stable-baselines3/issues/633>`_.
|
||||
|
||||
A personal pick (by @araffin) for environments with gradual difficulty in RL with continuous actions:
|
||||
|
||||
1. Pendulum (easy to solve)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ We implement experimental features in a separate contrib repository:
|
|||
`SB3-Contrib`_
|
||||
|
||||
This allows Stable-Baselines3 (SB3) to maintain a stable and compact core, while still
|
||||
providing the latest features, like Truncated Quantile Critics (TQC), Augmented Random Search (ARS), Trust Region Policy Optimization (TRPO) or
|
||||
providing the latest features, like RecurrentPPO (PPO LSTM), Truncated Quantile Critics (TQC), Augmented Random Search (ARS), Trust Region Policy Optimization (TRPO) or
|
||||
Quantile Regression DQN (QR-DQN).
|
||||
|
||||
Why create this repository?
|
||||
|
|
@ -38,9 +38,11 @@ See documentation for the full list of included features.
|
|||
|
||||
- `Augmented Random Search (ARS) <https://arxiv.org/abs/1803.07055>`_
|
||||
- `Quantile Regression DQN (QR-DQN)`_
|
||||
- `PPO with invalid action masking (Maskable PPO) <https://arxiv.org/abs/2006.14171>`_
|
||||
- `PPO with recurrent policy (RecurrentPPO aka PPO LSTM) <https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/>`_
|
||||
- `Truncated Quantile Critics (TQC)`_
|
||||
- `Trust Region Policy Optimization (TRPO) <https://arxiv.org/abs/1502.05477>`_
|
||||
- `PPO with invalid action masking (Maskable PPO) <https://arxiv.org/abs/2006.14171>`_
|
||||
|
||||
|
||||
**Gym Wrappers**:
|
||||
|
||||
|
|
|
|||
|
|
@ -26,10 +26,20 @@ You can also define custom logging name when training (by default it is the algo
|
|||
model.learn(total_timesteps=10_000, tb_log_name="first_run")
|
||||
# Pass reset_num_timesteps=False to continue the training curve in tensorboard
|
||||
# By default, it will create a new curve
|
||||
# Keep tb_log_name constant to have continuous curve (see note below)
|
||||
model.learn(total_timesteps=10_000, tb_log_name="second_run", reset_num_timesteps=False)
|
||||
model.learn(total_timesteps=10_000, tb_log_name="third_run", reset_num_timesteps=False)
|
||||
|
||||
|
||||
.. note::
|
||||
If you specify different ``tb_log_name`` in subsequent runs, you will have split graphs, like in the figure below.
|
||||
If you want them to be continuous, you must keep the same ``tb_log_name`` (see `issue #975 <https://github.com/DLR-RM/stable-baselines3/issues/975#issuecomment-1198992211>`_).
|
||||
And, if you still managed to get your graphs split by other means, just put tensorboard log files into the same folder.
|
||||
|
||||
.. image:: ../_static/img/split_graph.png
|
||||
:width: 330
|
||||
:alt: split_graph
|
||||
|
||||
Once the learn function is called, you can monitor the RL agent during or after the training, with the following bash command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
|
|
|||
|
|
@ -3,8 +3,7 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
|
||||
Release 1.5.1a0 (WIP)
|
||||
Release 1.6.1a0 (WIP)
|
||||
---------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -18,16 +17,80 @@ SB3-Contrib
|
|||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fixed saving and loading large policies greater than 2GB (@jkterry1, @ycheng517)
|
||||
- Fixed the issue that ``predict`` does not always return action as ``np.ndarray`` (@qgallouedec)
|
||||
- Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers.
|
||||
- Added multidimensional action space support (@qgallouedec)
|
||||
- Fixed missing verbose parameter passing in the ``EvalCallback`` constructor (@burakdmb)
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
- Fixed ``DictReplayBuffer.next_observations`` typing (@qgallouedec)
|
||||
|
||||
- Added support for ``device="auto"`` in buffers and made it default (@qgallouedec)
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
- Fixed typo in docstring "nature" -> "Nature" (@Melanol)
|
||||
- Added info on split tensorboard logs into (@Melanol)
|
||||
- Fixed typo in ppo doc (@francescoluciano)
|
||||
- Fixed typo in install doc(@jlp-ue)
|
||||
|
||||
|
||||
Release 1.6.0 (2022-07-11)
|
||||
---------------------------
|
||||
|
||||
**Recurrent PPO (PPO LSTM), better defaults for learning from pixels with SAC/TD3**
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Changed the way policy "aliases" are handled ("MlpPolicy", "CnnPolicy", ...), removing the former
|
||||
``register_policy`` helper, ``policy_base`` parameter and using ``policy_aliases`` static attributes instead (@Gregwar)
|
||||
- SB3 now requires PyTorch >= 1.11
|
||||
- Changed the default network architecture when using ``CnnPolicy`` or ``MultiInputPolicy`` with SAC or DDPG/TD3,
|
||||
``share_features_extractor`` is now set to False by default and the ``net_arch=[256, 256]`` (instead of ``net_arch=[]`` that was before)
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
SB3-Contrib
|
||||
^^^^^^^^^^^
|
||||
- Added Recurrent PPO (PPO LSTM). See https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/53
|
||||
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fixed saving and loading large policies greater than 2GB (@jkterry1, @ycheng517)
|
||||
- Fixed final goal selection strategy that did not sample the final achieved goal (@qgallouedec)
|
||||
- Fixed a bug with special characters in the tensorboard log name (@quantitative-technologies)
|
||||
- Fixed a bug in ``DummyVecEnv``'s and ``SubprocVecEnv``'s seeding function. None value was unchecked (@ScheiklP)
|
||||
- Fixed a bug where ``EvalCallback`` would crash when trying to synchronize ``VecNormalize`` stats when observation normalization was disabled
|
||||
- Added a check for unbounded actions
|
||||
- Fixed issues due to newer version of protobuf (tensorboard) and sphinx
|
||||
- Fix exception causes all over the codebase (@cool-RR)
|
||||
- Prohibit simultaneous use of optimize_memory_usage and handle_timeout_termination due to a bug (@MWeltevrede)
|
||||
- Fixed a bug in ``kl_divergence`` check that would fail when using numpy arrays with MultiCategorical distribution
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
- Upgraded to Python 3.7+ syntax using ``pyupgrade``
|
||||
- Removed redundant double-check for nested observations from ``BaseAlgorithm._wrap_env`` (@TibiGG)
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
- Added link to gym doc and gym env checker
|
||||
- Fix typo in PPO doc (@bcollazo)
|
||||
- Added link to PPO ICLR blog post
|
||||
- Added remark about breaking Markov assumption and timeout handling
|
||||
- Added doc about MLFlow integration via custom logger (@git-thor)
|
||||
- Updated Huggingface integration doc
|
||||
- Added copy button for code snippets
|
||||
- Added doc about EnvPool and Isaac Gym support
|
||||
|
||||
|
||||
Release 1.5.0 (2022-03-25)
|
||||
|
|
@ -923,7 +986,8 @@ Maintainers
|
|||
-----------
|
||||
|
||||
Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_), `Ashley Hill`_ (aka @hill-a),
|
||||
`Maximilian Ernestus`_ (aka @ernestum), `Adam Gleave`_ (`@AdamGleave`_) and `Anssi Kanervisto`_ (aka `@Miffyli`_).
|
||||
`Maximilian Ernestus`_ (aka @ernestum), `Adam Gleave`_ (`@AdamGleave`_), `Anssi Kanervisto`_ (aka `@Miffyli`_)
|
||||
and `Quentin Gallouédec`_ (aka @qgallouedec).
|
||||
|
||||
.. _Ashley Hill: https://github.com/hill-a
|
||||
.. _Antonin Raffin: https://araffin.github.io/
|
||||
|
|
@ -933,6 +997,8 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_)
|
|||
.. _@AdamGleave: https://github.com/adamgleave
|
||||
.. _Anssi Kanervisto: https://github.com/Miffyli
|
||||
.. _@Miffyli: https://github.com/Miffyli
|
||||
.. _Quentin Gallouédec: https://gallouedec.com/
|
||||
.. _@qgallouedec: https://github.com/qgallouedec
|
||||
|
||||
|
||||
|
||||
|
|
@ -957,4 +1023,5 @@ And all the contributors:
|
|||
@wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum
|
||||
@eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP
|
||||
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485
|
||||
@Gregwar @ycheng517
|
||||
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
|
||||
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb
|
||||
|
|
|
|||
|
|
@ -8,14 +8,14 @@ PPO
|
|||
The `Proximal Policy Optimization <https://arxiv.org/abs/1707.06347>`_ algorithm combines ideas from A2C (having multiple workers)
|
||||
and TRPO (it uses a trust region to improve the actor).
|
||||
|
||||
The main idea is that after an update, the new policy should be not too far form the old policy.
|
||||
The main idea is that after an update, the new policy should be not too far from the old policy.
|
||||
For that, ppo uses clipping to avoid too large update.
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
PPO contains several modifications from the original algorithm not documented
|
||||
by OpenAI: advantages are normalized and value function can be also clipped .
|
||||
by OpenAI: advantages are normalized and value function can be also clipped.
|
||||
|
||||
|
||||
Notes
|
||||
|
|
@ -25,11 +25,22 @@ Notes
|
|||
- Clear explanation of PPO on Arxiv Insights channel: https://www.youtube.com/watch?v=5P7I-xPq8u8
|
||||
- OpenAI blog post: https://blog.openai.com/openai-baselines-ppo/
|
||||
- Spinning Up guide: https://spinningup.openai.com/en/latest/algorithms/ppo.html
|
||||
- 37 implementation details blog: https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/
|
||||
|
||||
|
||||
Can I use?
|
||||
----------
|
||||
|
||||
.. note::
|
||||
|
||||
A recurrent version of PPO is available in our contrib repo: https://sb3-contrib.readthedocs.io/en/master/modules/ppo_recurrent.html
|
||||
|
||||
However we advise users to start with simple frame-stacking as a simpler, faster
|
||||
and usually competitive alternative, more info in our report: https://wandb.ai/sb3/no-vel-envs/reports/PPO-vs-RecurrentPPO-aka-PPO-LSTM-on-environments-with-masked-velocity--VmlldzoxOTI4NjE4
|
||||
See also `Procgen paper appendix Fig 11. <https://arxiv.org/abs/1912.01588>`_.
|
||||
In practice, you can stack multiple observations using ``VecFrameStack``.
|
||||
|
||||
|
||||
- Recurrent policies: ❌
|
||||
- Multi processing: ✔️
|
||||
- Gym spaces:
|
||||
|
|
|
|||
21
setup.py
21
setup.py
|
|
@ -2,7 +2,7 @@ import os
|
|||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
with open(os.path.join("stable_baselines3", "version.txt"), "r") as file_handler:
|
||||
with open(os.path.join("stable_baselines3", "version.txt")) as file_handler:
|
||||
__version__ = file_handler.read().strip()
|
||||
|
||||
|
||||
|
|
@ -43,10 +43,10 @@ import gym
|
|||
|
||||
from stable_baselines3 import PPO
|
||||
|
||||
env = gym.make('CartPole-v1')
|
||||
env = gym.make("CartPole-v1")
|
||||
|
||||
model = PPO('MlpPolicy', env, verbose=1)
|
||||
model.learn(total_timesteps=10000)
|
||||
model = PPO("MlpPolicy", env, verbose=1)
|
||||
model.learn(total_timesteps=10_000)
|
||||
|
||||
obs = env.reset()
|
||||
for i in range(1000):
|
||||
|
|
@ -57,12 +57,12 @@ for i in range(1000):
|
|||
obs = env.reset()
|
||||
```
|
||||
|
||||
Or just train a model with a one liner if [the environment is registered in Gym](https://github.com/openai/gym/wiki/Environments) and if [the policy is registered](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html):
|
||||
Or just train a model with a one liner if [the environment is registered in Gym](https://www.gymlibrary.ml/content/environment_creation/) and if [the policy is registered](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html):
|
||||
|
||||
```python
|
||||
from stable_baselines3 import PPO
|
||||
|
||||
model = PPO('MlpPolicy', 'CartPole-v1').learn(10000)
|
||||
model = PPO("MlpPolicy", "CartPole-v1").learn(10_000)
|
||||
```
|
||||
|
||||
""" # noqa:E501
|
||||
|
|
@ -75,7 +75,7 @@ setup(
|
|||
install_requires=[
|
||||
"gym==0.21", # Fixed version due to breaking changes in 0.22
|
||||
"numpy",
|
||||
"torch>=1.8.1",
|
||||
"torch>=1.11",
|
||||
# For saving models
|
||||
"cloudpickle",
|
||||
# For reading logs
|
||||
|
|
@ -111,16 +111,21 @@ setup(
|
|||
"sphinxcontrib.spelling",
|
||||
# Type hints support
|
||||
"sphinx-autodoc-typehints",
|
||||
# Copy button for code snippets
|
||||
"sphinx_copybutton",
|
||||
],
|
||||
"extra": [
|
||||
# For render
|
||||
"opencv-python",
|
||||
# For atari games,
|
||||
"ale-py~=0.7.4",
|
||||
"ale-py==0.7.4",
|
||||
"autorom[accept-rom-license]~=0.4.2",
|
||||
"pillow",
|
||||
# Tensorboard support
|
||||
"tensorboard>=2.2.0",
|
||||
# Protobuf >= 4 has breaking changes
|
||||
# which does play well with tensorboard
|
||||
"protobuf~=3.19.0",
|
||||
# Checking memory taken by replay buffer
|
||||
"psutil",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from stable_baselines3.td3 import TD3
|
|||
|
||||
# Read version from file
|
||||
version_file = os.path.join(os.path.dirname(__file__), "version.txt")
|
||||
with open(version_file, "r") as file_handler:
|
||||
with open(version_file) as file_handler:
|
||||
__version__ = file_handler.read().strip()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from gym import spaces
|
|||
from torch.nn import functional as F
|
||||
|
||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
from stable_baselines3.common.policies import ActorCriticPolicy
|
||||
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import explained_variance
|
||||
|
||||
|
|
@ -51,6 +51,12 @@ class A2C(OnPolicyAlgorithm):
|
|||
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
|
||||
policy_aliases: Dict[str, Type[BasePolicy]] = {
|
||||
"MlpPolicy": ActorCriticPolicy,
|
||||
"CnnPolicy": ActorCriticCnnPolicy,
|
||||
"MultiInputPolicy": MultiInputActorCriticPolicy,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Union[str, Type[ActorCriticPolicy]],
|
||||
|
|
@ -76,7 +82,7 @@ class A2C(OnPolicyAlgorithm):
|
|||
_init_setup_model: bool = True,
|
||||
):
|
||||
|
||||
super(A2C, self).__init__(
|
||||
super().__init__(
|
||||
policy,
|
||||
env,
|
||||
learning_rate=learning_rate,
|
||||
|
|
@ -188,7 +194,7 @@ class A2C(OnPolicyAlgorithm):
|
|||
reset_num_timesteps: bool = True,
|
||||
) -> "A2C":
|
||||
|
||||
return super(A2C, self).learn(
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
callback=callback,
|
||||
log_interval=log_interval,
|
||||
|
|
|
|||
|
|
@ -1,16 +1,7 @@
|
|||
# This file is here just to define MlpPolicy/CnnPolicy
|
||||
# that work for A2C
|
||||
from stable_baselines3.common.policies import (
|
||||
ActorCriticCnnPolicy,
|
||||
ActorCriticPolicy,
|
||||
MultiInputActorCriticPolicy,
|
||||
register_policy,
|
||||
)
|
||||
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
|
||||
|
||||
MlpPolicy = ActorCriticPolicy
|
||||
CnnPolicy = ActorCriticCnnPolicy
|
||||
MultiInputPolicy = MultiInputActorCriticPolicy
|
||||
|
||||
register_policy("MlpPolicy", ActorCriticPolicy)
|
||||
register_policy("CnnPolicy", ActorCriticCnnPolicy)
|
||||
register_policy("MultiInputPolicy", MultiInputPolicy)
|
||||
|
|
|
|||
|
|
@ -245,4 +245,4 @@ class AtariWrapper(gym.Wrapper):
|
|||
if clip_reward:
|
||||
env = ClipRewardEnv(env)
|
||||
|
||||
super(AtariWrapper, self).__init__(env)
|
||||
super().__init__(env)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from stable_baselines3.common.env_util import is_wrapped
|
|||
from stable_baselines3.common.logger import Logger
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.noise import ActionNoise
|
||||
from stable_baselines3.common.policies import BasePolicy, get_policy_from_name
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space, is_image_space_channels_first
|
||||
from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
|
|
@ -60,7 +60,6 @@ class BaseAlgorithm(ABC):
|
|||
:param policy: Policy object
|
||||
:param env: The environment to learn from
|
||||
(if registered in Gym, can be str. Can be None for loading trained models)
|
||||
:param policy_base: The base policy used by this method
|
||||
:param learning_rate: learning rate for the optimizer,
|
||||
it can be a function of the current progress remaining (from 1 to 0)
|
||||
:param policy_kwargs: Additional arguments to be passed to the policy on creation
|
||||
|
|
@ -83,11 +82,13 @@ class BaseAlgorithm(ABC):
|
|||
:param supported_action_spaces: The action spaces supported by the algorithm.
|
||||
"""
|
||||
|
||||
# Policy aliases (see _get_policy_from_name())
|
||||
policy_aliases: Dict[str, Type[BasePolicy]] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Type[BasePolicy],
|
||||
env: Union[GymEnv, str, None],
|
||||
policy_base: Type[BasePolicy],
|
||||
learning_rate: Union[float, Schedule],
|
||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
|
|
@ -101,9 +102,8 @@ class BaseAlgorithm(ABC):
|
|||
sde_sample_freq: int = -1,
|
||||
supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None,
|
||||
):
|
||||
|
||||
if isinstance(policy, str) and policy_base is not None:
|
||||
self.policy_class = get_policy_from_name(policy_base, policy)
|
||||
if isinstance(policy, str):
|
||||
self.policy_class = self._get_policy_from_name(policy)
|
||||
else:
|
||||
self.policy_class = policy
|
||||
|
||||
|
|
@ -185,6 +185,11 @@ class BaseAlgorithm(ABC):
|
|||
if self.use_sde and not isinstance(self.action_space, gym.spaces.Box):
|
||||
raise ValueError("generalized State-Dependent Exploration (gSDE) can only be used with continuous actions.")
|
||||
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
assert np.all(
|
||||
np.isfinite(np.array([self.action_space.low, self.action_space.high]))
|
||||
), "Continuous action space must have a finite lower and upper bound"
|
||||
|
||||
@staticmethod
|
||||
def _wrap_env(env: GymEnv, verbose: int = 0, monitor_wrapper: bool = True) -> VecEnv:
|
||||
""" "
|
||||
|
|
@ -209,11 +214,6 @@ class BaseAlgorithm(ABC):
|
|||
# Make sure that dict-spaces are not nested (not supported)
|
||||
check_for_nested_spaces(env.observation_space)
|
||||
|
||||
if isinstance(env.observation_space, gym.spaces.Dict):
|
||||
for space in env.observation_space.spaces.values():
|
||||
if isinstance(space, gym.spaces.Dict):
|
||||
raise ValueError("Nested observation spaces are not supported (Dict spaces inside Dict space).")
|
||||
|
||||
if not is_vecenv_wrapped(env, VecTransposeImage):
|
||||
wrap_with_vectranspose = False
|
||||
if isinstance(env.observation_space, gym.spaces.Dict):
|
||||
|
|
@ -325,6 +325,23 @@ class BaseAlgorithm(ABC):
|
|||
"_custom_logger",
|
||||
]
|
||||
|
||||
def _get_policy_from_name(self, policy_name: str) -> Type[BasePolicy]:
|
||||
"""
|
||||
Get a policy class from its name representation.
|
||||
|
||||
The goal here is to standardize policy naming, e.g.
|
||||
all algorithms can call upon "MlpPolicy" or "CnnPolicy",
|
||||
and they receive respective policies that work for them.
|
||||
|
||||
:param policy_name: Alias of the policy
|
||||
:return: A policy class (type)
|
||||
"""
|
||||
|
||||
if policy_name in self.policy_aliases:
|
||||
return self.policy_aliases[policy_name]
|
||||
else:
|
||||
raise ValueError(f"Policy {policy_name} unknown")
|
||||
|
||||
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
Get the name of the torch variables that will be saved with
|
||||
|
|
@ -375,6 +392,7 @@ class BaseAlgorithm(ABC):
|
|||
log_path=log_path,
|
||||
eval_freq=eval_freq,
|
||||
n_eval_episodes=n_eval_episodes,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
callback = CallbackList([callback, eval_callback])
|
||||
|
||||
|
|
@ -405,7 +423,7 @@ class BaseAlgorithm(ABC):
|
|||
:param tb_log_name: the name of the run for tensorboard log
|
||||
:return:
|
||||
"""
|
||||
self.start_time = time.time()
|
||||
self.start_time = time.time_ns()
|
||||
|
||||
if self.ep_info_buffer is None or reset_num_timesteps:
|
||||
# Initialize buffers if they don't exist, or reinitialize if resetting counters
|
||||
|
|
@ -611,11 +629,11 @@ class BaseAlgorithm(ABC):
|
|||
attr = None
|
||||
try:
|
||||
attr = recursive_getattr(self, name)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
# What errors recursive_getattr could throw? KeyError, but
|
||||
# possible something else too (e.g. if key is an int?).
|
||||
# Catch anything for now.
|
||||
raise ValueError(f"Key {name} is an invalid object name.")
|
||||
raise ValueError(f"Key {name} is an invalid object name.") from e
|
||||
|
||||
if isinstance(attr, th.optim.Optimizer):
|
||||
# Optimizers do not support "strict" keyword...
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from stable_baselines3.common.type_aliases import (
|
|||
ReplayBufferSamples,
|
||||
RolloutBufferSamples,
|
||||
)
|
||||
from stable_baselines3.common.utils import get_device
|
||||
from stable_baselines3.common.vec_env import VecNormalize
|
||||
|
||||
try:
|
||||
|
|
@ -39,10 +40,10 @@ class BaseBuffer(ABC):
|
|||
buffer_size: int,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
device: Union[th.device, str] = "cpu",
|
||||
device: Union[th.device, str] = "auto",
|
||||
n_envs: int = 1,
|
||||
):
|
||||
super(BaseBuffer, self).__init__()
|
||||
super().__init__()
|
||||
self.buffer_size = buffer_size
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
|
|
@ -51,7 +52,7 @@ class BaseBuffer(ABC):
|
|||
self.action_dim = get_action_dim(action_space)
|
||||
self.pos = 0
|
||||
self.full = False
|
||||
self.device = device
|
||||
self.device = get_device(device)
|
||||
self.n_envs = n_envs
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -157,13 +158,14 @@ class ReplayBuffer(BaseBuffer):
|
|||
:param buffer_size: Max number of element in the buffer
|
||||
:param observation_space: Observation space
|
||||
:param action_space: Action space
|
||||
:param device:
|
||||
:param device: PyTorch device
|
||||
:param n_envs: Number of parallel environments
|
||||
:param optimize_memory_usage: Enable a memory efficient variant
|
||||
of the replay buffer which reduces by almost a factor two the memory used,
|
||||
at a cost of more complexity.
|
||||
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
|
||||
and https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
|
||||
Cannot be used in combination with handle_timeout_termination.
|
||||
:param handle_timeout_termination: Handle timeout termination (due to timelimit)
|
||||
separately and treat the task as infinite horizon task.
|
||||
https://github.com/DLR-RM/stable-baselines3/issues/284
|
||||
|
|
@ -174,12 +176,12 @@ class ReplayBuffer(BaseBuffer):
|
|||
buffer_size: int,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
device: Union[th.device, str] = "cpu",
|
||||
device: Union[th.device, str] = "auto",
|
||||
n_envs: int = 1,
|
||||
optimize_memory_usage: bool = False,
|
||||
handle_timeout_termination: bool = True,
|
||||
):
|
||||
super(ReplayBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
|
||||
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
|
||||
|
||||
# Adjust buffer size
|
||||
self.buffer_size = max(buffer_size // n_envs, 1)
|
||||
|
|
@ -188,6 +190,13 @@ class ReplayBuffer(BaseBuffer):
|
|||
if psutil is not None:
|
||||
mem_available = psutil.virtual_memory().available
|
||||
|
||||
# there is a bug if both optimize_memory_usage and handle_timeout_termination are true
|
||||
# see https://github.com/DLR-RM/stable-baselines3/issues/934
|
||||
if optimize_memory_usage and handle_timeout_termination:
|
||||
raise ValueError(
|
||||
"ReplayBuffer does not support optimize_memory_usage = True "
|
||||
"and handle_timeout_termination = True simultaneously."
|
||||
)
|
||||
self.optimize_memory_usage = optimize_memory_usage
|
||||
|
||||
self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=observation_space.dtype)
|
||||
|
|
@ -239,8 +248,7 @@ class ReplayBuffer(BaseBuffer):
|
|||
next_obs = next_obs.reshape((self.n_envs,) + self.obs_shape)
|
||||
|
||||
# Same, for actions
|
||||
if isinstance(self.action_space, spaces.Discrete):
|
||||
action = action.reshape((self.n_envs, self.action_dim))
|
||||
action = action.reshape((self.n_envs, self.action_dim))
|
||||
|
||||
# Copy to avoid modification by reference
|
||||
self.observations[self.pos] = np.array(obs).copy()
|
||||
|
|
@ -321,7 +329,7 @@ class RolloutBuffer(BaseBuffer):
|
|||
:param buffer_size: Max number of element in the buffer
|
||||
:param observation_space: Observation space
|
||||
:param action_space: Action space
|
||||
:param device:
|
||||
:param device: PyTorch device
|
||||
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
||||
Equivalent to classic advantage when set to 1.
|
||||
:param gamma: Discount factor
|
||||
|
|
@ -333,13 +341,13 @@ class RolloutBuffer(BaseBuffer):
|
|||
buffer_size: int,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
device: Union[th.device, str] = "cpu",
|
||||
device: Union[th.device, str] = "auto",
|
||||
gae_lambda: float = 1,
|
||||
gamma: float = 0.99,
|
||||
n_envs: int = 1,
|
||||
):
|
||||
|
||||
super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
|
||||
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
|
||||
self.gae_lambda = gae_lambda
|
||||
self.gamma = gamma
|
||||
self.observations, self.actions, self.rewards, self.advantages = None, None, None, None
|
||||
|
|
@ -358,7 +366,7 @@ class RolloutBuffer(BaseBuffer):
|
|||
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
self.generator_ready = False
|
||||
super(RolloutBuffer, self).reset()
|
||||
super().reset()
|
||||
|
||||
def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None:
|
||||
"""
|
||||
|
|
@ -425,6 +433,9 @@ class RolloutBuffer(BaseBuffer):
|
|||
if isinstance(self.observation_space, spaces.Discrete):
|
||||
obs = obs.reshape((self.n_envs,) + self.obs_shape)
|
||||
|
||||
# Same reshape, for actions
|
||||
action = action.reshape((self.n_envs, self.action_dim))
|
||||
|
||||
self.observations[self.pos] = np.array(obs).copy()
|
||||
self.actions[self.pos] = np.array(action).copy()
|
||||
self.rewards[self.pos] = np.array(reward).copy()
|
||||
|
|
@ -483,7 +494,7 @@ class DictReplayBuffer(ReplayBuffer):
|
|||
:param buffer_size: Max number of element in the buffer
|
||||
:param observation_space: Observation space
|
||||
:param action_space: Action space
|
||||
:param device:
|
||||
:param device: PyTorch device
|
||||
:param n_envs: Number of parallel environments
|
||||
:param optimize_memory_usage: Enable a memory efficient variant
|
||||
Disabled for now (see https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702)
|
||||
|
|
@ -497,7 +508,7 @@ class DictReplayBuffer(ReplayBuffer):
|
|||
buffer_size: int,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
device: Union[th.device, str] = "cpu",
|
||||
device: Union[th.device, str] = "auto",
|
||||
n_envs: int = 1,
|
||||
optimize_memory_usage: bool = False,
|
||||
handle_timeout_termination: bool = True,
|
||||
|
|
@ -578,8 +589,7 @@ class DictReplayBuffer(ReplayBuffer):
|
|||
self.next_observations[key][self.pos] = np.array(next_obs[key]).copy()
|
||||
|
||||
# Same reshape, for actions
|
||||
if isinstance(self.action_space, spaces.Discrete):
|
||||
action = action.reshape((self.n_envs, self.action_dim))
|
||||
action = action.reshape((self.n_envs, self.action_dim))
|
||||
|
||||
self.actions[self.pos] = np.array(action).copy()
|
||||
self.rewards[self.pos] = np.array(reward).copy()
|
||||
|
|
@ -649,7 +659,7 @@ class DictRolloutBuffer(RolloutBuffer):
|
|||
:param buffer_size: Max number of element in the buffer
|
||||
:param observation_space: Observation space
|
||||
:param action_space: Action space
|
||||
:param device:
|
||||
:param device: PyTorch device
|
||||
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
||||
Equivalent to Monte-Carlo advantage estimate when set to 1.
|
||||
:param gamma: Discount factor
|
||||
|
|
@ -661,7 +671,7 @@ class DictRolloutBuffer(RolloutBuffer):
|
|||
buffer_size: int,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
device: Union[th.device, str] = "cpu",
|
||||
device: Union[th.device, str] = "auto",
|
||||
gae_lambda: float = 1,
|
||||
gamma: float = 0.99,
|
||||
n_envs: int = 1,
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ class BaseCallback(ABC):
|
|||
"""
|
||||
|
||||
def __init__(self, verbose: int = 0):
|
||||
super(BaseCallback, self).__init__()
|
||||
super().__init__()
|
||||
# The RL model
|
||||
self.model = None # type: Optional[base_class.BaseAlgorithm]
|
||||
# An alias for self.model.get_env(), the environment used for training
|
||||
|
|
@ -127,14 +127,14 @@ class EventCallback(BaseCallback):
|
|||
"""
|
||||
|
||||
def __init__(self, callback: Optional[BaseCallback] = None, verbose: int = 0):
|
||||
super(EventCallback, self).__init__(verbose=verbose)
|
||||
super().__init__(verbose=verbose)
|
||||
self.callback = callback
|
||||
# Give access to the parent
|
||||
if callback is not None:
|
||||
self.callback.parent = self
|
||||
|
||||
def init_callback(self, model: "base_class.BaseAlgorithm") -> None:
|
||||
super(EventCallback, self).init_callback(model)
|
||||
super().init_callback(model)
|
||||
if self.callback is not None:
|
||||
self.callback.init_callback(self.model)
|
||||
|
||||
|
|
@ -169,7 +169,7 @@ class CallbackList(BaseCallback):
|
|||
"""
|
||||
|
||||
def __init__(self, callbacks: List[BaseCallback]):
|
||||
super(CallbackList, self).__init__()
|
||||
super().__init__()
|
||||
assert isinstance(callbacks, list)
|
||||
self.callbacks = callbacks
|
||||
|
||||
|
|
@ -228,7 +228,7 @@ class CheckpointCallback(BaseCallback):
|
|||
"""
|
||||
|
||||
def __init__(self, save_freq: int, save_path: str, name_prefix: str = "rl_model", verbose: int = 0):
|
||||
super(CheckpointCallback, self).__init__(verbose)
|
||||
super().__init__(verbose)
|
||||
self.save_freq = save_freq
|
||||
self.save_path = save_path
|
||||
self.name_prefix = name_prefix
|
||||
|
|
@ -256,7 +256,7 @@ class ConvertCallback(BaseCallback):
|
|||
"""
|
||||
|
||||
def __init__(self, callback: Callable[[Dict[str, Any], Dict[str, Any]], bool], verbose: int = 0):
|
||||
super(ConvertCallback, self).__init__(verbose)
|
||||
super().__init__(verbose)
|
||||
self.callback = callback
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
|
|
@ -307,7 +307,7 @@ class EvalCallback(EventCallback):
|
|||
verbose: int = 1,
|
||||
warn: bool = True,
|
||||
):
|
||||
super(EvalCallback, self).__init__(callback_after_eval, verbose=verbose)
|
||||
super().__init__(callback_after_eval, verbose=verbose)
|
||||
|
||||
self.callback_on_new_best = callback_on_new_best
|
||||
if self.callback_on_new_best is not None:
|
||||
|
|
@ -380,12 +380,12 @@ class EvalCallback(EventCallback):
|
|||
if self.model.get_vec_normalize_env() is not None:
|
||||
try:
|
||||
sync_envs_normalization(self.training_env, self.eval_env)
|
||||
except AttributeError:
|
||||
except AttributeError as e:
|
||||
raise AssertionError(
|
||||
"Training and eval env are not wrapped the same way, "
|
||||
"see https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html#evalcallback "
|
||||
"and warning above."
|
||||
)
|
||||
) from e
|
||||
|
||||
# Reset success rate buffer
|
||||
self._is_success_buffer = []
|
||||
|
|
@ -480,7 +480,7 @@ class StopTrainingOnRewardThreshold(BaseCallback):
|
|||
"""
|
||||
|
||||
def __init__(self, reward_threshold: float, verbose: int = 0):
|
||||
super(StopTrainingOnRewardThreshold, self).__init__(verbose=verbose)
|
||||
super().__init__(verbose=verbose)
|
||||
self.reward_threshold = reward_threshold
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
|
|
@ -505,7 +505,7 @@ class EveryNTimesteps(EventCallback):
|
|||
"""
|
||||
|
||||
def __init__(self, n_steps: int, callback: BaseCallback):
|
||||
super(EveryNTimesteps, self).__init__(callback)
|
||||
super().__init__(callback)
|
||||
self.n_steps = n_steps
|
||||
self.last_time_trigger = 0
|
||||
|
||||
|
|
@ -528,7 +528,7 @@ class StopTrainingOnMaxEpisodes(BaseCallback):
|
|||
"""
|
||||
|
||||
def __init__(self, max_episodes: int, verbose: int = 0):
|
||||
super(StopTrainingOnMaxEpisodes, self).__init__(verbose=verbose)
|
||||
super().__init__(verbose=verbose)
|
||||
self.max_episodes = max_episodes
|
||||
self._total_max_episodes = max_episodes
|
||||
self.n_episodes = 0
|
||||
|
|
@ -573,7 +573,7 @@ class StopTrainingOnNoModelImprovement(BaseCallback):
|
|||
"""
|
||||
|
||||
def __init__(self, max_no_improvement_evals: int, min_evals: int = 0, verbose: int = 0):
|
||||
super(StopTrainingOnNoModelImprovement, self).__init__(verbose=verbose)
|
||||
super().__init__(verbose=verbose)
|
||||
self.max_no_improvement_evals = max_no_improvement_evals
|
||||
self.min_evals = min_evals
|
||||
self.last_best_mean_reward = -np.inf
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
|
|||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from torch import nn
|
||||
|
|
@ -16,7 +17,7 @@ class Distribution(ABC):
|
|||
"""Abstract base class for distributions."""
|
||||
|
||||
def __init__(self):
|
||||
super(Distribution, self).__init__()
|
||||
super().__init__()
|
||||
self.distribution = None
|
||||
|
||||
@abstractmethod
|
||||
|
|
@ -120,7 +121,7 @@ class DiagGaussianDistribution(Distribution):
|
|||
"""
|
||||
|
||||
def __init__(self, action_dim: int):
|
||||
super(DiagGaussianDistribution, self).__init__()
|
||||
super().__init__()
|
||||
self.action_dim = action_dim
|
||||
self.mean_actions = None
|
||||
self.log_std = None
|
||||
|
|
@ -201,13 +202,13 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
|
|||
"""
|
||||
|
||||
def __init__(self, action_dim: int, epsilon: float = 1e-6):
|
||||
super(SquashedDiagGaussianDistribution, self).__init__(action_dim)
|
||||
super().__init__(action_dim)
|
||||
# Avoid NaN (prevents division by zero or log of zero)
|
||||
self.epsilon = epsilon
|
||||
self.gaussian_actions = None
|
||||
|
||||
def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "SquashedDiagGaussianDistribution":
|
||||
super(SquashedDiagGaussianDistribution, self).proba_distribution(mean_actions, log_std)
|
||||
super().proba_distribution(mean_actions, log_std)
|
||||
return self
|
||||
|
||||
def log_prob(self, actions: th.Tensor, gaussian_actions: Optional[th.Tensor] = None) -> th.Tensor:
|
||||
|
|
@ -219,7 +220,7 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
|
|||
gaussian_actions = TanhBijector.inverse(actions)
|
||||
|
||||
# Log likelihood for a Gaussian distribution
|
||||
log_prob = super(SquashedDiagGaussianDistribution, self).log_prob(gaussian_actions)
|
||||
log_prob = super().log_prob(gaussian_actions)
|
||||
# Squash correction (from original SAC implementation)
|
||||
# this comes from the fact that tanh is bijective and differentiable
|
||||
log_prob -= th.sum(th.log(1 - actions**2 + self.epsilon), dim=1)
|
||||
|
|
@ -254,7 +255,7 @@ class CategoricalDistribution(Distribution):
|
|||
"""
|
||||
|
||||
def __init__(self, action_dim: int):
|
||||
super(CategoricalDistribution, self).__init__()
|
||||
super().__init__()
|
||||
self.action_dim = action_dim
|
||||
|
||||
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
|
||||
|
|
@ -305,7 +306,7 @@ class MultiCategoricalDistribution(Distribution):
|
|||
"""
|
||||
|
||||
def __init__(self, action_dims: List[int]):
|
||||
super(MultiCategoricalDistribution, self).__init__()
|
||||
super().__init__()
|
||||
self.action_dims = action_dims
|
||||
|
||||
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
|
||||
|
|
@ -360,7 +361,7 @@ class BernoulliDistribution(Distribution):
|
|||
"""
|
||||
|
||||
def __init__(self, action_dims: int):
|
||||
super(BernoulliDistribution, self).__init__()
|
||||
super().__init__()
|
||||
self.action_dims = action_dims
|
||||
|
||||
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
|
||||
|
|
@ -433,7 +434,7 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
learn_features: bool = False,
|
||||
epsilon: float = 1e-6,
|
||||
):
|
||||
super(StateDependentNoiseDistribution, self).__init__()
|
||||
super().__init__()
|
||||
self.action_dim = action_dim
|
||||
self.latent_sde_dim = None
|
||||
self.mean_actions = None
|
||||
|
|
@ -577,10 +578,10 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
return th.mm(latent_sde, self.exploration_mat)
|
||||
# Use batch matrix multiplication for efficient computation
|
||||
# (batch_size, n_features) -> (batch_size, 1, n_features)
|
||||
latent_sde = latent_sde.unsqueeze(1)
|
||||
latent_sde = latent_sde.unsqueeze(dim=1)
|
||||
# (batch_size, 1, n_actions)
|
||||
noise = th.bmm(latent_sde, self.exploration_matrices)
|
||||
return noise.squeeze(1)
|
||||
return noise.squeeze(dim=1)
|
||||
|
||||
def actions_from_params(
|
||||
self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor, deterministic: bool = False
|
||||
|
|
@ -597,7 +598,7 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
return actions, log_prob
|
||||
|
||||
|
||||
class TanhBijector(object):
|
||||
class TanhBijector:
|
||||
"""
|
||||
Bijective transformation of a probability distribution
|
||||
using a squashing function (tanh)
|
||||
|
|
@ -607,7 +608,7 @@ class TanhBijector(object):
|
|||
"""
|
||||
|
||||
def __init__(self, epsilon: float = 1e-6):
|
||||
super(TanhBijector, self).__init__()
|
||||
super().__init__()
|
||||
self.epsilon = epsilon
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -657,7 +658,6 @@ def make_proba_distribution(
|
|||
dist_kwargs = {}
|
||||
|
||||
if isinstance(action_space, spaces.Box):
|
||||
assert len(action_space.shape) == 1, "Error: the action space must be a vector"
|
||||
cls = StateDependentNoiseDistribution if use_sde else DiagGaussianDistribution
|
||||
return cls(get_action_dim(action_space), **dist_kwargs)
|
||||
elif isinstance(action_space, spaces.Discrete):
|
||||
|
|
@ -688,7 +688,7 @@ def kl_divergence(dist_true: Distribution, dist_pred: Distribution) -> th.Tensor
|
|||
# MultiCategoricalDistribution is not a PyTorch Distribution subclass
|
||||
# so we need to implement it ourselves!
|
||||
if isinstance(dist_pred, MultiCategoricalDistribution):
|
||||
assert dist_pred.action_dims == dist_true.action_dims, "Error: distributions must have the same input space"
|
||||
assert np.allclose(dist_pred.action_dims, dist_true.action_dims), "Error: distributions must have the same input space"
|
||||
return th.stack(
|
||||
[th.distributions.kl_divergence(p, q) for p, q in zip(dist_true.distribution, dist_pred.distribution)],
|
||||
dim=1,
|
||||
|
|
|
|||
|
|
@ -147,7 +147,7 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action
|
|||
try:
|
||||
_check_obs(obs[key], observation_space.spaces[key], "reset")
|
||||
except AssertionError as e:
|
||||
raise AssertionError(f"Error while checking key={key}: " + str(e))
|
||||
raise AssertionError(f"Error while checking key={key}: " + str(e)) from e
|
||||
else:
|
||||
_check_obs(obs, observation_space, "reset")
|
||||
|
||||
|
|
@ -166,7 +166,7 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action
|
|||
try:
|
||||
_check_obs(obs[key], observation_space.spaces[key], "step")
|
||||
except AssertionError as e:
|
||||
raise AssertionError(f"Error while checking key={key}: " + str(e))
|
||||
raise AssertionError(f"Error while checking key={key}: " + str(e)) from e
|
||||
|
||||
else:
|
||||
_check_obs(obs, observation_space, "step")
|
||||
|
|
@ -274,6 +274,11 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -
|
|||
"cf https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html"
|
||||
)
|
||||
|
||||
if isinstance(action_space, spaces.Box):
|
||||
assert np.all(
|
||||
np.isfinite(np.array([action_space.low, action_space.high]))
|
||||
), "Continuous action space must have a finite lower and upper bound"
|
||||
|
||||
if isinstance(action_space, spaces.Box) and action_space.dtype != np.dtype(np.float32):
|
||||
warnings.warn(
|
||||
f"Your action space has dtype {action_space.dtype}, we recommend using np.float32 to avoid cast errors."
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ class BitFlippingEnv(GoalEnv):
|
|||
image_obs_space: bool = False,
|
||||
channel_first: bool = True,
|
||||
):
|
||||
super(BitFlippingEnv, self).__init__()
|
||||
super().__init__()
|
||||
# Shape of the observation when using image space
|
||||
self.image_shape = (1, 36, 36) if channel_first else (36, 36, 1)
|
||||
# The achieved goal is determined by the current state
|
||||
|
|
@ -115,7 +115,7 @@ class BitFlippingEnv(GoalEnv):
|
|||
if self.discrete_obs_space:
|
||||
# The internal state is the binary representation of the
|
||||
# observed one
|
||||
return int(sum([state[i] * 2**i for i in range(len(state))]))
|
||||
return int(sum(state[i] * 2**i for i in range(len(state))))
|
||||
|
||||
if self.image_obs_space:
|
||||
size = np.prod(self.image_shape)
|
||||
|
|
@ -135,7 +135,7 @@ class BitFlippingEnv(GoalEnv):
|
|||
if isinstance(state, int):
|
||||
state = np.array(state).reshape(batch_size, -1)
|
||||
# Convert to binary representation
|
||||
state = (((state[:, :] & (1 << np.arange(len(self.state))))) > 0).astype(int)
|
||||
state = ((state[:, :] & (1 << np.arange(len(self.state)))) > 0).astype(int)
|
||||
elif self.image_obs_space:
|
||||
state = state.reshape(batch_size, -1)[:, : len(self.state)] / 255
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ class SimpleMultiObsEnv(gym.Env):
|
|||
discrete_actions: bool = True,
|
||||
channel_last: bool = True,
|
||||
):
|
||||
super(SimpleMultiObsEnv, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
self.vector_size = 5
|
||||
if channel_last:
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ try:
|
|||
except ImportError:
|
||||
SummaryWriter = None
|
||||
|
||||
|
||||
DEBUG = 10
|
||||
INFO = 20
|
||||
WARN = 30
|
||||
|
|
@ -24,7 +25,7 @@ ERROR = 40
|
|||
DISABLED = 50
|
||||
|
||||
|
||||
class Video(object):
|
||||
class Video:
|
||||
"""
|
||||
Video data class storing the video frames and the frame per seconds
|
||||
|
||||
|
|
@ -37,7 +38,7 @@ class Video(object):
|
|||
self.fps = fps
|
||||
|
||||
|
||||
class Figure(object):
|
||||
class Figure:
|
||||
"""
|
||||
Figure data class storing a matplotlib figure and whether to close the figure after logging it
|
||||
|
||||
|
|
@ -50,7 +51,7 @@ class Figure(object):
|
|||
self.close = close
|
||||
|
||||
|
||||
class Image(object):
|
||||
class Image:
|
||||
"""
|
||||
Image data class storing an image and data format
|
||||
|
||||
|
|
@ -80,13 +81,13 @@ class FormatUnsupportedError(NotImplementedError):
|
|||
format_str = f"formats {', '.join(unsupported_formats)} are"
|
||||
else:
|
||||
format_str = f"format {unsupported_formats[0]} is"
|
||||
super(FormatUnsupportedError, self).__init__(
|
||||
super().__init__(
|
||||
f"The {format_str} not supported for the {value_description} value logged.\n"
|
||||
f"You can exclude formats via the `exclude` parameter of the logger's `record` function."
|
||||
)
|
||||
|
||||
|
||||
class KVWriter(object):
|
||||
class KVWriter:
|
||||
"""
|
||||
Key Value writer
|
||||
"""
|
||||
|
|
@ -108,7 +109,7 @@ class KVWriter(object):
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
class SeqWriter(object):
|
||||
class SeqWriter:
|
||||
"""
|
||||
sequence writer
|
||||
"""
|
||||
|
|
@ -246,12 +247,13 @@ def filter_excluded_keys(
|
|||
|
||||
|
||||
class JSONOutputFormat(KVWriter):
|
||||
def __init__(self, filename: str):
|
||||
"""
|
||||
log to a file, in the JSON format
|
||||
"""
|
||||
Log to a file, in the JSON format
|
||||
|
||||
:param filename: the file to write the log to
|
||||
"""
|
||||
:param filename: the file to write the log to
|
||||
"""
|
||||
|
||||
def __init__(self, filename: str):
|
||||
self.file = open(filename, "wt")
|
||||
|
||||
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
|
||||
|
|
@ -287,13 +289,13 @@ class JSONOutputFormat(KVWriter):
|
|||
|
||||
|
||||
class CSVOutputFormat(KVWriter):
|
||||
"""
|
||||
Log to a file, in a CSV format
|
||||
|
||||
:param filename: the file to write the log to
|
||||
"""
|
||||
|
||||
def __init__(self, filename: str):
|
||||
"""
|
||||
log to a file, in a CSV format
|
||||
|
||||
:param filename: the file to write the log to
|
||||
"""
|
||||
|
||||
self.file = open(filename, "w+t")
|
||||
self.keys = []
|
||||
self.separator = ","
|
||||
|
|
@ -351,12 +353,13 @@ class CSVOutputFormat(KVWriter):
|
|||
|
||||
|
||||
class TensorBoardOutputFormat(KVWriter):
|
||||
def __init__(self, folder: str):
|
||||
"""
|
||||
Dumps key/value pairs into TensorBoard's numeric format.
|
||||
"""
|
||||
Dumps key/value pairs into TensorBoard's numeric format.
|
||||
|
||||
:param folder: the folder to write the log to
|
||||
"""
|
||||
:param folder: the folder to write the log to
|
||||
"""
|
||||
|
||||
def __init__(self, folder: str):
|
||||
assert SummaryWriter is not None, "tensorboard is not installed, you can use " "pip install tensorboard to do so"
|
||||
self.writer = SummaryWriter(log_dir=folder)
|
||||
|
||||
|
|
@ -427,7 +430,7 @@ def make_output_format(_format: str, log_dir: str, log_suffix: str = "") -> KVWr
|
|||
# ================================================================
|
||||
|
||||
|
||||
class Logger(object):
|
||||
class Logger:
|
||||
"""
|
||||
The logger class.
|
||||
|
||||
|
|
@ -623,7 +626,7 @@ def read_json(filename: str) -> pandas.DataFrame:
|
|||
:return: the data in the json
|
||||
"""
|
||||
data = []
|
||||
with open(filename, "rt") as file_handler:
|
||||
with open(filename) as file_handler:
|
||||
for line in file_handler:
|
||||
data.append(json.loads(line))
|
||||
return pandas.DataFrame(data)
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ class Monitor(gym.Wrapper):
|
|||
reset_keywords: Tuple[str, ...] = (),
|
||||
info_keywords: Tuple[str, ...] = (),
|
||||
):
|
||||
super(Monitor, self).__init__(env=env)
|
||||
super().__init__(env=env)
|
||||
self.t_start = time.time()
|
||||
if filename is not None:
|
||||
self.results_writer = ResultsWriter(
|
||||
|
|
@ -110,7 +110,7 @@ class Monitor(gym.Wrapper):
|
|||
"""
|
||||
Closes the environment
|
||||
"""
|
||||
super(Monitor, self).close()
|
||||
super().close()
|
||||
if self.results_writer is not None:
|
||||
self.results_writer.close()
|
||||
|
||||
|
|
@ -224,7 +224,7 @@ def load_results(path: str) -> pandas.DataFrame:
|
|||
raise LoadMonitorResultsError(f"No monitor files of the form *{Monitor.EXT} found in {path}")
|
||||
data_frames, headers = [], []
|
||||
for file_name in monitor_files:
|
||||
with open(file_name, "rt") as file_handler:
|
||||
with open(file_name) as file_handler:
|
||||
first_line = file_handler.readline()
|
||||
assert first_line[0] == "#"
|
||||
header = json.loads(first_line[1:])
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ class ActionNoise(ABC):
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(ActionNoise, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
|
|
@ -35,7 +35,7 @@ class NormalActionNoise(ActionNoise):
|
|||
def __init__(self, mean: np.ndarray, sigma: np.ndarray):
|
||||
self._mu = mean
|
||||
self._sigma = sigma
|
||||
super(NormalActionNoise, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
def __call__(self) -> np.ndarray:
|
||||
return np.random.normal(self._mu, self._sigma)
|
||||
|
|
@ -72,7 +72,7 @@ class OrnsteinUhlenbeckActionNoise(ActionNoise):
|
|||
self.initial_noise = initial_noise
|
||||
self.noise_prev = np.zeros_like(self._mu)
|
||||
self.reset()
|
||||
super(OrnsteinUhlenbeckActionNoise, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
def __call__(self) -> np.ndarray:
|
||||
noise = (
|
||||
|
|
@ -105,8 +105,8 @@ class VectorizedActionNoise(ActionNoise):
|
|||
try:
|
||||
self.n_envs = int(n_envs)
|
||||
assert self.n_envs > 0
|
||||
except (TypeError, AssertionError):
|
||||
raise ValueError(f"Expected n_envs={n_envs} to be positive integer greater than 0")
|
||||
except (TypeError, AssertionError) as e:
|
||||
raise ValueError(f"Expected n_envs={n_envs} to be positive integer greater than 0") from e
|
||||
|
||||
self.base_noise = base_noise
|
||||
self.noises = [copy.deepcopy(self.base_noise) for _ in range(n_envs)]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import io
|
||||
import pathlib
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
|
|
@ -28,7 +29,6 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
:param policy: Policy object
|
||||
:param env: The environment to learn from
|
||||
(if registered in Gym, can be str. Can be None for loading trained models)
|
||||
:param policy_base: The base policy used by this method
|
||||
:param learning_rate: learning rate for the optimizer,
|
||||
it can be a function of the current progress remaining (from 1 to 0)
|
||||
:param buffer_size: size of the replay buffer
|
||||
|
|
@ -76,7 +76,6 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
self,
|
||||
policy: Type[BasePolicy],
|
||||
env: Union[GymEnv, str],
|
||||
policy_base: Type[BasePolicy],
|
||||
learning_rate: Union[float, Schedule],
|
||||
buffer_size: int = 1_000_000, # 1e6
|
||||
learning_starts: int = 100,
|
||||
|
|
@ -104,10 +103,9 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None,
|
||||
):
|
||||
|
||||
super(OffPolicyAlgorithm, self).__init__(
|
||||
super().__init__(
|
||||
policy=policy,
|
||||
env=env,
|
||||
policy_base=policy_base,
|
||||
learning_rate=learning_rate,
|
||||
policy_kwargs=policy_kwargs,
|
||||
tensorboard_log=tensorboard_log,
|
||||
|
|
@ -160,8 +158,10 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
|
||||
try:
|
||||
train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1]))
|
||||
except ValueError:
|
||||
raise ValueError(f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!")
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!"
|
||||
) from e
|
||||
|
||||
if not isinstance(train_freq[0], int):
|
||||
raise ValueError(f"The frequency of `train_freq` must be an integer and not {train_freq[0]}")
|
||||
|
|
@ -428,8 +428,8 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
"""
|
||||
Write log.
|
||||
"""
|
||||
time_elapsed = time.time() - self.start_time
|
||||
fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time_elapsed + 1e-8))
|
||||
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
|
||||
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
|
||||
self.logger.record("time/episodes", self._episode_num, exclude="tensorboard")
|
||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
||||
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import sys
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
|
|
@ -8,7 +9,7 @@ import torch as th
|
|||
from stable_baselines3.common.base_class import BaseAlgorithm
|
||||
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy
|
||||
from stable_baselines3.common.policies import ActorCriticPolicy
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import obs_as_tensor, safe_mean
|
||||
from stable_baselines3.common.vec_env import VecEnv
|
||||
|
|
@ -34,7 +35,6 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
instead of action noise exploration (default: False)
|
||||
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
|
||||
Default: -1 (only sample at the beginning of the rollout)
|
||||
:param policy_base: The base policy used by this method
|
||||
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||
:param create_eval_env: Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
|
|
@ -62,7 +62,6 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
max_grad_norm: float,
|
||||
use_sde: bool,
|
||||
sde_sample_freq: int,
|
||||
policy_base: Type[BasePolicy] = ActorCriticPolicy,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
create_eval_env: bool = False,
|
||||
monitor_wrapper: bool = True,
|
||||
|
|
@ -74,10 +73,9 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None,
|
||||
):
|
||||
|
||||
super(OnPolicyAlgorithm, self).__init__(
|
||||
super().__init__(
|
||||
policy=policy,
|
||||
env=env,
|
||||
policy_base=policy_base,
|
||||
learning_rate=learning_rate,
|
||||
policy_kwargs=policy_kwargs,
|
||||
verbose=verbose,
|
||||
|
|
@ -257,13 +255,14 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
|
||||
# Display training infos
|
||||
if log_interval is not None and iteration % log_interval == 0:
|
||||
fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time.time() - self.start_time))
|
||||
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
|
||||
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
|
||||
self.logger.record("time/iterations", iteration, exclude="tensorboard")
|
||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
||||
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
|
||||
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
|
||||
self.logger.record("time/fps", fps)
|
||||
self.logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard")
|
||||
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
|
||||
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
|
||||
self.logger.dump(step=self.num_timesteps)
|
||||
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ class BaseModel(nn.Module, ABC):
|
|||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super(BaseModel, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
if optimizer_kwargs is None:
|
||||
optimizer_kwargs = {}
|
||||
|
|
@ -267,7 +267,7 @@ class BasePolicy(BaseModel):
|
|||
"""
|
||||
|
||||
def __init__(self, *args, squash_output: bool = False, **kwargs):
|
||||
super(BasePolicy, self).__init__(*args, **kwargs)
|
||||
super().__init__(*args, **kwargs)
|
||||
self._squash_output = squash_output
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -336,8 +336,8 @@ class BasePolicy(BaseModel):
|
|||
|
||||
with th.no_grad():
|
||||
actions = self._predict(observation, deterministic=deterministic)
|
||||
# Convert to numpy
|
||||
actions = actions.cpu().numpy()
|
||||
# Convert to numpy, and reshape to the original action shape
|
||||
actions = actions.cpu().numpy().reshape((-1,) + self.action_space.shape)
|
||||
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
if self.squash_output:
|
||||
|
|
@ -350,7 +350,7 @@ class BasePolicy(BaseModel):
|
|||
|
||||
# Remove batch dimension if needed
|
||||
if not vectorized_env:
|
||||
actions = actions[0]
|
||||
actions = actions.squeeze(axis=0)
|
||||
|
||||
return actions, state
|
||||
|
||||
|
|
@ -437,7 +437,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
if optimizer_class == th.optim.Adam:
|
||||
optimizer_kwargs["eps"] = 1e-5
|
||||
|
||||
super(ActorCriticPolicy, self).__init__(
|
||||
super().__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
features_extractor_class,
|
||||
|
|
@ -592,6 +592,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
distribution = self._get_action_dist_from_latent(latent_pi)
|
||||
actions = distribution.get_actions(deterministic=deterministic)
|
||||
log_prob = distribution.log_prob(actions)
|
||||
actions = actions.reshape((-1,) + self.action_space.shape)
|
||||
return actions, values, log_prob
|
||||
|
||||
def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distribution:
|
||||
|
|
@ -724,7 +725,7 @@ class ActorCriticCnnPolicy(ActorCriticPolicy):
|
|||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super(ActorCriticCnnPolicy, self).__init__(
|
||||
super().__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
lr_schedule,
|
||||
|
|
@ -799,7 +800,7 @@ class MultiInputActorCriticPolicy(ActorCriticPolicy):
|
|||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super(MultiInputActorCriticPolicy, self).__init__(
|
||||
super().__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
lr_schedule,
|
||||
|
|
@ -895,68 +896,3 @@ class ContinuousCritic(BaseModel):
|
|||
with th.no_grad():
|
||||
features = self.extract_features(obs)
|
||||
return self.q_networks[0](th.cat([features, actions], dim=1))
|
||||
|
||||
|
||||
_policy_registry = dict() # type: Dict[Type[BasePolicy], Dict[str, Type[BasePolicy]]]
|
||||
|
||||
|
||||
def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[BasePolicy]:
|
||||
"""
|
||||
Returns the registered policy from the base type and name.
|
||||
See `register_policy` for registering policies and explanation.
|
||||
|
||||
:param base_policy_type: the base policy class
|
||||
:param name: the policy name
|
||||
:return: the policy
|
||||
"""
|
||||
if base_policy_type not in _policy_registry:
|
||||
raise KeyError(f"Error: the policy type {base_policy_type} is not registered!")
|
||||
if name not in _policy_registry[base_policy_type]:
|
||||
raise KeyError(
|
||||
f"Error: unknown policy type {name},"
|
||||
f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!"
|
||||
)
|
||||
return _policy_registry[base_policy_type][name]
|
||||
|
||||
|
||||
def register_policy(name: str, policy: Type[BasePolicy]) -> None:
|
||||
"""
|
||||
Register a policy, so it can be called using its name.
|
||||
e.g. SAC('MlpPolicy', ...) instead of SAC(MlpPolicy, ...).
|
||||
|
||||
The goal here is to standardize policy naming, e.g.
|
||||
all algorithms can call upon "MlpPolicy" or "CnnPolicy",
|
||||
and they receive respective policies that work for them.
|
||||
Consider following:
|
||||
|
||||
OnlinePolicy
|
||||
-- OnlineMlpPolicy ("MlpPolicy")
|
||||
-- OnlineCnnPolicy ("CnnPolicy")
|
||||
OfflinePolicy
|
||||
-- OfflineMlpPolicy ("MlpPolicy")
|
||||
-- OfflineCnnPolicy ("CnnPolicy")
|
||||
|
||||
Two policies have name "MlpPolicy" and two have "CnnPolicy".
|
||||
In `get_policy_from_name`, the parent class (e.g. OnlinePolicy)
|
||||
is given and used to select and return the correct policy.
|
||||
|
||||
:param name: the policy name
|
||||
:param policy: the policy class
|
||||
"""
|
||||
sub_class = None
|
||||
for cls in BasePolicy.__subclasses__():
|
||||
if issubclass(policy, cls):
|
||||
sub_class = cls
|
||||
break
|
||||
if sub_class is None:
|
||||
raise ValueError(f"Error: the policy {policy} is not of any known subclasses of BasePolicy!")
|
||||
|
||||
if sub_class not in _policy_registry:
|
||||
_policy_registry[sub_class] = {}
|
||||
if name in _policy_registry[sub_class]:
|
||||
# Check if the registered policy is same
|
||||
# we try to register. If not so,
|
||||
# do not override and complain.
|
||||
if _policy_registry[sub_class][name] != policy:
|
||||
raise ValueError(f"Error: the name {name} is already registered for a different policy, will not override.")
|
||||
_policy_registry[sub_class][name] = policy
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import Tuple, Union
|
|||
import numpy as np
|
||||
|
||||
|
||||
class RunningMeanStd(object):
|
||||
class RunningMeanStd:
|
||||
def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()):
|
||||
"""
|
||||
Calulates the running mean and std of a data stream
|
||||
|
|
|
|||
|
|
@ -206,8 +206,8 @@ def open_path(path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verb
|
|||
mode = mode.lower()
|
||||
try:
|
||||
mode = {"write": "w", "read": "r", "w": "w", "r": "r"}[mode]
|
||||
except KeyError:
|
||||
raise ValueError("Expected mode to be either 'w' or 'r'.")
|
||||
except KeyError as e:
|
||||
raise ValueError("Expected mode to be either 'w' or 'r'.") from e
|
||||
if ("w" == mode) and not path.writable() or ("r" == mode) and not path.readable():
|
||||
e1 = "writable" if "w" == mode else "readable"
|
||||
raise ValueError(f"Expected a {e1} file.")
|
||||
|
|
@ -441,7 +441,7 @@ def load_from_zip_file(
|
|||
# State dicts. Store into params dictionary
|
||||
# with same name as in .zip file (without .pth)
|
||||
params[os.path.splitext(file_path)[0]] = th_object
|
||||
except zipfile.BadZipFile:
|
||||
except zipfile.BadZipFile as e:
|
||||
# load_path wasn't a zip file
|
||||
raise ValueError(f"Error: the file {load_path} wasn't a zip-file")
|
||||
raise ValueError(f"Error: the file {load_path} wasn't a zip-file") from e
|
||||
return data, params, pytorch_variables
|
||||
|
|
|
|||
|
|
@ -54,21 +54,21 @@ class RMSpropTFLike(Optimizer):
|
|||
centered: bool = False,
|
||||
):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
raise ValueError(f"Invalid epsilon value: {eps}")
|
||||
if not 0.0 <= momentum:
|
||||
raise ValueError("Invalid momentum value: {}".format(momentum))
|
||||
raise ValueError(f"Invalid momentum value: {momentum}")
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||
if not 0.0 <= alpha:
|
||||
raise ValueError("Invalid alpha value: {}".format(alpha))
|
||||
raise ValueError(f"Invalid alpha value: {alpha}")
|
||||
|
||||
defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay)
|
||||
super(RMSpropTFLike, self).__init__(params, defaults)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state: Dict[str, Any]) -> None:
|
||||
super(RMSpropTFLike, self).__setstate__(state)
|
||||
super().__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault("momentum", 0)
|
||||
group.setdefault("centered", False)
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ class BaseFeaturesExtractor(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, observation_space: gym.Space, features_dim: int = 0):
|
||||
super(BaseFeaturesExtractor, self).__init__()
|
||||
super().__init__()
|
||||
assert features_dim > 0
|
||||
self._observation_space = observation_space
|
||||
self._features_dim = features_dim
|
||||
|
|
@ -41,7 +41,7 @@ class FlattenExtractor(BaseFeaturesExtractor):
|
|||
"""
|
||||
|
||||
def __init__(self, observation_space: gym.Space):
|
||||
super(FlattenExtractor, self).__init__(observation_space, get_flattened_obs_dim(observation_space))
|
||||
super().__init__(observation_space, get_flattened_obs_dim(observation_space))
|
||||
self.flatten = nn.Flatten()
|
||||
|
||||
def forward(self, observations: th.Tensor) -> th.Tensor:
|
||||
|
|
@ -50,7 +50,7 @@ class FlattenExtractor(BaseFeaturesExtractor):
|
|||
|
||||
class NatureCNN(BaseFeaturesExtractor):
|
||||
"""
|
||||
CNN from DQN nature paper:
|
||||
CNN from DQN Nature paper:
|
||||
Mnih, Volodymyr, et al.
|
||||
"Human-level control through deep reinforcement learning."
|
||||
Nature 518.7540 (2015): 529-533.
|
||||
|
|
@ -61,7 +61,7 @@ class NatureCNN(BaseFeaturesExtractor):
|
|||
"""
|
||||
|
||||
def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512):
|
||||
super(NatureCNN, self).__init__(observation_space, features_dim)
|
||||
super().__init__(observation_space, features_dim)
|
||||
# We assume CxHxW images (channels first)
|
||||
# Re-ordering will be done by pre-preprocessing or wrapper
|
||||
assert is_image_space(observation_space, check_channels=False), (
|
||||
|
|
@ -169,7 +169,7 @@ class MlpExtractor(nn.Module):
|
|||
activation_fn: Type[nn.Module],
|
||||
device: Union[th.device, str] = "auto",
|
||||
):
|
||||
super(MlpExtractor, self).__init__()
|
||||
super().__init__()
|
||||
device = get_device(device)
|
||||
shared_net, policy_net, value_net = [], [], []
|
||||
policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network
|
||||
|
|
@ -250,7 +250,7 @@ class CombinedExtractor(BaseFeaturesExtractor):
|
|||
|
||||
def __init__(self, observation_space: gym.spaces.Dict, cnn_output_dim: int = 256):
|
||||
# TODO we do not know features-dim here before going over all the items, so put something there. This is dirty!
|
||||
super(CombinedExtractor, self).__init__(observation_space, features_dim=1)
|
||||
super().__init__(observation_space, features_dim=1)
|
||||
|
||||
extractors = {}
|
||||
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ class ReplayBufferSamples(NamedTuple):
|
|||
class DictReplayBufferSamples(ReplayBufferSamples):
|
||||
observations: TensorDict
|
||||
actions: th.Tensor
|
||||
next_observations: th.Tensor
|
||||
next_observations: TensorDict
|
||||
dones: th.Tensor
|
||||
rewards: th.Tensor
|
||||
|
||||
|
|
|
|||
|
|
@ -154,15 +154,18 @@ def get_device(device: Union[th.device, str] = "auto") -> th.device:
|
|||
return device
|
||||
|
||||
|
||||
def get_latest_run_id(log_path: Optional[str] = None, log_name: str = "") -> int:
|
||||
def get_latest_run_id(log_path: str = "", log_name: str = "") -> int:
|
||||
"""
|
||||
Returns the latest run number for the given log name and log path,
|
||||
by finding the greatest number in the directories.
|
||||
|
||||
:param log_path: Path to the log folder containing several runs.
|
||||
:param log_name: Name of the experiment. Each run is stored
|
||||
in a folder named ``log_name_1``, ``log_name_2``, ...
|
||||
:return: latest run number
|
||||
"""
|
||||
max_run_id = 0
|
||||
for path in glob.glob(f"{log_path}/{log_name}_[0-9]*"):
|
||||
for path in glob.glob(os.path.join(log_path, f"{glob.escape(log_name)}_[0-9]*")):
|
||||
file_name = path.split(os.sep)[-1]
|
||||
ext = file_name.split("_")[-1]
|
||||
if log_name == "_".join(file_name.split("_")[:-1]) and ext.isdigit() and int(ext) > max_run_id:
|
||||
|
|
|
|||
|
|
@ -66,7 +66,9 @@ def sync_envs_normalization(env: "GymEnv", eval_env: "GymEnv") -> None:
|
|||
env_tmp, eval_env_tmp = env, eval_env
|
||||
while isinstance(env_tmp, VecEnvWrapper):
|
||||
if isinstance(env_tmp, VecNormalize):
|
||||
eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms)
|
||||
# Only synchronize if observation normalization exists
|
||||
if hasattr(env_tmp, "obs_rms"):
|
||||
eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms)
|
||||
eval_env_tmp.ret_rms = deepcopy(env_tmp.ret_rms)
|
||||
env_tmp = env_tmp.venv
|
||||
eval_env_tmp = eval_env_tmp.venv
|
||||
|
|
|
|||
|
|
@ -305,7 +305,7 @@ class VecEnvWrapper(VecEnv):
|
|||
own_class = f"{type(self).__module__}.{type(self).__name__}"
|
||||
error_str = (
|
||||
f"Error: Recursive attribute lookup for {name} from {own_class} is "
|
||||
"ambiguous and hides attribute from {blocked_class}"
|
||||
f"ambiguous and hides attribute from {blocked_class}"
|
||||
)
|
||||
raise AttributeError(error_str)
|
||||
|
||||
|
|
|
|||
|
|
@ -51,7 +51,9 @@ class DummyVecEnv(VecEnv):
|
|||
return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos))
|
||||
|
||||
def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
|
||||
seeds = list()
|
||||
if seed is None:
|
||||
seed = np.random.randint(0, 2**32 - 1)
|
||||
seeds = []
|
||||
for idx, env in enumerate(self.envs):
|
||||
seeds.append(env.seed(seed + idx))
|
||||
return seeds
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from gym import spaces
|
|||
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
|
||||
|
||||
|
||||
class StackedObservations(object):
|
||||
class StackedObservations:
|
||||
"""
|
||||
Frame stacking wrapper for data.
|
||||
|
||||
|
|
|
|||
|
|
@ -123,6 +123,8 @@ class SubprocVecEnv(VecEnv):
|
|||
return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos
|
||||
|
||||
def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
|
||||
if seed is None:
|
||||
seed = np.random.randint(0, 2**32 - 1)
|
||||
for idx, remote in enumerate(self.remotes):
|
||||
remote.send(("seed", seed + idx))
|
||||
return [remote.recv() for remote in self.remotes]
|
||||
|
|
@ -215,6 +217,6 @@ def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: gym.space
|
|||
elif isinstance(space, gym.spaces.Tuple):
|
||||
assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space"
|
||||
obs_len = len(space.spaces)
|
||||
return tuple((np.stack([o[i] for o in obs]) for i in range(obs_len)))
|
||||
return tuple(np.stack([o[i] for o in obs]) for i in range(obs_len))
|
||||
else:
|
||||
return np.stack(obs)
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ def dict_to_obs(obs_space: gym.spaces.Space, obs_dict: Dict[Any, np.ndarray]) ->
|
|||
return obs_dict
|
||||
elif isinstance(obs_space, gym.spaces.Tuple):
|
||||
assert len(obs_dict) == len(obs_space.spaces), "size of observation does not match size of observation space"
|
||||
return tuple((obs_dict[i] for i in range(len(obs_space.spaces))))
|
||||
return tuple(obs_dict[i] for i in range(len(obs_space.spaces)))
|
||||
else:
|
||||
assert set(obs_dict.keys()) == {None}, "multiple observation keys for unstructured observation space"
|
||||
return obs_dict[None]
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ class VecTransposeImage(VecEnvWrapper):
|
|||
self.skip = skip
|
||||
# Do nothing
|
||||
if skip:
|
||||
super(VecTransposeImage, self).__init__(venv)
|
||||
super().__init__(venv)
|
||||
return
|
||||
|
||||
if isinstance(venv.observation_space, spaces.dict.Dict):
|
||||
|
|
@ -39,7 +39,7 @@ class VecTransposeImage(VecEnvWrapper):
|
|||
observation_space.spaces[key] = self.transpose_space(space, key)
|
||||
else:
|
||||
observation_space = self.transpose_space(venv.observation_space)
|
||||
super(VecTransposeImage, self).__init__(venv, observation_space=observation_space)
|
||||
super().__init__(venv, observation_space=observation_space)
|
||||
|
||||
@staticmethod
|
||||
def transpose_space(observation_space: spaces.Box, key: str = "") -> spaces.Box:
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ class DDPG(TD3):
|
|||
_init_setup_model: bool = True,
|
||||
):
|
||||
|
||||
super(DDPG, self).__init__(
|
||||
super().__init__(
|
||||
policy=policy,
|
||||
env=env,
|
||||
learning_rate=learning_rate,
|
||||
|
|
@ -127,7 +127,7 @@ class DDPG(TD3):
|
|||
reset_num_timesteps: bool = True,
|
||||
) -> OffPolicyAlgorithm:
|
||||
|
||||
return super(DDPG, self).learn(
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
callback=callback,
|
||||
log_interval=log_interval,
|
||||
|
|
|
|||
|
|
@ -8,10 +8,11 @@ from torch.nn import functional as F
|
|||
|
||||
from stable_baselines3.common.buffers import ReplayBuffer
|
||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.preprocessing import maybe_transpose
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update
|
||||
from stable_baselines3.dqn.policies import DQNPolicy
|
||||
from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy
|
||||
|
||||
|
||||
class DQN(OffPolicyAlgorithm):
|
||||
|
|
@ -19,7 +20,7 @@ class DQN(OffPolicyAlgorithm):
|
|||
Deep Q-Network (DQN)
|
||||
|
||||
Paper: https://arxiv.org/abs/1312.5602, https://www.nature.com/articles/nature14236
|
||||
Default hyperparameters are taken from the nature paper,
|
||||
Default hyperparameters are taken from the Nature paper,
|
||||
except for the optimizer and learning rate that were taken from Stable Baselines defaults.
|
||||
|
||||
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
|
||||
|
|
@ -59,6 +60,12 @@ class DQN(OffPolicyAlgorithm):
|
|||
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
|
||||
policy_aliases: Dict[str, Type[BasePolicy]] = {
|
||||
"MlpPolicy": MlpPolicy,
|
||||
"CnnPolicy": CnnPolicy,
|
||||
"MultiInputPolicy": MultiInputPolicy,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Union[str, Type[DQNPolicy]],
|
||||
|
|
@ -88,10 +95,9 @@ class DQN(OffPolicyAlgorithm):
|
|||
_init_setup_model: bool = True,
|
||||
):
|
||||
|
||||
super(DQN, self).__init__(
|
||||
super().__init__(
|
||||
policy,
|
||||
env,
|
||||
DQNPolicy,
|
||||
learning_rate,
|
||||
buffer_size,
|
||||
learning_starts,
|
||||
|
|
@ -132,7 +138,7 @@ class DQN(OffPolicyAlgorithm):
|
|||
self._setup_model()
|
||||
|
||||
def _setup_model(self) -> None:
|
||||
super(DQN, self)._setup_model()
|
||||
super()._setup_model()
|
||||
self._create_aliases()
|
||||
self.exploration_schedule = get_linear_fn(
|
||||
self.exploration_initial_eps,
|
||||
|
|
@ -255,7 +261,7 @@ class DQN(OffPolicyAlgorithm):
|
|||
reset_num_timesteps: bool = True,
|
||||
) -> OffPolicyAlgorithm:
|
||||
|
||||
return super(DQN, self).learn(
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
callback=callback,
|
||||
log_interval=log_interval,
|
||||
|
|
@ -268,7 +274,7 @@ class DQN(OffPolicyAlgorithm):
|
|||
)
|
||||
|
||||
def _excluded_save_params(self) -> List[str]:
|
||||
return super(DQN, self)._excluded_save_params() + ["q_net", "q_net_target"]
|
||||
return super()._excluded_save_params() + ["q_net", "q_net_target"]
|
||||
|
||||
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
|
||||
state_dicts = ["policy", "policy.optimizer"]
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import gym
|
|||
import torch as th
|
||||
from torch import nn
|
||||
|
||||
from stable_baselines3.common.policies import BasePolicy, register_policy
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.torch_layers import (
|
||||
BaseFeaturesExtractor,
|
||||
CombinedExtractor,
|
||||
|
|
@ -37,7 +37,7 @@ class QNetwork(BasePolicy):
|
|||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
normalize_images: bool = True,
|
||||
):
|
||||
super(QNetwork, self).__init__(
|
||||
super().__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
features_extractor=features_extractor,
|
||||
|
|
@ -118,7 +118,7 @@ class DQNPolicy(BasePolicy):
|
|||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super(DQNPolicy, self).__init__(
|
||||
super().__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
features_extractor_class,
|
||||
|
|
@ -239,7 +239,7 @@ class CnnPolicy(DQNPolicy):
|
|||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super(CnnPolicy, self).__init__(
|
||||
super().__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
lr_schedule,
|
||||
|
|
@ -284,7 +284,7 @@ class MultiInputPolicy(DQNPolicy):
|
|||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super(MultiInputPolicy, self).__init__(
|
||||
super().__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
lr_schedule,
|
||||
|
|
@ -296,8 +296,3 @@ class MultiInputPolicy(DQNPolicy):
|
|||
optimizer_class,
|
||||
optimizer_kwargs,
|
||||
)
|
||||
|
||||
|
||||
register_policy("MlpPolicy", MlpPolicy)
|
||||
register_policy("CnnPolicy", CnnPolicy)
|
||||
register_policy("MultiInputPolicy", MultiInputPolicy)
|
||||
|
|
|
|||
|
|
@ -28,13 +28,13 @@ def get_time_limit(env: VecEnv, current_max_episode_length: Optional[int]) -> in
|
|||
if current_max_episode_length is None:
|
||||
raise AttributeError
|
||||
# if not available check if a valid value was passed as an argument
|
||||
except AttributeError:
|
||||
except AttributeError as e:
|
||||
raise ValueError(
|
||||
"The max episode length could not be inferred.\n"
|
||||
"You must specify a `max_episode_steps` when registering the environment,\n"
|
||||
"use a `gym.wrappers.TimeLimit` wrapper "
|
||||
"or pass `max_episode_length` to the model constructor"
|
||||
)
|
||||
) from e
|
||||
return current_max_episode_length
|
||||
|
||||
|
||||
|
|
@ -73,7 +73,7 @@ class HerReplayBuffer(DictReplayBuffer):
|
|||
self,
|
||||
env: VecEnv,
|
||||
buffer_size: int,
|
||||
device: Union[th.device, str] = "cpu",
|
||||
device: Union[th.device, str] = "auto",
|
||||
replay_buffer: Optional[DictReplayBuffer] = None,
|
||||
max_episode_length: Optional[int] = None,
|
||||
n_sampled_goal: int = 4,
|
||||
|
|
@ -82,7 +82,7 @@ class HerReplayBuffer(DictReplayBuffer):
|
|||
handle_timeout_termination: bool = True,
|
||||
):
|
||||
|
||||
super(HerReplayBuffer, self).__init__(buffer_size, env.observation_space, env.action_space, device, env.num_envs)
|
||||
super().__init__(buffer_size, env.observation_space, env.action_space, device, env.num_envs)
|
||||
|
||||
# convert goal_selection_strategy into GoalSelectionStrategy if string
|
||||
if isinstance(goal_selection_strategy, str):
|
||||
|
|
@ -252,7 +252,7 @@ class HerReplayBuffer(DictReplayBuffer):
|
|||
elif self.goal_selection_strategy == GoalSelectionStrategy.FUTURE:
|
||||
# replay with random state which comes from the same episode and was observed after current transition
|
||||
transitions_indices = np.random.randint(
|
||||
transitions_indices[her_indices] + 1, self.episode_lengths[her_episode_indices]
|
||||
transitions_indices[her_indices], self.episode_lengths[her_episode_indices]
|
||||
)
|
||||
|
||||
elif self.goal_selection_strategy == GoalSelectionStrategy.EPISODE:
|
||||
|
|
@ -262,7 +262,7 @@ class HerReplayBuffer(DictReplayBuffer):
|
|||
else:
|
||||
raise ValueError(f"Strategy {self.goal_selection_strategy} for sampling goals not supported!")
|
||||
|
||||
return self._buffer["achieved_goal"][her_episode_indices, transitions_indices]
|
||||
return self._buffer["next_achieved_goal"][her_episode_indices, transitions_indices]
|
||||
|
||||
def _sample_transitions(
|
||||
self,
|
||||
|
|
@ -304,14 +304,6 @@ class HerReplayBuffer(DictReplayBuffer):
|
|||
|
||||
ep_lengths = self.episode_lengths[episode_indices]
|
||||
|
||||
# Special case when using the "future" goal sampling strategy
|
||||
# we cannot sample all transitions, we have to remove the last timestep
|
||||
if self.goal_selection_strategy == GoalSelectionStrategy.FUTURE:
|
||||
# restrict the sampling domain when ep_lengths > 1
|
||||
# otherwise filter out the indices
|
||||
her_indices = her_indices[ep_lengths[her_indices] > 1]
|
||||
ep_lengths[her_indices] -= 1
|
||||
|
||||
if online_sampling:
|
||||
# Select which transitions to use
|
||||
transitions_indices = np.random.randint(ep_lengths)
|
||||
|
|
|
|||
|
|
@ -1,16 +1,7 @@
|
|||
# This file is here just to define MlpPolicy/CnnPolicy
|
||||
# that work for PPO
|
||||
from stable_baselines3.common.policies import (
|
||||
ActorCriticCnnPolicy,
|
||||
ActorCriticPolicy,
|
||||
MultiInputActorCriticPolicy,
|
||||
register_policy,
|
||||
)
|
||||
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
|
||||
|
||||
MlpPolicy = ActorCriticPolicy
|
||||
CnnPolicy = ActorCriticCnnPolicy
|
||||
MultiInputPolicy = MultiInputActorCriticPolicy
|
||||
|
||||
register_policy("MlpPolicy", ActorCriticPolicy)
|
||||
register_policy("CnnPolicy", ActorCriticCnnPolicy)
|
||||
register_policy("MultiInputPolicy", MultiInputPolicy)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from gym import spaces
|
|||
from torch.nn import functional as F
|
||||
|
||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
from stable_baselines3.common.policies import ActorCriticPolicy
|
||||
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
|
||||
|
||||
|
|
@ -19,7 +19,7 @@ class PPO(OnPolicyAlgorithm):
|
|||
Paper: https://arxiv.org/abs/1707.06347
|
||||
Code: This implementation borrows code from OpenAI Spinning Up (https://github.com/openai/spinningup/)
|
||||
https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and
|
||||
and Stable Baselines (PPO2 from https://github.com/hill-a/stable-baselines)
|
||||
Stable Baselines (PPO2 from https://github.com/hill-a/stable-baselines)
|
||||
|
||||
Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html
|
||||
|
||||
|
|
@ -65,6 +65,12 @@ class PPO(OnPolicyAlgorithm):
|
|||
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
|
||||
policy_aliases: Dict[str, Type[BasePolicy]] = {
|
||||
"MlpPolicy": ActorCriticPolicy,
|
||||
"CnnPolicy": ActorCriticCnnPolicy,
|
||||
"MultiInputPolicy": MultiInputActorCriticPolicy,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Union[str, Type[ActorCriticPolicy]],
|
||||
|
|
@ -93,7 +99,7 @@ class PPO(OnPolicyAlgorithm):
|
|||
_init_setup_model: bool = True,
|
||||
):
|
||||
|
||||
super(PPO, self).__init__(
|
||||
super().__init__(
|
||||
policy,
|
||||
env,
|
||||
learning_rate=learning_rate,
|
||||
|
|
@ -156,7 +162,7 @@ class PPO(OnPolicyAlgorithm):
|
|||
self._setup_model()
|
||||
|
||||
def _setup_model(self) -> None:
|
||||
super(PPO, self)._setup_model()
|
||||
super()._setup_model()
|
||||
|
||||
# Initialize schedules for policy/value clipping
|
||||
self.clip_range = get_schedule_fn(self.clip_range)
|
||||
|
|
@ -301,7 +307,7 @@ class PPO(OnPolicyAlgorithm):
|
|||
reset_num_timesteps: bool = True,
|
||||
) -> "PPO":
|
||||
|
||||
return super(PPO, self).learn(
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
callback=callback,
|
||||
log_interval=log_interval,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import torch as th
|
|||
from torch import nn
|
||||
|
||||
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
|
||||
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, register_policy
|
||||
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic
|
||||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
from stable_baselines3.common.torch_layers import (
|
||||
BaseFeaturesExtractor,
|
||||
|
|
@ -65,7 +65,7 @@ class Actor(BasePolicy):
|
|||
clip_mean: float = 2.0,
|
||||
normalize_images: bool = True,
|
||||
):
|
||||
super(Actor, self).__init__(
|
||||
super().__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
features_extractor=features_extractor,
|
||||
|
|
@ -235,9 +235,9 @@ class SACPolicy(BasePolicy):
|
|||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
n_critics: int = 2,
|
||||
share_features_extractor: bool = True,
|
||||
share_features_extractor: bool = False,
|
||||
):
|
||||
super(SACPolicy, self).__init__(
|
||||
super().__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
features_extractor_class,
|
||||
|
|
@ -248,10 +248,7 @@ class SACPolicy(BasePolicy):
|
|||
)
|
||||
|
||||
if net_arch is None:
|
||||
if features_extractor_class == NatureCNN:
|
||||
net_arch = []
|
||||
else:
|
||||
net_arch = [256, 256]
|
||||
net_arch = [256, 256]
|
||||
|
||||
actor_arch, critic_arch = get_actor_critic_arch(net_arch)
|
||||
|
||||
|
|
@ -422,9 +419,9 @@ class CnnPolicy(SACPolicy):
|
|||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
n_critics: int = 2,
|
||||
share_features_extractor: bool = True,
|
||||
share_features_extractor: bool = False,
|
||||
):
|
||||
super(CnnPolicy, self).__init__(
|
||||
super().__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
lr_schedule,
|
||||
|
|
@ -493,9 +490,9 @@ class MultiInputPolicy(SACPolicy):
|
|||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
n_critics: int = 2,
|
||||
share_features_extractor: bool = True,
|
||||
share_features_extractor: bool = False,
|
||||
):
|
||||
super(MultiInputPolicy, self).__init__(
|
||||
super().__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
lr_schedule,
|
||||
|
|
@ -514,8 +511,3 @@ class MultiInputPolicy(SACPolicy):
|
|||
n_critics,
|
||||
share_features_extractor,
|
||||
)
|
||||
|
||||
|
||||
register_policy("MlpPolicy", MlpPolicy)
|
||||
register_policy("CnnPolicy", CnnPolicy)
|
||||
register_policy("MultiInputPolicy", MultiInputPolicy)
|
||||
|
|
|
|||
|
|
@ -8,9 +8,10 @@ from torch.nn import functional as F
|
|||
from stable_baselines3.common.buffers import ReplayBuffer
|
||||
from stable_baselines3.common.noise import ActionNoise
|
||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import polyak_update
|
||||
from stable_baselines3.sac.policies import SACPolicy
|
||||
from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy
|
||||
|
||||
|
||||
class SAC(OffPolicyAlgorithm):
|
||||
|
|
@ -72,6 +73,12 @@ class SAC(OffPolicyAlgorithm):
|
|||
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
|
||||
policy_aliases: Dict[str, Type[BasePolicy]] = {
|
||||
"MlpPolicy": MlpPolicy,
|
||||
"CnnPolicy": CnnPolicy,
|
||||
"MultiInputPolicy": MultiInputPolicy,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Union[str, Type[SACPolicy]],
|
||||
|
|
@ -103,10 +110,9 @@ class SAC(OffPolicyAlgorithm):
|
|||
_init_setup_model: bool = True,
|
||||
):
|
||||
|
||||
super(SAC, self).__init__(
|
||||
super().__init__(
|
||||
policy,
|
||||
env,
|
||||
SACPolicy,
|
||||
learning_rate,
|
||||
buffer_size,
|
||||
learning_starts,
|
||||
|
|
@ -144,7 +150,7 @@ class SAC(OffPolicyAlgorithm):
|
|||
self._setup_model()
|
||||
|
||||
def _setup_model(self) -> None:
|
||||
super(SAC, self)._setup_model()
|
||||
super()._setup_model()
|
||||
self._create_aliases()
|
||||
# Target entropy is used when learning the entropy coefficient
|
||||
if self.target_entropy == "auto":
|
||||
|
|
@ -248,7 +254,7 @@ class SAC(OffPolicyAlgorithm):
|
|||
current_q_values = self.critic(replay_data.observations, replay_data.actions)
|
||||
|
||||
# Compute critic loss
|
||||
critic_loss = 0.5 * sum([F.mse_loss(current_q, target_q_values) for current_q in current_q_values])
|
||||
critic_loss = 0.5 * sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values)
|
||||
critic_losses.append(critic_loss.item())
|
||||
|
||||
# Optimize the critic
|
||||
|
|
@ -297,7 +303,7 @@ class SAC(OffPolicyAlgorithm):
|
|||
reset_num_timesteps: bool = True,
|
||||
) -> OffPolicyAlgorithm:
|
||||
|
||||
return super(SAC, self).learn(
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
callback=callback,
|
||||
log_interval=log_interval,
|
||||
|
|
@ -310,7 +316,7 @@ class SAC(OffPolicyAlgorithm):
|
|||
)
|
||||
|
||||
def _excluded_save_params(self) -> List[str]:
|
||||
return super(SAC, self)._excluded_save_params() + ["actor", "critic", "critic_target"]
|
||||
return super()._excluded_save_params() + ["actor", "critic", "critic_target"]
|
||||
|
||||
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
|
||||
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import gym
|
|||
import torch as th
|
||||
from torch import nn
|
||||
|
||||
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, register_policy
|
||||
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic
|
||||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
from stable_baselines3.common.torch_layers import (
|
||||
BaseFeaturesExtractor,
|
||||
|
|
@ -42,7 +42,7 @@ class Actor(BasePolicy):
|
|||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
normalize_images: bool = True,
|
||||
):
|
||||
super(Actor, self).__init__(
|
||||
super().__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
features_extractor=features_extractor,
|
||||
|
|
@ -119,9 +119,9 @@ class TD3Policy(BasePolicy):
|
|||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
n_critics: int = 2,
|
||||
share_features_extractor: bool = True,
|
||||
share_features_extractor: bool = False,
|
||||
):
|
||||
super(TD3Policy, self).__init__(
|
||||
super().__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
features_extractor_class,
|
||||
|
|
@ -134,7 +134,7 @@ class TD3Policy(BasePolicy):
|
|||
# Default network architecture, from the original paper
|
||||
if net_arch is None:
|
||||
if features_extractor_class == NatureCNN:
|
||||
net_arch = []
|
||||
net_arch = [256, 256]
|
||||
else:
|
||||
net_arch = [400, 300]
|
||||
|
||||
|
|
@ -281,9 +281,9 @@ class CnnPolicy(TD3Policy):
|
|||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
n_critics: int = 2,
|
||||
share_features_extractor: bool = True,
|
||||
share_features_extractor: bool = False,
|
||||
):
|
||||
super(CnnPolicy, self).__init__(
|
||||
super().__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
lr_schedule,
|
||||
|
|
@ -335,9 +335,9 @@ class MultiInputPolicy(TD3Policy):
|
|||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
n_critics: int = 2,
|
||||
share_features_extractor: bool = True,
|
||||
share_features_extractor: bool = False,
|
||||
):
|
||||
super(MultiInputPolicy, self).__init__(
|
||||
super().__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
lr_schedule,
|
||||
|
|
@ -351,8 +351,3 @@ class MultiInputPolicy(TD3Policy):
|
|||
n_critics,
|
||||
share_features_extractor,
|
||||
)
|
||||
|
||||
|
||||
register_policy("MlpPolicy", MlpPolicy)
|
||||
register_policy("CnnPolicy", CnnPolicy)
|
||||
register_policy("MultiInputPolicy", MultiInputPolicy)
|
||||
|
|
|
|||
|
|
@ -8,9 +8,10 @@ from torch.nn import functional as F
|
|||
from stable_baselines3.common.buffers import ReplayBuffer
|
||||
from stable_baselines3.common.noise import ActionNoise
|
||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import polyak_update
|
||||
from stable_baselines3.td3.policies import TD3Policy
|
||||
from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TD3Policy
|
||||
|
||||
|
||||
class TD3(OffPolicyAlgorithm):
|
||||
|
|
@ -60,6 +61,12 @@ class TD3(OffPolicyAlgorithm):
|
|||
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
|
||||
policy_aliases: Dict[str, Type[BasePolicy]] = {
|
||||
"MlpPolicy": MlpPolicy,
|
||||
"CnnPolicy": CnnPolicy,
|
||||
"MultiInputPolicy": MultiInputPolicy,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Union[str, Type[TD3Policy]],
|
||||
|
|
@ -88,10 +95,9 @@ class TD3(OffPolicyAlgorithm):
|
|||
_init_setup_model: bool = True,
|
||||
):
|
||||
|
||||
super(TD3, self).__init__(
|
||||
super().__init__(
|
||||
policy,
|
||||
env,
|
||||
TD3Policy,
|
||||
learning_rate,
|
||||
buffer_size,
|
||||
learning_starts,
|
||||
|
|
@ -123,7 +129,7 @@ class TD3(OffPolicyAlgorithm):
|
|||
self._setup_model()
|
||||
|
||||
def _setup_model(self) -> None:
|
||||
super(TD3, self)._setup_model()
|
||||
super()._setup_model()
|
||||
self._create_aliases()
|
||||
|
||||
def _create_aliases(self) -> None:
|
||||
|
|
@ -162,7 +168,7 @@ class TD3(OffPolicyAlgorithm):
|
|||
current_q_values = self.critic(replay_data.observations, replay_data.actions)
|
||||
|
||||
# Compute critic loss
|
||||
critic_loss = sum([F.mse_loss(current_q, target_q_values) for current_q in current_q_values])
|
||||
critic_loss = sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values)
|
||||
critic_losses.append(critic_loss.item())
|
||||
|
||||
# Optimize the critics
|
||||
|
|
@ -202,7 +208,7 @@ class TD3(OffPolicyAlgorithm):
|
|||
reset_num_timesteps: bool = True,
|
||||
) -> OffPolicyAlgorithm:
|
||||
|
||||
return super(TD3, self).learn(
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
callback=callback,
|
||||
log_interval=log_interval,
|
||||
|
|
@ -215,7 +221,7 @@ class TD3(OffPolicyAlgorithm):
|
|||
)
|
||||
|
||||
def _excluded_save_params(self) -> List[str]:
|
||||
return super(TD3, self)._excluded_save_params() + ["actor", "critic", "actor_target", "critic_target"]
|
||||
return super()._excluded_save_params() + ["actor", "critic", "actor_target", "critic_target"]
|
||||
|
||||
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
|
||||
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.5.1a0
|
||||
1.6.1a0
|
||||
|
|
|
|||
|
|
@ -4,9 +4,10 @@ import pytest
|
|||
import torch as th
|
||||
from gym import spaces
|
||||
|
||||
from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer
|
||||
from stable_baselines3.common.buffers import DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
from stable_baselines3.common.type_aliases import DictReplayBufferSamples, ReplayBufferSamples
|
||||
from stable_baselines3.common.utils import get_device
|
||||
from stable_baselines3.common.vec_env import VecNormalize
|
||||
|
||||
|
||||
|
|
@ -71,7 +72,7 @@ def test_replay_buffer_normalization(replay_buffer_cls):
|
|||
env = make_vec_env(env)
|
||||
env = VecNormalize(env)
|
||||
|
||||
buffer = replay_buffer_cls(100, env.observation_space, env.action_space)
|
||||
buffer = replay_buffer_cls(100, env.observation_space, env.action_space, device="cpu")
|
||||
|
||||
# Interract and store transitions
|
||||
env.reset()
|
||||
|
|
@ -94,3 +95,47 @@ def test_replay_buffer_normalization(replay_buffer_cls):
|
|||
assert th.allclose(observations.mean(0), th.zeros(1), atol=1)
|
||||
# Test reward normalization
|
||||
assert np.allclose(sample.rewards.mean(0), np.zeros(1), atol=1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("replay_buffer_cls", [DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer])
|
||||
@pytest.mark.parametrize("device", ["cpu", "cuda", "auto"])
|
||||
def test_device_buffer(replay_buffer_cls, device):
|
||||
if device == "cuda" and not th.cuda.is_available():
|
||||
pytest.skip("CUDA not available")
|
||||
|
||||
env = {
|
||||
RolloutBuffer: DummyEnv,
|
||||
DictRolloutBuffer: DummyDictEnv,
|
||||
ReplayBuffer: DummyEnv,
|
||||
DictReplayBuffer: DummyDictEnv,
|
||||
}[replay_buffer_cls]
|
||||
env = make_vec_env(env)
|
||||
|
||||
buffer = replay_buffer_cls(100, env.observation_space, env.action_space, device=device)
|
||||
|
||||
# Interract and store transitions
|
||||
obs = env.reset()
|
||||
for _ in range(100):
|
||||
action = env.action_space.sample()
|
||||
next_obs, reward, done, info = env.step(action)
|
||||
if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]:
|
||||
episode_start, values, log_prob = np.zeros(1), th.zeros(1), th.ones(1)
|
||||
buffer.add(obs, action, reward, episode_start, values, log_prob)
|
||||
else:
|
||||
buffer.add(obs, next_obs, action, reward, done, info)
|
||||
obs = next_obs
|
||||
|
||||
# Get data from the buffer
|
||||
if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]:
|
||||
data = buffer.get(50)
|
||||
elif replay_buffer_cls in [ReplayBuffer, DictReplayBuffer]:
|
||||
data = buffer.sample(50)
|
||||
|
||||
# Check that all data are on the desired device
|
||||
desired_device = get_device(device).type
|
||||
for value in list(data):
|
||||
if isinstance(value, dict):
|
||||
for key in value.keys():
|
||||
assert value[key].device.type == desired_device
|
||||
elif isinstance(value, th.Tensor):
|
||||
assert value.device.type == desired_device
|
||||
|
|
|
|||
|
|
@ -163,7 +163,9 @@ def test_categorical(dist, CAT_ACTIONS):
|
|||
BernoulliDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS)),
|
||||
CategoricalDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS)),
|
||||
DiagGaussianDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS), th.rand(N_ACTIONS)),
|
||||
MultiCategoricalDistribution([N_ACTIONS, N_ACTIONS]).proba_distribution(th.rand(1, sum([N_ACTIONS, N_ACTIONS]))),
|
||||
MultiCategoricalDistribution(np.array([N_ACTIONS, N_ACTIONS])).proba_distribution(
|
||||
th.rand(1, sum([N_ACTIONS, N_ACTIONS]))
|
||||
),
|
||||
SquashedDiagGaussianDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS), th.rand(N_ACTIONS)),
|
||||
StateDependentNoiseDistribution(N_ACTIONS).proba_distribution(
|
||||
th.rand(N_ACTIONS), th.rand([N_ACTIONS, N_ACTIONS]), th.rand([N_ACTIONS, N_ACTIONS])
|
||||
|
|
|
|||
|
|
@ -141,6 +141,8 @@ def test_non_default_spaces(new_obs_space):
|
|||
spaces.Box(low=1, high=-1, shape=(2,), dtype=np.float32),
|
||||
# Same boundaries
|
||||
spaces.Box(low=1, high=1, shape=(2,), dtype=np.float32),
|
||||
# Unbounded action space
|
||||
spaces.Box(low=-np.inf, high=1, shape=(2,), dtype=np.float32),
|
||||
# Almost good, except for one dim
|
||||
spaces.Box(low=np.array([-1, -1, -1]), high=np.array([1, 1, 0.99]), dtype=np.float32),
|
||||
],
|
||||
|
|
@ -156,8 +158,14 @@ def test_non_default_action_spaces(new_action_space):
|
|||
# Change the action space
|
||||
env.action_space = new_action_space
|
||||
|
||||
with pytest.warns(UserWarning):
|
||||
check_env(env)
|
||||
# Unbounded action space throws an error,
|
||||
# the rest only warning
|
||||
if not np.all(np.isfinite(env.action_space.low)):
|
||||
with pytest.raises(AssertionError), pytest.warns(UserWarning):
|
||||
check_env(env)
|
||||
else:
|
||||
with pytest.warns(UserWarning):
|
||||
check_env(env)
|
||||
|
||||
|
||||
def check_reset_assert_error(env, new_reset_return):
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from stable_baselines3.common.policies import ActorCriticPolicy
|
|||
|
||||
class CustomEnv(gym.Env):
|
||||
def __init__(self, max_steps=8):
|
||||
super(CustomEnv, self).__init__()
|
||||
super().__init__()
|
||||
self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||
self.max_steps = max_steps
|
||||
|
|
@ -54,7 +54,7 @@ class InfiniteHorizonEnv(gym.Env):
|
|||
|
||||
class CheckGAECallback(BaseCallback):
|
||||
def __init__(self):
|
||||
super(CheckGAECallback, self).__init__(verbose=0)
|
||||
super().__init__(verbose=0)
|
||||
|
||||
def _on_rollout_end(self):
|
||||
buffer = self.model.rollout_buffer
|
||||
|
|
@ -99,7 +99,7 @@ class CustomPolicy(ActorCriticPolicy):
|
|||
"""Custom Policy with a constant value function"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(CustomPolicy, self).__init__(*args, **kwargs)
|
||||
super().__init__(*args, **kwargs)
|
||||
self.constant_value = 0.0
|
||||
|
||||
def forward(self, obs, deterministic=False):
|
||||
|
|
|
|||
|
|
@ -156,7 +156,7 @@ def test_save_load(tmp_path, model_class, use_sde, online_sampling):
|
|||
params = deepcopy(model.policy.state_dict())
|
||||
|
||||
# Modify all parameters to be random values
|
||||
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
|
||||
random_params = {param_name: th.rand_like(param) for param_name, param in params.items()}
|
||||
|
||||
# Update model parameters with the new random values
|
||||
model.policy.load_state_dict(random_params)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import time
|
||||
from typing import Sequence
|
||||
from unittest import mock
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
|
@ -381,3 +382,16 @@ def test_fps_logger(tmp_path, algo):
|
|||
# third time, FPS should be the same
|
||||
model.learn(100, log_interval=1, reset_num_timesteps=False)
|
||||
assert max_fps / 10 <= logger.name_to_value["time/fps"] <= max_fps
|
||||
|
||||
|
||||
@pytest.mark.parametrize("algo", [A2C, DQN])
|
||||
def test_fps_no_div_zero(algo):
|
||||
"""Set time to constant and train algorithm to check no division by zero error.
|
||||
|
||||
Time can appear to be constant during short runs on platforms with low-precision
|
||||
timers. We should avoid division by zero errors e.g. when computing FPS in
|
||||
this situation."""
|
||||
with mock.patch("time.time", lambda: 42.0):
|
||||
with mock.patch("time.time_ns", lambda: 42.0):
|
||||
model = algo("MlpPolicy", "CartPole-v1")
|
||||
model.learn(total_timesteps=100)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ def test_monitor(tmp_path):
|
|||
"""
|
||||
env = gym.make("CartPole-v1")
|
||||
env.seed(0)
|
||||
monitor_file = os.path.join(str(tmp_path), "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4()))
|
||||
monitor_file = os.path.join(str(tmp_path), f"stable_baselines-test-{uuid.uuid4()}.monitor.csv")
|
||||
monitor_env = Monitor(env, monitor_file)
|
||||
monitor_env.reset()
|
||||
total_steps = 1000
|
||||
|
|
@ -37,7 +37,7 @@ def test_monitor(tmp_path):
|
|||
assert sum(monitor_env.get_episode_rewards()) == sum(ep_rewards)
|
||||
_ = monitor_env.get_episode_times()
|
||||
|
||||
with open(monitor_file, "rt") as file_handler:
|
||||
with open(monitor_file) as file_handler:
|
||||
first_line = file_handler.readline()
|
||||
assert first_line.startswith("#")
|
||||
metadata = json.loads(first_line[1:])
|
||||
|
|
@ -56,7 +56,7 @@ def test_monitor_load_results(tmp_path):
|
|||
tmp_path = str(tmp_path)
|
||||
env1 = gym.make("CartPole-v1")
|
||||
env1.seed(0)
|
||||
monitor_file1 = os.path.join(tmp_path, "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4()))
|
||||
monitor_file1 = os.path.join(tmp_path, f"stable_baselines-test-{uuid.uuid4()}.monitor.csv")
|
||||
monitor_env1 = Monitor(env1, monitor_file1)
|
||||
|
||||
monitor_files = get_monitor_files(tmp_path)
|
||||
|
|
@ -76,7 +76,7 @@ def test_monitor_load_results(tmp_path):
|
|||
|
||||
env2 = gym.make("CartPole-v1")
|
||||
env2.seed(0)
|
||||
monitor_file2 = os.path.join(tmp_path, "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4()))
|
||||
monitor_file2 = os.path.join(tmp_path, f"stable_baselines-test-{uuid.uuid4()}.monitor.csv")
|
||||
monitor_env2 = Monitor(env2, monitor_file2)
|
||||
monitor_files = get_monitor_files(tmp_path)
|
||||
assert len(monitor_files) == 2
|
||||
|
|
|
|||
|
|
@ -73,11 +73,13 @@ def test_predict(model_class, env_id, device):
|
|||
|
||||
obs = env.reset()
|
||||
action, _ = model.predict(obs)
|
||||
assert isinstance(action, np.ndarray)
|
||||
assert action.shape == env.action_space.shape
|
||||
assert env.action_space.contains(action)
|
||||
|
||||
vec_env_obs = vec_env.reset()
|
||||
action, _ = model.predict(vec_env_obs)
|
||||
assert isinstance(action, np.ndarray)
|
||||
assert action.shape[0] == vec_env_obs.shape[0]
|
||||
|
||||
# Special case for DQN to check the epsilon greedy exploration
|
||||
|
|
|
|||
|
|
@ -10,7 +10,10 @@ normal_action_noise = NormalActionNoise(np.zeros(1), 0.1 * np.ones(1))
|
|||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [TD3, DDPG])
|
||||
@pytest.mark.parametrize("action_noise", [normal_action_noise, OrnsteinUhlenbeckActionNoise(np.zeros(1), 0.1 * np.ones(1))])
|
||||
@pytest.mark.parametrize(
|
||||
"action_noise",
|
||||
[normal_action_noise, OrnsteinUhlenbeckActionNoise(np.zeros(1), 0.1 * np.ones(1))],
|
||||
)
|
||||
def test_deterministic_pg(model_class, action_noise):
|
||||
"""
|
||||
Test for DDPG and variants (TD3).
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ def test_save_load(tmp_path, model_class):
|
|||
model.set_parameters(invalid_object_params, exact_match=False)
|
||||
|
||||
# Test that exact_match catches when something was missed.
|
||||
missing_object_params = dict((k, v) for k, v in list(original_params.items())[:-1])
|
||||
missing_object_params = {k: v for k, v in list(original_params.items())[:-1]}
|
||||
with pytest.raises(ValueError):
|
||||
model.set_parameters(missing_object_params, exact_match=True)
|
||||
|
||||
|
|
@ -375,6 +375,9 @@ def test_warn_buffer(recwarn, model_class, optimize_memory_usage):
|
|||
select_env(model_class),
|
||||
buffer_size=100,
|
||||
optimize_memory_usage=optimize_memory_usage,
|
||||
# we cannot use optimize_memory_usage and handle_timeout_termination
|
||||
# at the same time
|
||||
replay_buffer_kwargs={"handle_timeout_termination": not optimize_memory_usage},
|
||||
policy_kwargs=dict(net_arch=[64]),
|
||||
learning_starts=10,
|
||||
)
|
||||
|
|
@ -446,7 +449,7 @@ def test_save_load_policy(tmp_path, model_class, policy_str, use_sde):
|
|||
params = deepcopy(policy.state_dict())
|
||||
|
||||
# Modify all parameters to be random values
|
||||
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
|
||||
random_params = {param_name: th.rand_like(param) for param_name, param in params.items()}
|
||||
|
||||
# Update model parameters with the new random values
|
||||
policy.load_state_dict(random_params)
|
||||
|
|
@ -537,7 +540,7 @@ def test_save_load_q_net(tmp_path, model_class, policy_str):
|
|||
params = deepcopy(q_net.state_dict())
|
||||
|
||||
# Modify all parameters to be random values
|
||||
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
|
||||
random_params = {param_name: th.rand_like(param) for param_name, param in params.items()}
|
||||
|
||||
# Update model parameters with the new random values
|
||||
q_net.load_state_dict(random_params)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from stable_baselines3.common.evaluation import evaluate_policy
|
|||
|
||||
class DummyMultiDiscreteSpace(gym.Env):
|
||||
def __init__(self, nvec):
|
||||
super(DummyMultiDiscreteSpace, self).__init__()
|
||||
super().__init__()
|
||||
self.observation_space = gym.spaces.MultiDiscrete(nvec)
|
||||
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||
|
||||
|
|
@ -22,7 +22,7 @@ class DummyMultiDiscreteSpace(gym.Env):
|
|||
|
||||
class DummyMultiBinary(gym.Env):
|
||||
def __init__(self, n):
|
||||
super(DummyMultiBinary, self).__init__()
|
||||
super().__init__()
|
||||
self.observation_space = gym.spaces.MultiBinary(n)
|
||||
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||
|
||||
|
|
@ -33,6 +33,19 @@ class DummyMultiBinary(gym.Env):
|
|||
return self.observation_space.sample(), 0.0, False, {}
|
||||
|
||||
|
||||
class DummyMultidimensionalAction(gym.Env):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32)
|
||||
|
||||
def reset(self):
|
||||
return self.observation_space.sample()
|
||||
|
||||
def step(self, action):
|
||||
return self.observation_space.sample(), 0.0, False, {}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [SAC, TD3, DQN])
|
||||
@pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8)])
|
||||
def test_identity_spaces(model_class, env):
|
||||
|
|
@ -53,22 +66,39 @@ def test_identity_spaces(model_class, env):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, DDPG, DQN, PPO, SAC, TD3])
|
||||
@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1"])
|
||||
@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1", DummyMultidimensionalAction()])
|
||||
def test_action_spaces(model_class, env):
|
||||
kwargs = {}
|
||||
if model_class in [SAC, DDPG, TD3]:
|
||||
supported_action_space = env == "Pendulum-v1"
|
||||
supported_action_space = env == "Pendulum-v1" or isinstance(env, DummyMultidimensionalAction)
|
||||
kwargs["learning_starts"] = 2
|
||||
kwargs["train_freq"] = 32
|
||||
elif model_class == DQN:
|
||||
supported_action_space = env == "CartPole-v1"
|
||||
elif model_class in [A2C, PPO]:
|
||||
supported_action_space = True
|
||||
kwargs["n_steps"] = 64
|
||||
|
||||
if supported_action_space:
|
||||
model_class("MlpPolicy", env)
|
||||
model = model_class("MlpPolicy", env, **kwargs)
|
||||
if isinstance(env, DummyMultidimensionalAction):
|
||||
model.learn(64)
|
||||
else:
|
||||
with pytest.raises(AssertionError):
|
||||
model_class("MlpPolicy", env)
|
||||
|
||||
|
||||
def test_sde_multi_dim():
|
||||
SAC(
|
||||
"MlpPolicy",
|
||||
DummyMultidimensionalAction(),
|
||||
learning_starts=10,
|
||||
use_sde=True,
|
||||
sde_sample_freq=2,
|
||||
use_sde_at_warmup=True,
|
||||
).learn(20)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, PPO, DQN])
|
||||
@pytest.mark.parametrize("env", ["Taxi-v3"])
|
||||
def test_discrete_obs_space(model_class, env):
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import os
|
|||
import pytest
|
||||
|
||||
from stable_baselines3 import A2C, PPO, SAC, TD3
|
||||
from stable_baselines3.common.utils import get_latest_run_id
|
||||
|
||||
MODEL_DICT = {
|
||||
"a2c": (A2C, "CartPole-v1"),
|
||||
|
|
@ -35,3 +36,13 @@ def test_tensorboard(tmp_path, model_name):
|
|||
assert os.path.isdir(tmp_path / str(logname + "_1"))
|
||||
# Check that the log dir name increments correctly
|
||||
assert os.path.isdir(tmp_path / str(logname + "_2"))
|
||||
|
||||
|
||||
def test_escape_log_name(tmp_path):
|
||||
# Log name that must be escaped
|
||||
log_name = "filename[16, 16]"
|
||||
# Create folder
|
||||
os.makedirs(str(tmp_path) + f"/{log_name}_1", exist_ok=True)
|
||||
os.makedirs(str(tmp_path) + f"/{log_name}_2", exist_ok=True)
|
||||
last_run_id = get_latest_run_id(tmp_path, log_name)
|
||||
assert last_run_id == 2
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class FlattenBatchNormDropoutExtractor(BaseFeaturesExtractor):
|
|||
"""
|
||||
|
||||
def __init__(self, observation_space: gym.Space):
|
||||
super(FlattenBatchNormDropoutExtractor, self).__init__(
|
||||
super().__init__(
|
||||
observation_space,
|
||||
get_flattened_obs_dim(observation_space),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -180,7 +180,7 @@ class AlwaysDoneWrapper(gym.Wrapper):
|
|||
# Pretends that environment only has single step for each
|
||||
# episode.
|
||||
def __init__(self, env):
|
||||
super(AlwaysDoneWrapper, self).__init__(env)
|
||||
super().__init__(env)
|
||||
self.last_obs = None
|
||||
self.needs_reset = True
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ class NanAndInfEnv(gym.Env):
|
|||
metadata = {"render.modes": ["human"]}
|
||||
|
||||
def __init__(self):
|
||||
super(NanAndInfEnv, self).__init__()
|
||||
super().__init__()
|
||||
self.action_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64)
|
||||
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64)
|
||||
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ class CustomGymEnv(gym.Env):
|
|||
return self.state
|
||||
|
||||
def step(self, action):
|
||||
reward = 1
|
||||
reward = float(np.random.rand())
|
||||
self._choose_next_state()
|
||||
self.current_step += 1
|
||||
done = self.current_step >= self.ep_length
|
||||
|
|
@ -45,7 +45,9 @@ class CustomGymEnv(gym.Env):
|
|||
return np.zeros((4, 4, 3))
|
||||
|
||||
def seed(self, seed=None):
|
||||
pass
|
||||
if seed is not None:
|
||||
np.random.seed(seed)
|
||||
self.observation_space.seed(seed)
|
||||
|
||||
@staticmethod
|
||||
def custom_method(dim_0=1, dim_1=1):
|
||||
|
|
@ -440,3 +442,34 @@ def test_vec_env_is_wrapped():
|
|||
|
||||
vec_env = VecFrameStack(vec_env, n_stack=2)
|
||||
assert vec_env.env_is_wrapped(Monitor) == [False, True]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES)
|
||||
def test_vec_seeding(vec_env_class):
|
||||
def make_env():
|
||||
return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2)))
|
||||
|
||||
# For SubprocVecEnv check for all starting methods
|
||||
start_methods = [None]
|
||||
if vec_env_class != DummyVecEnv:
|
||||
all_methods = {"forkserver", "spawn", "fork"}
|
||||
available_methods = multiprocessing.get_all_start_methods()
|
||||
start_methods = list(all_methods.intersection(available_methods))
|
||||
|
||||
for start_method in start_methods:
|
||||
if start_method is not None:
|
||||
vec_env_class = functools.partial(SubprocVecEnv, start_method=start_method)
|
||||
|
||||
n_envs = 3
|
||||
vec_env = vec_env_class([make_env] * n_envs)
|
||||
# Seed with no argument
|
||||
vec_env.seed()
|
||||
obs = vec_env.reset()
|
||||
_, rewards, _, _ = vec_env.step(np.array([vec_env.action_space.sample() for _ in range(n_envs)]))
|
||||
# Seed should be different per process
|
||||
assert not np.allclose(obs[0], obs[1])
|
||||
assert not np.allclose(rewards[0], rewards[1])
|
||||
assert not np.allclose(obs[1], obs[2])
|
||||
assert not np.allclose(rewards[1], rewards[2])
|
||||
|
||||
vec_env.close()
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ def test_vec_monitor(tmp_path):
|
|||
|
||||
monitor_env.close()
|
||||
|
||||
with open(monitor_file, "rt") as file_handler:
|
||||
with open(monitor_file) as file_handler:
|
||||
first_line = file_handler.readline()
|
||||
assert first_line.startswith("#")
|
||||
metadata = json.loads(first_line[1:])
|
||||
|
|
@ -66,7 +66,7 @@ def test_vec_monitor_info_keywords(tmp_path):
|
|||
|
||||
monitor_env.close()
|
||||
|
||||
with open(monitor_file, "rt") as f:
|
||||
with open(monitor_file) as f:
|
||||
reader = csv.reader(f)
|
||||
for i, line in enumerate(reader):
|
||||
if i == 0 or i == 1:
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ class DummyDictEnv(gym.GoalEnv):
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(DummyDictEnv, self).__init__()
|
||||
super().__init__()
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
"observation": spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32),
|
||||
|
|
@ -388,11 +388,11 @@ def test_offpolicy_normalization(model_class, online_sampling):
|
|||
|
||||
@pytest.mark.parametrize("make_env", [make_env, make_dict_env])
|
||||
def test_sync_vec_normalize(make_env):
|
||||
env = DummyVecEnv([make_env])
|
||||
original_env = DummyVecEnv([make_env])
|
||||
|
||||
assert unwrap_vec_normalize(env) is None
|
||||
assert unwrap_vec_normalize(original_env) is None
|
||||
|
||||
env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=100.0, clip_reward=100.0)
|
||||
env = VecNormalize(original_env, norm_obs=True, norm_reward=True, clip_obs=100.0, clip_reward=100.0)
|
||||
|
||||
assert isinstance(unwrap_vec_normalize(env), VecNormalize)
|
||||
|
||||
|
|
@ -433,6 +433,17 @@ def test_sync_vec_normalize(make_env):
|
|||
assert allclose(obs, eval_env.normalize_obs(original_obs))
|
||||
assert allclose(env.normalize_reward(dummy_rewards), eval_env.normalize_reward(dummy_rewards))
|
||||
|
||||
# Check synchronization when only reward is normalized
|
||||
env = VecNormalize(original_env, norm_obs=False, norm_reward=True, clip_reward=100.0)
|
||||
eval_env = DummyVecEnv([make_env])
|
||||
eval_env = VecNormalize(eval_env, training=False, norm_obs=False, norm_reward=False)
|
||||
env.reset()
|
||||
env.step([env.action_space.sample()])
|
||||
assert not np.allclose(env.ret_rms.mean, eval_env.ret_rms.mean)
|
||||
sync_envs_normalization(env, eval_env)
|
||||
assert np.allclose(env.ret_rms.mean, eval_env.ret_rms.mean)
|
||||
assert np.allclose(env.ret_rms.var, eval_env.ret_rms.var)
|
||||
|
||||
|
||||
def test_discrete_obs():
|
||||
with pytest.raises(ValueError, match=".*only supports.*"):
|
||||
|
|
|
|||
Loading…
Reference in a new issue