diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b45ae31..4bc23e2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -28,7 +28,7 @@ jobs: run: | python -m pip install --upgrade pip # cpu version of pytorch - pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html + pip install torch==1.11+cpu -f https://download.pytorch.org/whl/torch_stable.html pip install .[extra,tests,docs] # Use headless version pip install opencv-python-headless diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 45ca8f5..20953d2 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -9,6 +9,7 @@ pytest: - python --version # MKL_THREADING_LAYER=GNU to avoid MKL_THREADING_LAYER=INTEL incompatibility error - MKL_THREADING_LAYER=GNU make pytest + coverage: '/^TOTAL.+?(\d+\%)$/' doc-build: script: diff --git a/README.md b/README.md index 2a0701c..973a180 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ Documentation is available online: [https://stable-baselines3.readthedocs.io/](h ## Integrations -Stable-Baselines3 has some integration with other libraries/services like Weights & Biases for experiment tracking or Hugging Face for storing/sharing trained models. You can find out more in the [dedicated section](https://stable-baselines3.readthedocs.io/en/master/guide/integrations.html) of the documentation. +Stable-Baselines3 has some integration with other libraries/services like Weights & Biases for experiment tracking or Hugging Face for storing/sharing trained models. You can find out more in the [dedicated section](https://stable-baselines3.readthedocs.io/en/master/guide/integrations.html) of the documentation. ## RL Baselines3 Zoo: A Training Framework for Stable Baselines3 Reinforcement Learning Agents @@ -77,14 +77,14 @@ Documentation: https://stable-baselines3.readthedocs.io/en/master/guide/rl_zoo.h We implement experimental features in a separate contrib repository: [SB3-Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) -This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Truncated Quantile Critics (TQC), Quantile Regression DQN (QR-DQN) or PPO with invalid action masking (Maskable PPO). +This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Recurrent PPO (PPO LSTM), Truncated Quantile Critics (TQC), Quantile Regression DQN (QR-DQN) or PPO with invalid action masking (Maskable PPO). Documentation is available online: [https://sb3-contrib.readthedocs.io/](https://sb3-contrib.readthedocs.io/) ## Installation -**Note:** Stable-Baselines3 supports PyTorch >= 1.8.1. +**Note:** Stable-Baselines3 supports PyTorch >= 1.11 ### Prerequisites Stable Baselines3 requires Python 3.7+. @@ -122,7 +122,7 @@ from stable_baselines3 import PPO env = gym.make("CartPole-v1") model = PPO("MlpPolicy", env, verbose=1) -model.learn(total_timesteps=10000) +model.learn(total_timesteps=10_000) obs = env.reset() for i in range(1000): @@ -140,7 +140,7 @@ Or just train a model with a one liner if [the environment is registered in Gym] ```python from stable_baselines3 import PPO -model = PPO('MlpPolicy', 'CartPole-v1').learn(10000) +model = PPO("MlpPolicy", "CartPole-v1").learn(10_000) ``` Please read the [documentation](https://stable-baselines3.readthedocs.io/) for more examples. @@ -172,6 +172,7 @@ All the following examples can be executed online using Google colab notebooks: | HER | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :x: | | PPO | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | QR-DQN[1](#f1) | :x: | :x: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | +| RecurrentPPO[1](#f1) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | SAC | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: | | TD3 | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: | | TQC[1](#f1) | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: | @@ -231,7 +232,7 @@ To cite this repository in publications: ## Maintainers -Stable-Baselines3 is currently maintained by [Ashley Hill](https://github.com/hill-a) (aka @hill-a), [Antonin Raffin](https://araffin.github.io/) (aka [@araffin](https://github.com/araffin)), [Maximilian Ernestus](https://github.com/ernestum) (aka @ernestum), [Adam Gleave](https://github.com/adamgleave) (@AdamGleave) and [Anssi Kanervisto](https://github.com/Miffyli) (@Miffyli). +Stable-Baselines3 is currently maintained by [Ashley Hill](https://github.com/hill-a) (aka @hill-a), [Antonin Raffin](https://araffin.github.io/) (aka [@araffin](https://github.com/araffin)), [Maximilian Ernestus](https://github.com/ernestum) (aka @ernestum), [Adam Gleave](https://github.com/adamgleave) (@AdamGleave), [Anssi Kanervisto](https://github.com/Miffyli) (@Miffyli) and [Quentin Gallouédec](https://gallouedec.com/) (@qgallouedec). **Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email. Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case. diff --git a/docs/_static/img/split_graph.png b/docs/_static/img/split_graph.png new file mode 100644 index 0000000..c966c55 Binary files /dev/null and b/docs/_static/img/split_graph.png differ diff --git a/docs/conda_env.yml b/docs/conda_env.yml index a01d37b..98a5508 100644 --- a/docs/conda_env.yml +++ b/docs/conda_env.yml @@ -6,9 +6,9 @@ dependencies: - cpuonly=1.0=0 - pip=21.1 - python=3.7 - - pytorch=1.8.1=py3.7_cpu_0 + - pytorch=1.11=py3.7_cpu_0 - pip: - - gym>=0.17.2 + - gym==0.21 - cloudpickle - opencv-python-headless - pandas @@ -16,5 +16,5 @@ dependencies: - matplotlib - sphinx_autodoc_typehints - sphinx>=4.2 - # See https://github.com/readthedocs/sphinx_rtd_theme/issues/1115 - sphinx_rtd_theme>=1.0 + - sphinx_copybutton diff --git a/docs/conf.py b/docs/conf.py index 088f8a0..b44be6f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Configuration file for the Sphinx documentation builder. # @@ -25,6 +24,14 @@ try: except ImportError: enable_spell_check = False +# Try to enable copy button +try: + import sphinx_copybutton # noqa: F401 + + enable_copy_button = True +except ImportError: + enable_copy_button = False + # source code directory, relative to this file, for sphinx-autobuild sys.path.insert(0, os.path.abspath("..")) @@ -46,13 +53,13 @@ sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES) # Read version from file version_file = os.path.join(os.path.dirname(__file__), "../stable_baselines3", "version.txt") -with open(version_file, "r") as file_handler: +with open(version_file) as file_handler: __version__ = file_handler.read().strip() # -- Project information ----------------------------------------------------- project = "Stable Baselines3" -copyright = "2020, Stable Baselines3" +copyright = "2022, Stable Baselines3" author = "Stable Baselines3 Contributors" # The short X.Y version @@ -84,6 +91,9 @@ extensions = [ if enable_spell_check: extensions.append("sphinxcontrib.spelling") +if enable_copy_button: + extensions.append("sphinx_copybutton") + # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] @@ -101,7 +111,7 @@ master_doc = "index" # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. -language = None +language = "en" # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. diff --git a/docs/guide/algos.rst b/docs/guide/algos.rst index 474a047..55aba35 100644 --- a/docs/guide/algos.rst +++ b/docs/guide/algos.rst @@ -15,6 +15,7 @@ DQN ❌ ✔️ ❌ ❌ HER ✔️ ✔️ ❌ ❌ ❌ PPO ✔️ ✔️ ✔️ ✔️ ✔️ QR-DQN [#f1]_ ❌ ️ ✔️ ❌ ❌ ✔️ +RecurrentPPO [#f1]_ ✔️ ✔️ ✔️ ✔️ ✔️ SAC ✔️ ❌ ❌ ❌ ✔️ TD3 ✔️ ❌ ❌ ❌ ✔️ TQC [#f1]_ ✔️ ❌ ❌ ❌ ✔️ @@ -26,8 +27,8 @@ Maskable PPO [#f1]_ ❌ ✔️ ✔️ ✔ .. [#f1] Implemented in `SB3 Contrib `_ .. note:: - ``Tuple`` observation spaces are not supported by any environment - however single-level ``Dict`` spaces are (cf. :ref:`Examples `). + ``Tuple`` observation spaces are not supported by any environment, + however, single-level ``Dict`` spaces are (cf. :ref:`Examples `). Actions ``gym.spaces``: diff --git a/docs/guide/custom_env.rst b/docs/guide/custom_env.rst index 355ecb2..2e2d1f7 100644 --- a/docs/guide/custom_env.rst +++ b/docs/guide/custom_env.rst @@ -61,7 +61,7 @@ Then you can define and train a RL agent with: model = A2C('CnnPolicy', env).learn(total_timesteps=1000) -To check that your environment follows the gym interface, please use: +To check that your environment follows the Gym interface that SB3 supports, please use: .. code-block:: python @@ -71,11 +71,11 @@ To check that your environment follows the gym interface, please use: # It will check your custom environment and output additional warnings if needed check_env(env) - +Gym also have its own `env checker `_ but it checks a superset of what SB3 supports (SB3 does not support all Gym features). We have created a `colab notebook `_ for a concrete example on creating a custom environment along with an example of using it with Stable-Baselines3 interface. -Alternatively, you may look at OpenAI Gym `built-in environments `_. However, the readers are cautioned as per OpenAI Gym `official wiki `_, its advised not to customize their built-in environments. It is better to copy and create new ones if you need to modify them. +Alternatively, you may look at OpenAI Gym `built-in environments `_. However, the readers are cautioned as per OpenAI Gym `official wiki `_, its advised not to customize their built-in environments. It is better to copy and create new ones if you need to modify them. Optionally, you can also register the environment with gym, that will allow you to create the RL agent in one line (and use ``gym.make()`` to instantiate the env): diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index a5b56b2..0d7e7c0 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -729,6 +729,16 @@ to keep track of the agent progress. model.learn(10_000) +SB3 with EnvPool or Isaac Gym +----------------------------- + +Just like Procgen (see above), `EnvPool `_ and `Isaac Gym `_ accelerate the environment by +already providing a vectorized implementation. + +To use SB3 with those tools, you must wrap the env with tool's specific ``VecEnvWrapper`` that will pre-process the data for SB3, +you can find links to those wrappers in `issue #772 `_. + + Record a Video -------------- diff --git a/docs/guide/install.rst b/docs/guide/install.rst index 7beabb7..0169495 100644 --- a/docs/guide/install.rst +++ b/docs/guide/install.rst @@ -6,7 +6,7 @@ Installation Prerequisites ------------- -Stable-Baselines3 requires python 3.7+ and PyTorch >= 1.8.1. +Stable-Baselines3 requires python 3.7+ and PyTorch >= 1.11 Windows 10 ~~~~~~~~~~ @@ -54,6 +54,17 @@ Bleeding-edge version pip install git+https://github.com/DLR-RM/stable-baselines3 +.. note:: + + If you want to use latest gym version (0.24+), you have to use + + .. code-block:: bash + + pip install git+https://github.com/carlosluis/stable-baselines3@fix_tests + + See `PR #780 `_ for more information. + + Development version ------------------- diff --git a/docs/guide/integrations.rst b/docs/guide/integrations.rst index 9007ade..98cf84a 100644 --- a/docs/guide/integrations.rst +++ b/docs/guide/integrations.rst @@ -48,11 +48,14 @@ Hugging Face 🤗 The Hugging Face Hub 🤗 is a central place where anyone can share and explore models. It allows you to host your saved models 💾. You can see the list of stable-baselines3 saved models here: https://huggingface.co/models?other=stable-baselines3 +Most of them are available via the RL Zoo. Official pre-trained models are saved in the SB3 organization on the hub: https://huggingface.co/sb3 We wrote a tutorial on how to use 🤗 Hub and Stable-Baselines3 here: https://colab.research.google.com/drive/1GI0WpThwRHbl-Fu2RHfczq6dci5GBDVE#scrollTo=q4cz-w9MdO7T +For up to date instructions (for instance for using ``package_to_hub()``), please take a look at the Huggingface SB3 package README: https://github.com/huggingface/huggingface_sb3 + Installation ------------- @@ -137,3 +140,56 @@ Then, in this example, we train a PPO agent to play CartPole-v1 and push it to a filename="ppo-CartPole-v1", commit_message="Added Cartpole-v1 model trained with PPO", ) + +MLFLow +====== + +If you want to use `MLFLow `_ to track your SB3 experiments, +you can adapt the following code which defines a custom logger output: + +.. code-block:: python + + import sys + from typing import Any, Dict, Tuple, Union + + import mlflow + import numpy as np + + from stable_baselines3 import SAC + from stable_baselines3.common.logger import HumanOutputFormat, KVWriter, Logger + + + class MLflowOutputFormat(KVWriter): + """ + Dumps key/value pairs into MLflow's numeric format. + """ + + def write( + self, + key_values: Dict[str, Any], + key_excluded: Dict[str, Union[str, Tuple[str, ...]]], + step: int = 0, + ) -> None: + + for (key, value), (_, excluded) in zip( + sorted(key_values.items()), sorted(key_excluded.items()) + ): + + if excluded is not None and "mlflow" in excluded: + continue + + if isinstance(value, np.ScalarType): + if not isinstance(value, str): + mlflow.log_metric(key, value, step) + + + loggers = Logger( + folder=None, + output_formats=[HumanOutputFormat(sys.stdout), MLflowOutputFormat()], + ) + + with mlflow.start_run(): + model = SAC("MlpPolicy", "Pendulum-v1", verbose=2) + # Set custom logger + model.set_logger(loggers) + model.learn(total_timesteps=10000, log_interval=1) diff --git a/docs/guide/migration.rst b/docs/guide/migration.rst index 879a5fb..ef26870 100644 --- a/docs/guide/migration.rst +++ b/docs/guide/migration.rst @@ -141,7 +141,7 @@ DQN ^^^ Only the vanilla DQN is implemented right now but extensions will follow. -Default hyperparameters are taken from the nature paper, except for the optimizer and learning rate that were taken from Stable Baselines defaults. +Default hyperparameters are taken from the Nature paper, except for the optimizer and learning rate that were taken from Stable Baselines defaults. DDPG ^^^^ diff --git a/docs/guide/rl_tips.rst b/docs/guide/rl_tips.rst index 031f947..2f093f6 100644 --- a/docs/guide/rl_tips.rst +++ b/docs/guide/rl_tips.rst @@ -183,6 +183,16 @@ Some basic advice: - start with shaped reward (i.e. informative reward) and simplified version of your problem - debug with random actions to check that your environment works and follows the gym interface: +Two important things to keep in mind when creating a custom environment is to avoid breaking Markov assumption +and properly handle termination due to a timeout (maximum number of steps in an episode). +For instance, if there is some time delay between action and observation (e.g. due to wifi communication), you should give an history of observations +as input. + +Termination due to timeout (max number of steps per episode) needs to be handled separately. You should fill the key in the info dict: ``info["TimeLimit.truncated"] = True``. +If you are using the gym ``TimeLimit`` wrapper, this will be done automatically. +You can read `Time Limit in RL `_ or take a look at the `RL Tips and Tricks video `_ +for more details. + We provide a helper to check that your environment runs without error: @@ -241,12 +251,15 @@ We *recommend following those steps to have a working RL algorithm*: 1. Read the original paper several times 2. Read existing implementations (if available) 3. Try to have some "sign of life" on toy problems -4. Validate the implementation by making it run on harder and harder envs (you can compare results against the RL zoo) - You usually need to run hyperparameter optimization for that step. +4. Validate the implementation by making it run on harder and harder envs (you can compare results against the RL zoo). + You usually need to run hyperparameter optimization for that step. -You need to be particularly careful on the shape of the different objects you are manipulating (a broadcast mistake will fail silently cf `issue #75 `_) +You need to be particularly careful on the shape of the different objects you are manipulating (a broadcast mistake will fail silently cf. `issue #75 `_) and when to stop the gradient propagation. +Don't forget to handle termination due to timeout separately (see remark in the custom environment section above), +you can also take a look at `Issue #284 `_ and `Issue #633 `_. + A personal pick (by @araffin) for environments with gradual difficulty in RL with continuous actions: 1. Pendulum (easy to solve) diff --git a/docs/guide/sb3_contrib.rst b/docs/guide/sb3_contrib.rst index 1dfa912..445832c 100644 --- a/docs/guide/sb3_contrib.rst +++ b/docs/guide/sb3_contrib.rst @@ -8,7 +8,7 @@ We implement experimental features in a separate contrib repository: `SB3-Contrib`_ This allows Stable-Baselines3 (SB3) to maintain a stable and compact core, while still -providing the latest features, like Truncated Quantile Critics (TQC), Augmented Random Search (ARS), Trust Region Policy Optimization (TRPO) or +providing the latest features, like RecurrentPPO (PPO LSTM), Truncated Quantile Critics (TQC), Augmented Random Search (ARS), Trust Region Policy Optimization (TRPO) or Quantile Regression DQN (QR-DQN). Why create this repository? @@ -38,9 +38,11 @@ See documentation for the full list of included features. - `Augmented Random Search (ARS) `_ - `Quantile Regression DQN (QR-DQN)`_ +- `PPO with invalid action masking (Maskable PPO) `_ +- `PPO with recurrent policy (RecurrentPPO aka PPO LSTM) `_ - `Truncated Quantile Critics (TQC)`_ - `Trust Region Policy Optimization (TRPO) `_ -- `PPO with invalid action masking (Maskable PPO) `_ + **Gym Wrappers**: diff --git a/docs/guide/tensorboard.rst b/docs/guide/tensorboard.rst index dc39d8b..625c1be 100644 --- a/docs/guide/tensorboard.rst +++ b/docs/guide/tensorboard.rst @@ -26,10 +26,20 @@ You can also define custom logging name when training (by default it is the algo model.learn(total_timesteps=10_000, tb_log_name="first_run") # Pass reset_num_timesteps=False to continue the training curve in tensorboard # By default, it will create a new curve + # Keep tb_log_name constant to have continuous curve (see note below) model.learn(total_timesteps=10_000, tb_log_name="second_run", reset_num_timesteps=False) model.learn(total_timesteps=10_000, tb_log_name="third_run", reset_num_timesteps=False) +.. note:: + If you specify different ``tb_log_name`` in subsequent runs, you will have split graphs, like in the figure below. + If you want them to be continuous, you must keep the same ``tb_log_name`` (see `issue #975 `_). + And, if you still managed to get your graphs split by other means, just put tensorboard log files into the same folder. + + .. image:: ../_static/img/split_graph.png + :width: 330 + :alt: split_graph + Once the learn function is called, you can monitor the RL agent during or after the training, with the following bash command: .. code-block:: bash diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index e17c3df..eca1173 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,8 +3,7 @@ Changelog ========== - -Release 1.5.1a0 (WIP) +Release 1.6.1a0 (WIP) --------------------------- Breaking Changes: @@ -18,16 +17,80 @@ SB3-Contrib Bug Fixes: ^^^^^^^^^^ -- Fixed saving and loading large policies greater than 2GB (@jkterry1, @ycheng517) +- Fixed the issue that ``predict`` does not always return action as ``np.ndarray`` (@qgallouedec) +- Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers. +- Added multidimensional action space support (@qgallouedec) +- Fixed missing verbose parameter passing in the ``EvalCallback`` constructor (@burakdmb) Deprecations: ^^^^^^^^^^^^^ Others: ^^^^^^^ +- Fixed ``DictReplayBuffer.next_observations`` typing (@qgallouedec) + +- Added support for ``device="auto"`` in buffers and made it default (@qgallouedec) Documentation: ^^^^^^^^^^^^^^ +- Fixed typo in docstring "nature" -> "Nature" (@Melanol) +- Added info on split tensorboard logs into (@Melanol) +- Fixed typo in ppo doc (@francescoluciano) +- Fixed typo in install doc(@jlp-ue) + + +Release 1.6.0 (2022-07-11) +--------------------------- + +**Recurrent PPO (PPO LSTM), better defaults for learning from pixels with SAC/TD3** + +Breaking Changes: +^^^^^^^^^^^^^^^^^ +- Changed the way policy "aliases" are handled ("MlpPolicy", "CnnPolicy", ...), removing the former + ``register_policy`` helper, ``policy_base`` parameter and using ``policy_aliases`` static attributes instead (@Gregwar) +- SB3 now requires PyTorch >= 1.11 +- Changed the default network architecture when using ``CnnPolicy`` or ``MultiInputPolicy`` with SAC or DDPG/TD3, + ``share_features_extractor`` is now set to False by default and the ``net_arch=[256, 256]`` (instead of ``net_arch=[]`` that was before) + +New Features: +^^^^^^^^^^^^^ + +SB3-Contrib +^^^^^^^^^^^ +- Added Recurrent PPO (PPO LSTM). See https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/53 + + +Bug Fixes: +^^^^^^^^^^ +- Fixed saving and loading large policies greater than 2GB (@jkterry1, @ycheng517) +- Fixed final goal selection strategy that did not sample the final achieved goal (@qgallouedec) +- Fixed a bug with special characters in the tensorboard log name (@quantitative-technologies) +- Fixed a bug in ``DummyVecEnv``'s and ``SubprocVecEnv``'s seeding function. None value was unchecked (@ScheiklP) +- Fixed a bug where ``EvalCallback`` would crash when trying to synchronize ``VecNormalize`` stats when observation normalization was disabled +- Added a check for unbounded actions +- Fixed issues due to newer version of protobuf (tensorboard) and sphinx +- Fix exception causes all over the codebase (@cool-RR) +- Prohibit simultaneous use of optimize_memory_usage and handle_timeout_termination due to a bug (@MWeltevrede) +- Fixed a bug in ``kl_divergence`` check that would fail when using numpy arrays with MultiCategorical distribution + +Deprecations: +^^^^^^^^^^^^^ + +Others: +^^^^^^^ +- Upgraded to Python 3.7+ syntax using ``pyupgrade`` +- Removed redundant double-check for nested observations from ``BaseAlgorithm._wrap_env`` (@TibiGG) + +Documentation: +^^^^^^^^^^^^^^ +- Added link to gym doc and gym env checker +- Fix typo in PPO doc (@bcollazo) +- Added link to PPO ICLR blog post +- Added remark about breaking Markov assumption and timeout handling +- Added doc about MLFlow integration via custom logger (@git-thor) +- Updated Huggingface integration doc +- Added copy button for code snippets +- Added doc about EnvPool and Isaac Gym support Release 1.5.0 (2022-03-25) @@ -923,7 +986,8 @@ Maintainers ----------- Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_), `Ashley Hill`_ (aka @hill-a), -`Maximilian Ernestus`_ (aka @ernestum), `Adam Gleave`_ (`@AdamGleave`_) and `Anssi Kanervisto`_ (aka `@Miffyli`_). +`Maximilian Ernestus`_ (aka @ernestum), `Adam Gleave`_ (`@AdamGleave`_), `Anssi Kanervisto`_ (aka `@Miffyli`_) +and `Quentin Gallouédec`_ (aka @qgallouedec). .. _Ashley Hill: https://github.com/hill-a .. _Antonin Raffin: https://araffin.github.io/ @@ -933,6 +997,8 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_) .. _@AdamGleave: https://github.com/adamgleave .. _Anssi Kanervisto: https://github.com/Miffyli .. _@Miffyli: https://github.com/Miffyli +.. _Quentin Gallouédec: https://gallouedec.com/ +.. _@qgallouedec: https://github.com/qgallouedec @@ -957,4 +1023,5 @@ And all the contributors: @wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum @eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP @simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485 -@Gregwar @ycheng517 +@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede +@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb diff --git a/docs/modules/ppo.rst b/docs/modules/ppo.rst index 3aab653..d0c425f 100644 --- a/docs/modules/ppo.rst +++ b/docs/modules/ppo.rst @@ -8,14 +8,14 @@ PPO The `Proximal Policy Optimization `_ algorithm combines ideas from A2C (having multiple workers) and TRPO (it uses a trust region to improve the actor). -The main idea is that after an update, the new policy should be not too far form the old policy. +The main idea is that after an update, the new policy should be not too far from the old policy. For that, ppo uses clipping to avoid too large update. .. note:: PPO contains several modifications from the original algorithm not documented - by OpenAI: advantages are normalized and value function can be also clipped . + by OpenAI: advantages are normalized and value function can be also clipped. Notes @@ -25,11 +25,22 @@ Notes - Clear explanation of PPO on Arxiv Insights channel: https://www.youtube.com/watch?v=5P7I-xPq8u8 - OpenAI blog post: https://blog.openai.com/openai-baselines-ppo/ - Spinning Up guide: https://spinningup.openai.com/en/latest/algorithms/ppo.html +- 37 implementation details blog: https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/ Can I use? ---------- +.. note:: + + A recurrent version of PPO is available in our contrib repo: https://sb3-contrib.readthedocs.io/en/master/modules/ppo_recurrent.html + + However we advise users to start with simple frame-stacking as a simpler, faster + and usually competitive alternative, more info in our report: https://wandb.ai/sb3/no-vel-envs/reports/PPO-vs-RecurrentPPO-aka-PPO-LSTM-on-environments-with-masked-velocity--VmlldzoxOTI4NjE4 + See also `Procgen paper appendix Fig 11. `_. + In practice, you can stack multiple observations using ``VecFrameStack``. + + - Recurrent policies: ❌ - Multi processing: ✔️ - Gym spaces: diff --git a/setup.py b/setup.py index de615a7..2816316 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ import os from setuptools import find_packages, setup -with open(os.path.join("stable_baselines3", "version.txt"), "r") as file_handler: +with open(os.path.join("stable_baselines3", "version.txt")) as file_handler: __version__ = file_handler.read().strip() @@ -43,10 +43,10 @@ import gym from stable_baselines3 import PPO -env = gym.make('CartPole-v1') +env = gym.make("CartPole-v1") -model = PPO('MlpPolicy', env, verbose=1) -model.learn(total_timesteps=10000) +model = PPO("MlpPolicy", env, verbose=1) +model.learn(total_timesteps=10_000) obs = env.reset() for i in range(1000): @@ -57,12 +57,12 @@ for i in range(1000): obs = env.reset() ``` -Or just train a model with a one liner if [the environment is registered in Gym](https://github.com/openai/gym/wiki/Environments) and if [the policy is registered](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html): +Or just train a model with a one liner if [the environment is registered in Gym](https://www.gymlibrary.ml/content/environment_creation/) and if [the policy is registered](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html): ```python from stable_baselines3 import PPO -model = PPO('MlpPolicy', 'CartPole-v1').learn(10000) +model = PPO("MlpPolicy", "CartPole-v1").learn(10_000) ``` """ # noqa:E501 @@ -75,7 +75,7 @@ setup( install_requires=[ "gym==0.21", # Fixed version due to breaking changes in 0.22 "numpy", - "torch>=1.8.1", + "torch>=1.11", # For saving models "cloudpickle", # For reading logs @@ -111,16 +111,21 @@ setup( "sphinxcontrib.spelling", # Type hints support "sphinx-autodoc-typehints", + # Copy button for code snippets + "sphinx_copybutton", ], "extra": [ # For render "opencv-python", # For atari games, - "ale-py~=0.7.4", + "ale-py==0.7.4", "autorom[accept-rom-license]~=0.4.2", "pillow", # Tensorboard support "tensorboard>=2.2.0", + # Protobuf >= 4 has breaking changes + # which does play well with tensorboard + "protobuf~=3.19.0", # Checking memory taken by replay buffer "psutil", ], diff --git a/stable_baselines3/__init__.py b/stable_baselines3/__init__.py index 4e31c5b..d73f5f0 100644 --- a/stable_baselines3/__init__.py +++ b/stable_baselines3/__init__.py @@ -11,7 +11,7 @@ from stable_baselines3.td3 import TD3 # Read version from file version_file = os.path.join(os.path.dirname(__file__), "version.txt") -with open(version_file, "r") as file_handler: +with open(version_file) as file_handler: __version__ = file_handler.read().strip() diff --git a/stable_baselines3/a2c/a2c.py b/stable_baselines3/a2c/a2c.py index 837ec42..13adf68 100644 --- a/stable_baselines3/a2c/a2c.py +++ b/stable_baselines3/a2c/a2c.py @@ -5,7 +5,7 @@ from gym import spaces from torch.nn import functional as F from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm -from stable_baselines3.common.policies import ActorCriticPolicy +from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance @@ -51,6 +51,12 @@ class A2C(OnPolicyAlgorithm): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ + policy_aliases: Dict[str, Type[BasePolicy]] = { + "MlpPolicy": ActorCriticPolicy, + "CnnPolicy": ActorCriticCnnPolicy, + "MultiInputPolicy": MultiInputActorCriticPolicy, + } + def __init__( self, policy: Union[str, Type[ActorCriticPolicy]], @@ -76,7 +82,7 @@ class A2C(OnPolicyAlgorithm): _init_setup_model: bool = True, ): - super(A2C, self).__init__( + super().__init__( policy, env, learning_rate=learning_rate, @@ -188,7 +194,7 @@ class A2C(OnPolicyAlgorithm): reset_num_timesteps: bool = True, ) -> "A2C": - return super(A2C, self).learn( + return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, diff --git a/stable_baselines3/a2c/policies.py b/stable_baselines3/a2c/policies.py index 79c85f8..7299b34 100644 --- a/stable_baselines3/a2c/policies.py +++ b/stable_baselines3/a2c/policies.py @@ -1,16 +1,7 @@ # This file is here just to define MlpPolicy/CnnPolicy # that work for A2C -from stable_baselines3.common.policies import ( - ActorCriticCnnPolicy, - ActorCriticPolicy, - MultiInputActorCriticPolicy, - register_policy, -) +from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy MlpPolicy = ActorCriticPolicy CnnPolicy = ActorCriticCnnPolicy MultiInputPolicy = MultiInputActorCriticPolicy - -register_policy("MlpPolicy", ActorCriticPolicy) -register_policy("CnnPolicy", ActorCriticCnnPolicy) -register_policy("MultiInputPolicy", MultiInputPolicy) diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index 832ad9f..a9b2eca 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -245,4 +245,4 @@ class AtariWrapper(gym.Wrapper): if clip_reward: env = ClipRewardEnv(env) - super(AtariWrapper, self).__init__(env) + super().__init__(env) diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 25c2638..9445ee4 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -17,7 +17,7 @@ from stable_baselines3.common.env_util import is_wrapped from stable_baselines3.common.logger import Logger from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.noise import ActionNoise -from stable_baselines3.common.policies import BasePolicy, get_policy_from_name +from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space, is_image_space_channels_first from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule @@ -60,7 +60,6 @@ class BaseAlgorithm(ABC): :param policy: Policy object :param env: The environment to learn from (if registered in Gym, can be str. Can be None for loading trained models) - :param policy_base: The base policy used by this method :param learning_rate: learning rate for the optimizer, it can be a function of the current progress remaining (from 1 to 0) :param policy_kwargs: Additional arguments to be passed to the policy on creation @@ -83,11 +82,13 @@ class BaseAlgorithm(ABC): :param supported_action_spaces: The action spaces supported by the algorithm. """ + # Policy aliases (see _get_policy_from_name()) + policy_aliases: Dict[str, Type[BasePolicy]] = {} + def __init__( self, policy: Type[BasePolicy], env: Union[GymEnv, str, None], - policy_base: Type[BasePolicy], learning_rate: Union[float, Schedule], policy_kwargs: Optional[Dict[str, Any]] = None, tensorboard_log: Optional[str] = None, @@ -101,9 +102,8 @@ class BaseAlgorithm(ABC): sde_sample_freq: int = -1, supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None, ): - - if isinstance(policy, str) and policy_base is not None: - self.policy_class = get_policy_from_name(policy_base, policy) + if isinstance(policy, str): + self.policy_class = self._get_policy_from_name(policy) else: self.policy_class = policy @@ -185,6 +185,11 @@ class BaseAlgorithm(ABC): if self.use_sde and not isinstance(self.action_space, gym.spaces.Box): raise ValueError("generalized State-Dependent Exploration (gSDE) can only be used with continuous actions.") + if isinstance(self.action_space, gym.spaces.Box): + assert np.all( + np.isfinite(np.array([self.action_space.low, self.action_space.high])) + ), "Continuous action space must have a finite lower and upper bound" + @staticmethod def _wrap_env(env: GymEnv, verbose: int = 0, monitor_wrapper: bool = True) -> VecEnv: """ " @@ -209,11 +214,6 @@ class BaseAlgorithm(ABC): # Make sure that dict-spaces are not nested (not supported) check_for_nested_spaces(env.observation_space) - if isinstance(env.observation_space, gym.spaces.Dict): - for space in env.observation_space.spaces.values(): - if isinstance(space, gym.spaces.Dict): - raise ValueError("Nested observation spaces are not supported (Dict spaces inside Dict space).") - if not is_vecenv_wrapped(env, VecTransposeImage): wrap_with_vectranspose = False if isinstance(env.observation_space, gym.spaces.Dict): @@ -325,6 +325,23 @@ class BaseAlgorithm(ABC): "_custom_logger", ] + def _get_policy_from_name(self, policy_name: str) -> Type[BasePolicy]: + """ + Get a policy class from its name representation. + + The goal here is to standardize policy naming, e.g. + all algorithms can call upon "MlpPolicy" or "CnnPolicy", + and they receive respective policies that work for them. + + :param policy_name: Alias of the policy + :return: A policy class (type) + """ + + if policy_name in self.policy_aliases: + return self.policy_aliases[policy_name] + else: + raise ValueError(f"Policy {policy_name} unknown") + def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: """ Get the name of the torch variables that will be saved with @@ -375,6 +392,7 @@ class BaseAlgorithm(ABC): log_path=log_path, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, + verbose=self.verbose, ) callback = CallbackList([callback, eval_callback]) @@ -405,7 +423,7 @@ class BaseAlgorithm(ABC): :param tb_log_name: the name of the run for tensorboard log :return: """ - self.start_time = time.time() + self.start_time = time.time_ns() if self.ep_info_buffer is None or reset_num_timesteps: # Initialize buffers if they don't exist, or reinitialize if resetting counters @@ -611,11 +629,11 @@ class BaseAlgorithm(ABC): attr = None try: attr = recursive_getattr(self, name) - except Exception: + except Exception as e: # What errors recursive_getattr could throw? KeyError, but # possible something else too (e.g. if key is an int?). # Catch anything for now. - raise ValueError(f"Key {name} is an invalid object name.") + raise ValueError(f"Key {name} is an invalid object name.") from e if isinstance(attr, th.optim.Optimizer): # Optimizers do not support "strict" keyword... diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index bba2272..5972531 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -13,6 +13,7 @@ from stable_baselines3.common.type_aliases import ( ReplayBufferSamples, RolloutBufferSamples, ) +from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import VecNormalize try: @@ -39,10 +40,10 @@ class BaseBuffer(ABC): buffer_size: int, observation_space: spaces.Space, action_space: spaces.Space, - device: Union[th.device, str] = "cpu", + device: Union[th.device, str] = "auto", n_envs: int = 1, ): - super(BaseBuffer, self).__init__() + super().__init__() self.buffer_size = buffer_size self.observation_space = observation_space self.action_space = action_space @@ -51,7 +52,7 @@ class BaseBuffer(ABC): self.action_dim = get_action_dim(action_space) self.pos = 0 self.full = False - self.device = device + self.device = get_device(device) self.n_envs = n_envs @staticmethod @@ -157,13 +158,14 @@ class ReplayBuffer(BaseBuffer): :param buffer_size: Max number of element in the buffer :param observation_space: Observation space :param action_space: Action space - :param device: + :param device: PyTorch device :param n_envs: Number of parallel environments :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer which reduces by almost a factor two the memory used, at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 and https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274 + Cannot be used in combination with handle_timeout_termination. :param handle_timeout_termination: Handle timeout termination (due to timelimit) separately and treat the task as infinite horizon task. https://github.com/DLR-RM/stable-baselines3/issues/284 @@ -174,12 +176,12 @@ class ReplayBuffer(BaseBuffer): buffer_size: int, observation_space: spaces.Space, action_space: spaces.Space, - device: Union[th.device, str] = "cpu", + device: Union[th.device, str] = "auto", n_envs: int = 1, optimize_memory_usage: bool = False, handle_timeout_termination: bool = True, ): - super(ReplayBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) + super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) # Adjust buffer size self.buffer_size = max(buffer_size // n_envs, 1) @@ -188,6 +190,13 @@ class ReplayBuffer(BaseBuffer): if psutil is not None: mem_available = psutil.virtual_memory().available + # there is a bug if both optimize_memory_usage and handle_timeout_termination are true + # see https://github.com/DLR-RM/stable-baselines3/issues/934 + if optimize_memory_usage and handle_timeout_termination: + raise ValueError( + "ReplayBuffer does not support optimize_memory_usage = True " + "and handle_timeout_termination = True simultaneously." + ) self.optimize_memory_usage = optimize_memory_usage self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=observation_space.dtype) @@ -239,8 +248,7 @@ class ReplayBuffer(BaseBuffer): next_obs = next_obs.reshape((self.n_envs,) + self.obs_shape) # Same, for actions - if isinstance(self.action_space, spaces.Discrete): - action = action.reshape((self.n_envs, self.action_dim)) + action = action.reshape((self.n_envs, self.action_dim)) # Copy to avoid modification by reference self.observations[self.pos] = np.array(obs).copy() @@ -321,7 +329,7 @@ class RolloutBuffer(BaseBuffer): :param buffer_size: Max number of element in the buffer :param observation_space: Observation space :param action_space: Action space - :param device: + :param device: PyTorch device :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator Equivalent to classic advantage when set to 1. :param gamma: Discount factor @@ -333,13 +341,13 @@ class RolloutBuffer(BaseBuffer): buffer_size: int, observation_space: spaces.Space, action_space: spaces.Space, - device: Union[th.device, str] = "cpu", + device: Union[th.device, str] = "auto", gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, ): - super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) + super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) self.gae_lambda = gae_lambda self.gamma = gamma self.observations, self.actions, self.rewards, self.advantages = None, None, None, None @@ -358,7 +366,7 @@ class RolloutBuffer(BaseBuffer): self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.generator_ready = False - super(RolloutBuffer, self).reset() + super().reset() def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None: """ @@ -425,6 +433,9 @@ class RolloutBuffer(BaseBuffer): if isinstance(self.observation_space, spaces.Discrete): obs = obs.reshape((self.n_envs,) + self.obs_shape) + # Same reshape, for actions + action = action.reshape((self.n_envs, self.action_dim)) + self.observations[self.pos] = np.array(obs).copy() self.actions[self.pos] = np.array(action).copy() self.rewards[self.pos] = np.array(reward).copy() @@ -483,7 +494,7 @@ class DictReplayBuffer(ReplayBuffer): :param buffer_size: Max number of element in the buffer :param observation_space: Observation space :param action_space: Action space - :param device: + :param device: PyTorch device :param n_envs: Number of parallel environments :param optimize_memory_usage: Enable a memory efficient variant Disabled for now (see https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702) @@ -497,7 +508,7 @@ class DictReplayBuffer(ReplayBuffer): buffer_size: int, observation_space: spaces.Space, action_space: spaces.Space, - device: Union[th.device, str] = "cpu", + device: Union[th.device, str] = "auto", n_envs: int = 1, optimize_memory_usage: bool = False, handle_timeout_termination: bool = True, @@ -578,8 +589,7 @@ class DictReplayBuffer(ReplayBuffer): self.next_observations[key][self.pos] = np.array(next_obs[key]).copy() # Same reshape, for actions - if isinstance(self.action_space, spaces.Discrete): - action = action.reshape((self.n_envs, self.action_dim)) + action = action.reshape((self.n_envs, self.action_dim)) self.actions[self.pos] = np.array(action).copy() self.rewards[self.pos] = np.array(reward).copy() @@ -649,7 +659,7 @@ class DictRolloutBuffer(RolloutBuffer): :param buffer_size: Max number of element in the buffer :param observation_space: Observation space :param action_space: Action space - :param device: + :param device: PyTorch device :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator Equivalent to Monte-Carlo advantage estimate when set to 1. :param gamma: Discount factor @@ -661,7 +671,7 @@ class DictRolloutBuffer(RolloutBuffer): buffer_size: int, observation_space: spaces.Space, action_space: spaces.Space, - device: Union[th.device, str] = "cpu", + device: Union[th.device, str] = "auto", gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py index 27ce5e6..e9f46fe 100644 --- a/stable_baselines3/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -19,7 +19,7 @@ class BaseCallback(ABC): """ def __init__(self, verbose: int = 0): - super(BaseCallback, self).__init__() + super().__init__() # The RL model self.model = None # type: Optional[base_class.BaseAlgorithm] # An alias for self.model.get_env(), the environment used for training @@ -127,14 +127,14 @@ class EventCallback(BaseCallback): """ def __init__(self, callback: Optional[BaseCallback] = None, verbose: int = 0): - super(EventCallback, self).__init__(verbose=verbose) + super().__init__(verbose=verbose) self.callback = callback # Give access to the parent if callback is not None: self.callback.parent = self def init_callback(self, model: "base_class.BaseAlgorithm") -> None: - super(EventCallback, self).init_callback(model) + super().init_callback(model) if self.callback is not None: self.callback.init_callback(self.model) @@ -169,7 +169,7 @@ class CallbackList(BaseCallback): """ def __init__(self, callbacks: List[BaseCallback]): - super(CallbackList, self).__init__() + super().__init__() assert isinstance(callbacks, list) self.callbacks = callbacks @@ -228,7 +228,7 @@ class CheckpointCallback(BaseCallback): """ def __init__(self, save_freq: int, save_path: str, name_prefix: str = "rl_model", verbose: int = 0): - super(CheckpointCallback, self).__init__(verbose) + super().__init__(verbose) self.save_freq = save_freq self.save_path = save_path self.name_prefix = name_prefix @@ -256,7 +256,7 @@ class ConvertCallback(BaseCallback): """ def __init__(self, callback: Callable[[Dict[str, Any], Dict[str, Any]], bool], verbose: int = 0): - super(ConvertCallback, self).__init__(verbose) + super().__init__(verbose) self.callback = callback def _on_step(self) -> bool: @@ -307,7 +307,7 @@ class EvalCallback(EventCallback): verbose: int = 1, warn: bool = True, ): - super(EvalCallback, self).__init__(callback_after_eval, verbose=verbose) + super().__init__(callback_after_eval, verbose=verbose) self.callback_on_new_best = callback_on_new_best if self.callback_on_new_best is not None: @@ -380,12 +380,12 @@ class EvalCallback(EventCallback): if self.model.get_vec_normalize_env() is not None: try: sync_envs_normalization(self.training_env, self.eval_env) - except AttributeError: + except AttributeError as e: raise AssertionError( "Training and eval env are not wrapped the same way, " "see https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html#evalcallback " "and warning above." - ) + ) from e # Reset success rate buffer self._is_success_buffer = [] @@ -480,7 +480,7 @@ class StopTrainingOnRewardThreshold(BaseCallback): """ def __init__(self, reward_threshold: float, verbose: int = 0): - super(StopTrainingOnRewardThreshold, self).__init__(verbose=verbose) + super().__init__(verbose=verbose) self.reward_threshold = reward_threshold def _on_step(self) -> bool: @@ -505,7 +505,7 @@ class EveryNTimesteps(EventCallback): """ def __init__(self, n_steps: int, callback: BaseCallback): - super(EveryNTimesteps, self).__init__(callback) + super().__init__(callback) self.n_steps = n_steps self.last_time_trigger = 0 @@ -528,7 +528,7 @@ class StopTrainingOnMaxEpisodes(BaseCallback): """ def __init__(self, max_episodes: int, verbose: int = 0): - super(StopTrainingOnMaxEpisodes, self).__init__(verbose=verbose) + super().__init__(verbose=verbose) self.max_episodes = max_episodes self._total_max_episodes = max_episodes self.n_episodes = 0 @@ -573,7 +573,7 @@ class StopTrainingOnNoModelImprovement(BaseCallback): """ def __init__(self, max_no_improvement_evals: int, min_evals: int = 0, verbose: int = 0): - super(StopTrainingOnNoModelImprovement, self).__init__(verbose=verbose) + super().__init__(verbose=verbose) self.max_no_improvement_evals = max_no_improvement_evals self.min_evals = min_evals self.last_best_mean_reward = -np.inf diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index 1c0e54a..fc48625 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Tuple, Union import gym +import numpy as np import torch as th from gym import spaces from torch import nn @@ -16,7 +17,7 @@ class Distribution(ABC): """Abstract base class for distributions.""" def __init__(self): - super(Distribution, self).__init__() + super().__init__() self.distribution = None @abstractmethod @@ -120,7 +121,7 @@ class DiagGaussianDistribution(Distribution): """ def __init__(self, action_dim: int): - super(DiagGaussianDistribution, self).__init__() + super().__init__() self.action_dim = action_dim self.mean_actions = None self.log_std = None @@ -201,13 +202,13 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution): """ def __init__(self, action_dim: int, epsilon: float = 1e-6): - super(SquashedDiagGaussianDistribution, self).__init__(action_dim) + super().__init__(action_dim) # Avoid NaN (prevents division by zero or log of zero) self.epsilon = epsilon self.gaussian_actions = None def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "SquashedDiagGaussianDistribution": - super(SquashedDiagGaussianDistribution, self).proba_distribution(mean_actions, log_std) + super().proba_distribution(mean_actions, log_std) return self def log_prob(self, actions: th.Tensor, gaussian_actions: Optional[th.Tensor] = None) -> th.Tensor: @@ -219,7 +220,7 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution): gaussian_actions = TanhBijector.inverse(actions) # Log likelihood for a Gaussian distribution - log_prob = super(SquashedDiagGaussianDistribution, self).log_prob(gaussian_actions) + log_prob = super().log_prob(gaussian_actions) # Squash correction (from original SAC implementation) # this comes from the fact that tanh is bijective and differentiable log_prob -= th.sum(th.log(1 - actions**2 + self.epsilon), dim=1) @@ -254,7 +255,7 @@ class CategoricalDistribution(Distribution): """ def __init__(self, action_dim: int): - super(CategoricalDistribution, self).__init__() + super().__init__() self.action_dim = action_dim def proba_distribution_net(self, latent_dim: int) -> nn.Module: @@ -305,7 +306,7 @@ class MultiCategoricalDistribution(Distribution): """ def __init__(self, action_dims: List[int]): - super(MultiCategoricalDistribution, self).__init__() + super().__init__() self.action_dims = action_dims def proba_distribution_net(self, latent_dim: int) -> nn.Module: @@ -360,7 +361,7 @@ class BernoulliDistribution(Distribution): """ def __init__(self, action_dims: int): - super(BernoulliDistribution, self).__init__() + super().__init__() self.action_dims = action_dims def proba_distribution_net(self, latent_dim: int) -> nn.Module: @@ -433,7 +434,7 @@ class StateDependentNoiseDistribution(Distribution): learn_features: bool = False, epsilon: float = 1e-6, ): - super(StateDependentNoiseDistribution, self).__init__() + super().__init__() self.action_dim = action_dim self.latent_sde_dim = None self.mean_actions = None @@ -577,10 +578,10 @@ class StateDependentNoiseDistribution(Distribution): return th.mm(latent_sde, self.exploration_mat) # Use batch matrix multiplication for efficient computation # (batch_size, n_features) -> (batch_size, 1, n_features) - latent_sde = latent_sde.unsqueeze(1) + latent_sde = latent_sde.unsqueeze(dim=1) # (batch_size, 1, n_actions) noise = th.bmm(latent_sde, self.exploration_matrices) - return noise.squeeze(1) + return noise.squeeze(dim=1) def actions_from_params( self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor, deterministic: bool = False @@ -597,7 +598,7 @@ class StateDependentNoiseDistribution(Distribution): return actions, log_prob -class TanhBijector(object): +class TanhBijector: """ Bijective transformation of a probability distribution using a squashing function (tanh) @@ -607,7 +608,7 @@ class TanhBijector(object): """ def __init__(self, epsilon: float = 1e-6): - super(TanhBijector, self).__init__() + super().__init__() self.epsilon = epsilon @staticmethod @@ -657,7 +658,6 @@ def make_proba_distribution( dist_kwargs = {} if isinstance(action_space, spaces.Box): - assert len(action_space.shape) == 1, "Error: the action space must be a vector" cls = StateDependentNoiseDistribution if use_sde else DiagGaussianDistribution return cls(get_action_dim(action_space), **dist_kwargs) elif isinstance(action_space, spaces.Discrete): @@ -688,7 +688,7 @@ def kl_divergence(dist_true: Distribution, dist_pred: Distribution) -> th.Tensor # MultiCategoricalDistribution is not a PyTorch Distribution subclass # so we need to implement it ourselves! if isinstance(dist_pred, MultiCategoricalDistribution): - assert dist_pred.action_dims == dist_true.action_dims, "Error: distributions must have the same input space" + assert np.allclose(dist_pred.action_dims, dist_true.action_dims), "Error: distributions must have the same input space" return th.stack( [th.distributions.kl_divergence(p, q) for p, q in zip(dist_true.distribution, dist_pred.distribution)], dim=1, diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index c4e5669..3b2c502 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -147,7 +147,7 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action try: _check_obs(obs[key], observation_space.spaces[key], "reset") except AssertionError as e: - raise AssertionError(f"Error while checking key={key}: " + str(e)) + raise AssertionError(f"Error while checking key={key}: " + str(e)) from e else: _check_obs(obs, observation_space, "reset") @@ -166,7 +166,7 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action try: _check_obs(obs[key], observation_space.spaces[key], "step") except AssertionError as e: - raise AssertionError(f"Error while checking key={key}: " + str(e)) + raise AssertionError(f"Error while checking key={key}: " + str(e)) from e else: _check_obs(obs, observation_space, "step") @@ -274,6 +274,11 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) - "cf https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html" ) + if isinstance(action_space, spaces.Box): + assert np.all( + np.isfinite(np.array([action_space.low, action_space.high])) + ), "Continuous action space must have a finite lower and upper bound" + if isinstance(action_space, spaces.Box) and action_space.dtype != np.dtype(np.float32): warnings.warn( f"Your action space has dtype {action_space.dtype}, we recommend using np.float32 to avoid cast errors." diff --git a/stable_baselines3/common/envs/bit_flipping_env.py b/stable_baselines3/common/envs/bit_flipping_env.py index c5d713a..a881b32 100644 --- a/stable_baselines3/common/envs/bit_flipping_env.py +++ b/stable_baselines3/common/envs/bit_flipping_env.py @@ -36,7 +36,7 @@ class BitFlippingEnv(GoalEnv): image_obs_space: bool = False, channel_first: bool = True, ): - super(BitFlippingEnv, self).__init__() + super().__init__() # Shape of the observation when using image space self.image_shape = (1, 36, 36) if channel_first else (36, 36, 1) # The achieved goal is determined by the current state @@ -115,7 +115,7 @@ class BitFlippingEnv(GoalEnv): if self.discrete_obs_space: # The internal state is the binary representation of the # observed one - return int(sum([state[i] * 2**i for i in range(len(state))])) + return int(sum(state[i] * 2**i for i in range(len(state)))) if self.image_obs_space: size = np.prod(self.image_shape) @@ -135,7 +135,7 @@ class BitFlippingEnv(GoalEnv): if isinstance(state, int): state = np.array(state).reshape(batch_size, -1) # Convert to binary representation - state = (((state[:, :] & (1 << np.arange(len(self.state))))) > 0).astype(int) + state = ((state[:, :] & (1 << np.arange(len(self.state)))) > 0).astype(int) elif self.image_obs_space: state = state.reshape(batch_size, -1)[:, : len(self.state)] / 255 else: diff --git a/stable_baselines3/common/envs/multi_input_envs.py b/stable_baselines3/common/envs/multi_input_envs.py index 177a641..2e5f13f 100644 --- a/stable_baselines3/common/envs/multi_input_envs.py +++ b/stable_baselines3/common/envs/multi_input_envs.py @@ -42,7 +42,7 @@ class SimpleMultiObsEnv(gym.Env): discrete_actions: bool = True, channel_last: bool = True, ): - super(SimpleMultiObsEnv, self).__init__() + super().__init__() self.vector_size = 5 if channel_last: diff --git a/stable_baselines3/common/logger.py b/stable_baselines3/common/logger.py index 6493a3e..1295e5b 100644 --- a/stable_baselines3/common/logger.py +++ b/stable_baselines3/common/logger.py @@ -17,6 +17,7 @@ try: except ImportError: SummaryWriter = None + DEBUG = 10 INFO = 20 WARN = 30 @@ -24,7 +25,7 @@ ERROR = 40 DISABLED = 50 -class Video(object): +class Video: """ Video data class storing the video frames and the frame per seconds @@ -37,7 +38,7 @@ class Video(object): self.fps = fps -class Figure(object): +class Figure: """ Figure data class storing a matplotlib figure and whether to close the figure after logging it @@ -50,7 +51,7 @@ class Figure(object): self.close = close -class Image(object): +class Image: """ Image data class storing an image and data format @@ -80,13 +81,13 @@ class FormatUnsupportedError(NotImplementedError): format_str = f"formats {', '.join(unsupported_formats)} are" else: format_str = f"format {unsupported_formats[0]} is" - super(FormatUnsupportedError, self).__init__( + super().__init__( f"The {format_str} not supported for the {value_description} value logged.\n" f"You can exclude formats via the `exclude` parameter of the logger's `record` function." ) -class KVWriter(object): +class KVWriter: """ Key Value writer """ @@ -108,7 +109,7 @@ class KVWriter(object): raise NotImplementedError -class SeqWriter(object): +class SeqWriter: """ sequence writer """ @@ -246,12 +247,13 @@ def filter_excluded_keys( class JSONOutputFormat(KVWriter): - def __init__(self, filename: str): - """ - log to a file, in the JSON format + """ + Log to a file, in the JSON format - :param filename: the file to write the log to - """ + :param filename: the file to write the log to + """ + + def __init__(self, filename: str): self.file = open(filename, "wt") def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None: @@ -287,13 +289,13 @@ class JSONOutputFormat(KVWriter): class CSVOutputFormat(KVWriter): + """ + Log to a file, in a CSV format + + :param filename: the file to write the log to + """ + def __init__(self, filename: str): - """ - log to a file, in a CSV format - - :param filename: the file to write the log to - """ - self.file = open(filename, "w+t") self.keys = [] self.separator = "," @@ -351,12 +353,13 @@ class CSVOutputFormat(KVWriter): class TensorBoardOutputFormat(KVWriter): - def __init__(self, folder: str): - """ - Dumps key/value pairs into TensorBoard's numeric format. + """ + Dumps key/value pairs into TensorBoard's numeric format. - :param folder: the folder to write the log to - """ + :param folder: the folder to write the log to + """ + + def __init__(self, folder: str): assert SummaryWriter is not None, "tensorboard is not installed, you can use " "pip install tensorboard to do so" self.writer = SummaryWriter(log_dir=folder) @@ -427,7 +430,7 @@ def make_output_format(_format: str, log_dir: str, log_suffix: str = "") -> KVWr # ================================================================ -class Logger(object): +class Logger: """ The logger class. @@ -623,7 +626,7 @@ def read_json(filename: str) -> pandas.DataFrame: :return: the data in the json """ data = [] - with open(filename, "rt") as file_handler: + with open(filename) as file_handler: for line in file_handler: data.append(json.loads(line)) return pandas.DataFrame(data) diff --git a/stable_baselines3/common/monitor.py b/stable_baselines3/common/monitor.py index 04cda22..a482b72 100644 --- a/stable_baselines3/common/monitor.py +++ b/stable_baselines3/common/monitor.py @@ -36,7 +36,7 @@ class Monitor(gym.Wrapper): reset_keywords: Tuple[str, ...] = (), info_keywords: Tuple[str, ...] = (), ): - super(Monitor, self).__init__(env=env) + super().__init__(env=env) self.t_start = time.time() if filename is not None: self.results_writer = ResultsWriter( @@ -110,7 +110,7 @@ class Monitor(gym.Wrapper): """ Closes the environment """ - super(Monitor, self).close() + super().close() if self.results_writer is not None: self.results_writer.close() @@ -224,7 +224,7 @@ def load_results(path: str) -> pandas.DataFrame: raise LoadMonitorResultsError(f"No monitor files of the form *{Monitor.EXT} found in {path}") data_frames, headers = [], [] for file_name in monitor_files: - with open(file_name, "rt") as file_handler: + with open(file_name) as file_handler: first_line = file_handler.readline() assert first_line[0] == "#" header = json.loads(first_line[1:]) diff --git a/stable_baselines3/common/noise.py b/stable_baselines3/common/noise.py index b1db6f4..baa72e9 100644 --- a/stable_baselines3/common/noise.py +++ b/stable_baselines3/common/noise.py @@ -11,7 +11,7 @@ class ActionNoise(ABC): """ def __init__(self): - super(ActionNoise, self).__init__() + super().__init__() def reset(self) -> None: """ @@ -35,7 +35,7 @@ class NormalActionNoise(ActionNoise): def __init__(self, mean: np.ndarray, sigma: np.ndarray): self._mu = mean self._sigma = sigma - super(NormalActionNoise, self).__init__() + super().__init__() def __call__(self) -> np.ndarray: return np.random.normal(self._mu, self._sigma) @@ -72,7 +72,7 @@ class OrnsteinUhlenbeckActionNoise(ActionNoise): self.initial_noise = initial_noise self.noise_prev = np.zeros_like(self._mu) self.reset() - super(OrnsteinUhlenbeckActionNoise, self).__init__() + super().__init__() def __call__(self) -> np.ndarray: noise = ( @@ -105,8 +105,8 @@ class VectorizedActionNoise(ActionNoise): try: self.n_envs = int(n_envs) assert self.n_envs > 0 - except (TypeError, AssertionError): - raise ValueError(f"Expected n_envs={n_envs} to be positive integer greater than 0") + except (TypeError, AssertionError) as e: + raise ValueError(f"Expected n_envs={n_envs} to be positive integer greater than 0") from e self.base_noise = base_noise self.noises = [copy.deepcopy(self.base_noise) for _ in range(n_envs)] diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index 27e8bdd..b841eb0 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -1,5 +1,6 @@ import io import pathlib +import sys import time import warnings from copy import deepcopy @@ -28,7 +29,6 @@ class OffPolicyAlgorithm(BaseAlgorithm): :param policy: Policy object :param env: The environment to learn from (if registered in Gym, can be str. Can be None for loading trained models) - :param policy_base: The base policy used by this method :param learning_rate: learning rate for the optimizer, it can be a function of the current progress remaining (from 1 to 0) :param buffer_size: size of the replay buffer @@ -76,7 +76,6 @@ class OffPolicyAlgorithm(BaseAlgorithm): self, policy: Type[BasePolicy], env: Union[GymEnv, str], - policy_base: Type[BasePolicy], learning_rate: Union[float, Schedule], buffer_size: int = 1_000_000, # 1e6 learning_starts: int = 100, @@ -104,10 +103,9 @@ class OffPolicyAlgorithm(BaseAlgorithm): supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None, ): - super(OffPolicyAlgorithm, self).__init__( + super().__init__( policy=policy, env=env, - policy_base=policy_base, learning_rate=learning_rate, policy_kwargs=policy_kwargs, tensorboard_log=tensorboard_log, @@ -160,8 +158,10 @@ class OffPolicyAlgorithm(BaseAlgorithm): try: train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1])) - except ValueError: - raise ValueError(f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!") + except ValueError as e: + raise ValueError( + f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!" + ) from e if not isinstance(train_freq[0], int): raise ValueError(f"The frequency of `train_freq` must be an integer and not {train_freq[0]}") @@ -428,8 +428,8 @@ class OffPolicyAlgorithm(BaseAlgorithm): """ Write log. """ - time_elapsed = time.time() - self.start_time - fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time_elapsed + 1e-8)) + time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon) + fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed) self.logger.record("time/episodes", self._episode_num, exclude="tensorboard") if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 48cb365..84c89d9 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -1,3 +1,4 @@ +import sys import time from typing import Any, Dict, List, Optional, Tuple, Type, Union @@ -8,7 +9,7 @@ import torch as th from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer from stable_baselines3.common.callbacks import BaseCallback -from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy +from stable_baselines3.common.policies import ActorCriticPolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import obs_as_tensor, safe_mean from stable_baselines3.common.vec_env import VecEnv @@ -34,7 +35,6 @@ class OnPolicyAlgorithm(BaseAlgorithm): instead of action noise exploration (default: False) :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) - :param policy_base: The base policy used by this method :param tensorboard_log: the log location for tensorboard (if None, no logging) :param create_eval_env: Whether to create a second environment that will be used for evaluating the agent periodically. (Only available when passing string for the environment) @@ -62,7 +62,6 @@ class OnPolicyAlgorithm(BaseAlgorithm): max_grad_norm: float, use_sde: bool, sde_sample_freq: int, - policy_base: Type[BasePolicy] = ActorCriticPolicy, tensorboard_log: Optional[str] = None, create_eval_env: bool = False, monitor_wrapper: bool = True, @@ -74,10 +73,9 @@ class OnPolicyAlgorithm(BaseAlgorithm): supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None, ): - super(OnPolicyAlgorithm, self).__init__( + super().__init__( policy=policy, env=env, - policy_base=policy_base, learning_rate=learning_rate, policy_kwargs=policy_kwargs, verbose=verbose, @@ -257,13 +255,14 @@ class OnPolicyAlgorithm(BaseAlgorithm): # Display training infos if log_interval is not None and iteration % log_interval == 0: - fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time.time() - self.start_time)) + time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon) + fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed) self.logger.record("time/iterations", iteration, exclude="tensorboard") if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) self.logger.record("time/fps", fps) - self.logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard") + self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard") self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") self.logger.dump(step=self.num_timesteps) diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index fe36b2e..d122acd 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -67,7 +67,7 @@ class BaseModel(nn.Module, ABC): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, ): - super(BaseModel, self).__init__() + super().__init__() if optimizer_kwargs is None: optimizer_kwargs = {} @@ -267,7 +267,7 @@ class BasePolicy(BaseModel): """ def __init__(self, *args, squash_output: bool = False, **kwargs): - super(BasePolicy, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self._squash_output = squash_output @staticmethod @@ -336,8 +336,8 @@ class BasePolicy(BaseModel): with th.no_grad(): actions = self._predict(observation, deterministic=deterministic) - # Convert to numpy - actions = actions.cpu().numpy() + # Convert to numpy, and reshape to the original action shape + actions = actions.cpu().numpy().reshape((-1,) + self.action_space.shape) if isinstance(self.action_space, gym.spaces.Box): if self.squash_output: @@ -350,7 +350,7 @@ class BasePolicy(BaseModel): # Remove batch dimension if needed if not vectorized_env: - actions = actions[0] + actions = actions.squeeze(axis=0) return actions, state @@ -437,7 +437,7 @@ class ActorCriticPolicy(BasePolicy): if optimizer_class == th.optim.Adam: optimizer_kwargs["eps"] = 1e-5 - super(ActorCriticPolicy, self).__init__( + super().__init__( observation_space, action_space, features_extractor_class, @@ -592,6 +592,7 @@ class ActorCriticPolicy(BasePolicy): distribution = self._get_action_dist_from_latent(latent_pi) actions = distribution.get_actions(deterministic=deterministic) log_prob = distribution.log_prob(actions) + actions = actions.reshape((-1,) + self.action_space.shape) return actions, values, log_prob def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distribution: @@ -724,7 +725,7 @@ class ActorCriticCnnPolicy(ActorCriticPolicy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, ): - super(ActorCriticCnnPolicy, self).__init__( + super().__init__( observation_space, action_space, lr_schedule, @@ -799,7 +800,7 @@ class MultiInputActorCriticPolicy(ActorCriticPolicy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, ): - super(MultiInputActorCriticPolicy, self).__init__( + super().__init__( observation_space, action_space, lr_schedule, @@ -895,68 +896,3 @@ class ContinuousCritic(BaseModel): with th.no_grad(): features = self.extract_features(obs) return self.q_networks[0](th.cat([features, actions], dim=1)) - - -_policy_registry = dict() # type: Dict[Type[BasePolicy], Dict[str, Type[BasePolicy]]] - - -def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[BasePolicy]: - """ - Returns the registered policy from the base type and name. - See `register_policy` for registering policies and explanation. - - :param base_policy_type: the base policy class - :param name: the policy name - :return: the policy - """ - if base_policy_type not in _policy_registry: - raise KeyError(f"Error: the policy type {base_policy_type} is not registered!") - if name not in _policy_registry[base_policy_type]: - raise KeyError( - f"Error: unknown policy type {name}," - f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!" - ) - return _policy_registry[base_policy_type][name] - - -def register_policy(name: str, policy: Type[BasePolicy]) -> None: - """ - Register a policy, so it can be called using its name. - e.g. SAC('MlpPolicy', ...) instead of SAC(MlpPolicy, ...). - - The goal here is to standardize policy naming, e.g. - all algorithms can call upon "MlpPolicy" or "CnnPolicy", - and they receive respective policies that work for them. - Consider following: - - OnlinePolicy - -- OnlineMlpPolicy ("MlpPolicy") - -- OnlineCnnPolicy ("CnnPolicy") - OfflinePolicy - -- OfflineMlpPolicy ("MlpPolicy") - -- OfflineCnnPolicy ("CnnPolicy") - - Two policies have name "MlpPolicy" and two have "CnnPolicy". - In `get_policy_from_name`, the parent class (e.g. OnlinePolicy) - is given and used to select and return the correct policy. - - :param name: the policy name - :param policy: the policy class - """ - sub_class = None - for cls in BasePolicy.__subclasses__(): - if issubclass(policy, cls): - sub_class = cls - break - if sub_class is None: - raise ValueError(f"Error: the policy {policy} is not of any known subclasses of BasePolicy!") - - if sub_class not in _policy_registry: - _policy_registry[sub_class] = {} - if name in _policy_registry[sub_class]: - # Check if the registered policy is same - # we try to register. If not so, - # do not override and complain. - if _policy_registry[sub_class][name] != policy: - raise ValueError(f"Error: the name {name} is already registered for a different policy, will not override.") - _policy_registry[sub_class][name] = policy diff --git a/stable_baselines3/common/running_mean_std.py b/stable_baselines3/common/running_mean_std.py index fb3ae8b..b48f922 100644 --- a/stable_baselines3/common/running_mean_std.py +++ b/stable_baselines3/common/running_mean_std.py @@ -3,7 +3,7 @@ from typing import Tuple, Union import numpy as np -class RunningMeanStd(object): +class RunningMeanStd: def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()): """ Calulates the running mean and std of a data stream diff --git a/stable_baselines3/common/save_util.py b/stable_baselines3/common/save_util.py index e0b104f..1569001 100644 --- a/stable_baselines3/common/save_util.py +++ b/stable_baselines3/common/save_util.py @@ -206,8 +206,8 @@ def open_path(path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verb mode = mode.lower() try: mode = {"write": "w", "read": "r", "w": "w", "r": "r"}[mode] - except KeyError: - raise ValueError("Expected mode to be either 'w' or 'r'.") + except KeyError as e: + raise ValueError("Expected mode to be either 'w' or 'r'.") from e if ("w" == mode) and not path.writable() or ("r" == mode) and not path.readable(): e1 = "writable" if "w" == mode else "readable" raise ValueError(f"Expected a {e1} file.") @@ -441,7 +441,7 @@ def load_from_zip_file( # State dicts. Store into params dictionary # with same name as in .zip file (without .pth) params[os.path.splitext(file_path)[0]] = th_object - except zipfile.BadZipFile: + except zipfile.BadZipFile as e: # load_path wasn't a zip file - raise ValueError(f"Error: the file {load_path} wasn't a zip-file") + raise ValueError(f"Error: the file {load_path} wasn't a zip-file") from e return data, params, pytorch_variables diff --git a/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py b/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py index ba70a5f..377b7f6 100644 --- a/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py +++ b/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py @@ -54,21 +54,21 @@ class RMSpropTFLike(Optimizer): centered: bool = False, ): if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) + raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) + raise ValueError(f"Invalid epsilon value: {eps}") if not 0.0 <= momentum: - raise ValueError("Invalid momentum value: {}".format(momentum)) + raise ValueError(f"Invalid momentum value: {momentum}") if not 0.0 <= weight_decay: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError(f"Invalid weight_decay value: {weight_decay}") if not 0.0 <= alpha: - raise ValueError("Invalid alpha value: {}".format(alpha)) + raise ValueError(f"Invalid alpha value: {alpha}") defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay) - super(RMSpropTFLike, self).__init__(params, defaults) + super().__init__(params, defaults) def __setstate__(self, state: Dict[str, Any]) -> None: - super(RMSpropTFLike, self).__setstate__(state) + super().__setstate__(state) for group in self.param_groups: group.setdefault("momentum", 0) group.setdefault("centered", False) diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index 589d12e..f87337c 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -19,7 +19,7 @@ class BaseFeaturesExtractor(nn.Module): """ def __init__(self, observation_space: gym.Space, features_dim: int = 0): - super(BaseFeaturesExtractor, self).__init__() + super().__init__() assert features_dim > 0 self._observation_space = observation_space self._features_dim = features_dim @@ -41,7 +41,7 @@ class FlattenExtractor(BaseFeaturesExtractor): """ def __init__(self, observation_space: gym.Space): - super(FlattenExtractor, self).__init__(observation_space, get_flattened_obs_dim(observation_space)) + super().__init__(observation_space, get_flattened_obs_dim(observation_space)) self.flatten = nn.Flatten() def forward(self, observations: th.Tensor) -> th.Tensor: @@ -50,7 +50,7 @@ class FlattenExtractor(BaseFeaturesExtractor): class NatureCNN(BaseFeaturesExtractor): """ - CNN from DQN nature paper: + CNN from DQN Nature paper: Mnih, Volodymyr, et al. "Human-level control through deep reinforcement learning." Nature 518.7540 (2015): 529-533. @@ -61,7 +61,7 @@ class NatureCNN(BaseFeaturesExtractor): """ def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512): - super(NatureCNN, self).__init__(observation_space, features_dim) + super().__init__(observation_space, features_dim) # We assume CxHxW images (channels first) # Re-ordering will be done by pre-preprocessing or wrapper assert is_image_space(observation_space, check_channels=False), ( @@ -169,7 +169,7 @@ class MlpExtractor(nn.Module): activation_fn: Type[nn.Module], device: Union[th.device, str] = "auto", ): - super(MlpExtractor, self).__init__() + super().__init__() device = get_device(device) shared_net, policy_net, value_net = [], [], [] policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network @@ -250,7 +250,7 @@ class CombinedExtractor(BaseFeaturesExtractor): def __init__(self, observation_space: gym.spaces.Dict, cnn_output_dim: int = 256): # TODO we do not know features-dim here before going over all the items, so put something there. This is dirty! - super(CombinedExtractor, self).__init__(observation_space, features_dim=1) + super().__init__(observation_space, features_dim=1) extractors = {} diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index 7e69d39..f4c29ab 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -50,7 +50,7 @@ class ReplayBufferSamples(NamedTuple): class DictReplayBufferSamples(ReplayBufferSamples): observations: TensorDict actions: th.Tensor - next_observations: th.Tensor + next_observations: TensorDict dones: th.Tensor rewards: th.Tensor diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 8504c8d..94cd658 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -154,15 +154,18 @@ def get_device(device: Union[th.device, str] = "auto") -> th.device: return device -def get_latest_run_id(log_path: Optional[str] = None, log_name: str = "") -> int: +def get_latest_run_id(log_path: str = "", log_name: str = "") -> int: """ Returns the latest run number for the given log name and log path, by finding the greatest number in the directories. + :param log_path: Path to the log folder containing several runs. + :param log_name: Name of the experiment. Each run is stored + in a folder named ``log_name_1``, ``log_name_2``, ... :return: latest run number """ max_run_id = 0 - for path in glob.glob(f"{log_path}/{log_name}_[0-9]*"): + for path in glob.glob(os.path.join(log_path, f"{glob.escape(log_name)}_[0-9]*")): file_name = path.split(os.sep)[-1] ext = file_name.split("_")[-1] if log_name == "_".join(file_name.split("_")[:-1]) and ext.isdigit() and int(ext) > max_run_id: diff --git a/stable_baselines3/common/vec_env/__init__.py b/stable_baselines3/common/vec_env/__init__.py index 37ebc36..3880fbd 100644 --- a/stable_baselines3/common/vec_env/__init__.py +++ b/stable_baselines3/common/vec_env/__init__.py @@ -66,7 +66,9 @@ def sync_envs_normalization(env: "GymEnv", eval_env: "GymEnv") -> None: env_tmp, eval_env_tmp = env, eval_env while isinstance(env_tmp, VecEnvWrapper): if isinstance(env_tmp, VecNormalize): - eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms) + # Only synchronize if observation normalization exists + if hasattr(env_tmp, "obs_rms"): + eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms) eval_env_tmp.ret_rms = deepcopy(env_tmp.ret_rms) env_tmp = env_tmp.venv eval_env_tmp = eval_env_tmp.venv diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index d3e624a..9870605 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -305,7 +305,7 @@ class VecEnvWrapper(VecEnv): own_class = f"{type(self).__module__}.{type(self).__name__}" error_str = ( f"Error: Recursive attribute lookup for {name} from {own_class} is " - "ambiguous and hides attribute from {blocked_class}" + f"ambiguous and hides attribute from {blocked_class}" ) raise AttributeError(error_str) diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index 5eb87cd..c0efc8c 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -51,7 +51,9 @@ class DummyVecEnv(VecEnv): return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos)) def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: - seeds = list() + if seed is None: + seed = np.random.randint(0, 2**32 - 1) + seeds = [] for idx, env in enumerate(self.envs): seeds.append(env.seed(seed + idx)) return seeds diff --git a/stable_baselines3/common/vec_env/stacked_observations.py b/stable_baselines3/common/vec_env/stacked_observations.py index affd775..733b728 100644 --- a/stable_baselines3/common/vec_env/stacked_observations.py +++ b/stable_baselines3/common/vec_env/stacked_observations.py @@ -7,7 +7,7 @@ from gym import spaces from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first -class StackedObservations(object): +class StackedObservations: """ Frame stacking wrapper for data. diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index 1050f3e..f723c71 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -123,6 +123,8 @@ class SubprocVecEnv(VecEnv): return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: + if seed is None: + seed = np.random.randint(0, 2**32 - 1) for idx, remote in enumerate(self.remotes): remote.send(("seed", seed + idx)) return [remote.recv() for remote in self.remotes] @@ -215,6 +217,6 @@ def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: gym.space elif isinstance(space, gym.spaces.Tuple): assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space" obs_len = len(space.spaces) - return tuple((np.stack([o[i] for o in obs]) for i in range(obs_len))) + return tuple(np.stack([o[i] for o in obs]) for i in range(obs_len)) else: return np.stack(obs) diff --git a/stable_baselines3/common/vec_env/util.py b/stable_baselines3/common/vec_env/util.py index 859f1ec..ca590cb 100644 --- a/stable_baselines3/common/vec_env/util.py +++ b/stable_baselines3/common/vec_env/util.py @@ -37,7 +37,7 @@ def dict_to_obs(obs_space: gym.spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> return obs_dict elif isinstance(obs_space, gym.spaces.Tuple): assert len(obs_dict) == len(obs_space.spaces), "size of observation does not match size of observation space" - return tuple((obs_dict[i] for i in range(len(obs_space.spaces)))) + return tuple(obs_dict[i] for i in range(len(obs_space.spaces))) else: assert set(obs_dict.keys()) == {None}, "multiple observation keys for unstructured observation space" return obs_dict[None] diff --git a/stable_baselines3/common/vec_env/vec_transpose.py b/stable_baselines3/common/vec_env/vec_transpose.py index e6f728b..b6b0ad8 100644 --- a/stable_baselines3/common/vec_env/vec_transpose.py +++ b/stable_baselines3/common/vec_env/vec_transpose.py @@ -26,7 +26,7 @@ class VecTransposeImage(VecEnvWrapper): self.skip = skip # Do nothing if skip: - super(VecTransposeImage, self).__init__(venv) + super().__init__(venv) return if isinstance(venv.observation_space, spaces.dict.Dict): @@ -39,7 +39,7 @@ class VecTransposeImage(VecEnvWrapper): observation_space.spaces[key] = self.transpose_space(space, key) else: observation_space = self.transpose_space(venv.observation_space) - super(VecTransposeImage, self).__init__(venv, observation_space=observation_space) + super().__init__(venv, observation_space=observation_space) @staticmethod def transpose_space(observation_space: spaces.Box, key: str = "") -> spaces.Box: diff --git a/stable_baselines3/ddpg/ddpg.py b/stable_baselines3/ddpg/ddpg.py index 14293ca..53d3fb6 100644 --- a/stable_baselines3/ddpg/ddpg.py +++ b/stable_baselines3/ddpg/ddpg.py @@ -78,7 +78,7 @@ class DDPG(TD3): _init_setup_model: bool = True, ): - super(DDPG, self).__init__( + super().__init__( policy=policy, env=env, learning_rate=learning_rate, @@ -127,7 +127,7 @@ class DDPG(TD3): reset_num_timesteps: bool = True, ) -> OffPolicyAlgorithm: - return super(DDPG, self).learn( + return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index a7aec6b..0cd6dfb 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -8,10 +8,11 @@ from torch.nn import functional as F from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm +from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.preprocessing import maybe_transpose from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update -from stable_baselines3.dqn.policies import DQNPolicy +from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy class DQN(OffPolicyAlgorithm): @@ -19,7 +20,7 @@ class DQN(OffPolicyAlgorithm): Deep Q-Network (DQN) Paper: https://arxiv.org/abs/1312.5602, https://www.nature.com/articles/nature14236 - Default hyperparameters are taken from the nature paper, + Default hyperparameters are taken from the Nature paper, except for the optimizer and learning rate that were taken from Stable Baselines defaults. :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) @@ -59,6 +60,12 @@ class DQN(OffPolicyAlgorithm): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ + policy_aliases: Dict[str, Type[BasePolicy]] = { + "MlpPolicy": MlpPolicy, + "CnnPolicy": CnnPolicy, + "MultiInputPolicy": MultiInputPolicy, + } + def __init__( self, policy: Union[str, Type[DQNPolicy]], @@ -88,10 +95,9 @@ class DQN(OffPolicyAlgorithm): _init_setup_model: bool = True, ): - super(DQN, self).__init__( + super().__init__( policy, env, - DQNPolicy, learning_rate, buffer_size, learning_starts, @@ -132,7 +138,7 @@ class DQN(OffPolicyAlgorithm): self._setup_model() def _setup_model(self) -> None: - super(DQN, self)._setup_model() + super()._setup_model() self._create_aliases() self.exploration_schedule = get_linear_fn( self.exploration_initial_eps, @@ -255,7 +261,7 @@ class DQN(OffPolicyAlgorithm): reset_num_timesteps: bool = True, ) -> OffPolicyAlgorithm: - return super(DQN, self).learn( + return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, @@ -268,7 +274,7 @@ class DQN(OffPolicyAlgorithm): ) def _excluded_save_params(self) -> List[str]: - return super(DQN, self)._excluded_save_params() + ["q_net", "q_net_target"] + return super()._excluded_save_params() + ["q_net", "q_net_target"] def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: state_dicts = ["policy", "policy.optimizer"] diff --git a/stable_baselines3/dqn/policies.py b/stable_baselines3/dqn/policies.py index 099a4e3..ed3497c 100644 --- a/stable_baselines3/dqn/policies.py +++ b/stable_baselines3/dqn/policies.py @@ -4,7 +4,7 @@ import gym import torch as th from torch import nn -from stable_baselines3.common.policies import BasePolicy, register_policy +from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.torch_layers import ( BaseFeaturesExtractor, CombinedExtractor, @@ -37,7 +37,7 @@ class QNetwork(BasePolicy): activation_fn: Type[nn.Module] = nn.ReLU, normalize_images: bool = True, ): - super(QNetwork, self).__init__( + super().__init__( observation_space, action_space, features_extractor=features_extractor, @@ -118,7 +118,7 @@ class DQNPolicy(BasePolicy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, ): - super(DQNPolicy, self).__init__( + super().__init__( observation_space, action_space, features_extractor_class, @@ -239,7 +239,7 @@ class CnnPolicy(DQNPolicy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, ): - super(CnnPolicy, self).__init__( + super().__init__( observation_space, action_space, lr_schedule, @@ -284,7 +284,7 @@ class MultiInputPolicy(DQNPolicy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, ): - super(MultiInputPolicy, self).__init__( + super().__init__( observation_space, action_space, lr_schedule, @@ -296,8 +296,3 @@ class MultiInputPolicy(DQNPolicy): optimizer_class, optimizer_kwargs, ) - - -register_policy("MlpPolicy", MlpPolicy) -register_policy("CnnPolicy", CnnPolicy) -register_policy("MultiInputPolicy", MultiInputPolicy) diff --git a/stable_baselines3/her/her_replay_buffer.py b/stable_baselines3/her/her_replay_buffer.py index 9a41477..e3fc63e 100644 --- a/stable_baselines3/her/her_replay_buffer.py +++ b/stable_baselines3/her/her_replay_buffer.py @@ -28,13 +28,13 @@ def get_time_limit(env: VecEnv, current_max_episode_length: Optional[int]) -> in if current_max_episode_length is None: raise AttributeError # if not available check if a valid value was passed as an argument - except AttributeError: + except AttributeError as e: raise ValueError( "The max episode length could not be inferred.\n" "You must specify a `max_episode_steps` when registering the environment,\n" "use a `gym.wrappers.TimeLimit` wrapper " "or pass `max_episode_length` to the model constructor" - ) + ) from e return current_max_episode_length @@ -73,7 +73,7 @@ class HerReplayBuffer(DictReplayBuffer): self, env: VecEnv, buffer_size: int, - device: Union[th.device, str] = "cpu", + device: Union[th.device, str] = "auto", replay_buffer: Optional[DictReplayBuffer] = None, max_episode_length: Optional[int] = None, n_sampled_goal: int = 4, @@ -82,7 +82,7 @@ class HerReplayBuffer(DictReplayBuffer): handle_timeout_termination: bool = True, ): - super(HerReplayBuffer, self).__init__(buffer_size, env.observation_space, env.action_space, device, env.num_envs) + super().__init__(buffer_size, env.observation_space, env.action_space, device, env.num_envs) # convert goal_selection_strategy into GoalSelectionStrategy if string if isinstance(goal_selection_strategy, str): @@ -252,7 +252,7 @@ class HerReplayBuffer(DictReplayBuffer): elif self.goal_selection_strategy == GoalSelectionStrategy.FUTURE: # replay with random state which comes from the same episode and was observed after current transition transitions_indices = np.random.randint( - transitions_indices[her_indices] + 1, self.episode_lengths[her_episode_indices] + transitions_indices[her_indices], self.episode_lengths[her_episode_indices] ) elif self.goal_selection_strategy == GoalSelectionStrategy.EPISODE: @@ -262,7 +262,7 @@ class HerReplayBuffer(DictReplayBuffer): else: raise ValueError(f"Strategy {self.goal_selection_strategy} for sampling goals not supported!") - return self._buffer["achieved_goal"][her_episode_indices, transitions_indices] + return self._buffer["next_achieved_goal"][her_episode_indices, transitions_indices] def _sample_transitions( self, @@ -304,14 +304,6 @@ class HerReplayBuffer(DictReplayBuffer): ep_lengths = self.episode_lengths[episode_indices] - # Special case when using the "future" goal sampling strategy - # we cannot sample all transitions, we have to remove the last timestep - if self.goal_selection_strategy == GoalSelectionStrategy.FUTURE: - # restrict the sampling domain when ep_lengths > 1 - # otherwise filter out the indices - her_indices = her_indices[ep_lengths[her_indices] > 1] - ep_lengths[her_indices] -= 1 - if online_sampling: # Select which transitions to use transitions_indices = np.random.randint(ep_lengths) diff --git a/stable_baselines3/ppo/policies.py b/stable_baselines3/ppo/policies.py index 7427cfc..fb7afae 100644 --- a/stable_baselines3/ppo/policies.py +++ b/stable_baselines3/ppo/policies.py @@ -1,16 +1,7 @@ # This file is here just to define MlpPolicy/CnnPolicy # that work for PPO -from stable_baselines3.common.policies import ( - ActorCriticCnnPolicy, - ActorCriticPolicy, - MultiInputActorCriticPolicy, - register_policy, -) +from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy MlpPolicy = ActorCriticPolicy CnnPolicy = ActorCriticCnnPolicy MultiInputPolicy = MultiInputActorCriticPolicy - -register_policy("MlpPolicy", ActorCriticPolicy) -register_policy("CnnPolicy", ActorCriticCnnPolicy) -register_policy("MultiInputPolicy", MultiInputPolicy) diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index 088bab3..5b8d9e2 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -7,7 +7,7 @@ from gym import spaces from torch.nn import functional as F from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm -from stable_baselines3.common.policies import ActorCriticPolicy +from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance, get_schedule_fn @@ -19,7 +19,7 @@ class PPO(OnPolicyAlgorithm): Paper: https://arxiv.org/abs/1707.06347 Code: This implementation borrows code from OpenAI Spinning Up (https://github.com/openai/spinningup/) https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and - and Stable Baselines (PPO2 from https://github.com/hill-a/stable-baselines) + Stable Baselines (PPO2 from https://github.com/hill-a/stable-baselines) Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html @@ -65,6 +65,12 @@ class PPO(OnPolicyAlgorithm): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ + policy_aliases: Dict[str, Type[BasePolicy]] = { + "MlpPolicy": ActorCriticPolicy, + "CnnPolicy": ActorCriticCnnPolicy, + "MultiInputPolicy": MultiInputActorCriticPolicy, + } + def __init__( self, policy: Union[str, Type[ActorCriticPolicy]], @@ -93,7 +99,7 @@ class PPO(OnPolicyAlgorithm): _init_setup_model: bool = True, ): - super(PPO, self).__init__( + super().__init__( policy, env, learning_rate=learning_rate, @@ -156,7 +162,7 @@ class PPO(OnPolicyAlgorithm): self._setup_model() def _setup_model(self) -> None: - super(PPO, self)._setup_model() + super()._setup_model() # Initialize schedules for policy/value clipping self.clip_range = get_schedule_fn(self.clip_range) @@ -301,7 +307,7 @@ class PPO(OnPolicyAlgorithm): reset_num_timesteps: bool = True, ) -> "PPO": - return super(PPO, self).learn( + return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py index 0bd1382..255bd75 100644 --- a/stable_baselines3/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -6,7 +6,7 @@ import torch as th from torch import nn from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution -from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, register_policy +from stable_baselines3.common.policies import BasePolicy, ContinuousCritic from stable_baselines3.common.preprocessing import get_action_dim from stable_baselines3.common.torch_layers import ( BaseFeaturesExtractor, @@ -65,7 +65,7 @@ class Actor(BasePolicy): clip_mean: float = 2.0, normalize_images: bool = True, ): - super(Actor, self).__init__( + super().__init__( observation_space, action_space, features_extractor=features_extractor, @@ -235,9 +235,9 @@ class SACPolicy(BasePolicy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, n_critics: int = 2, - share_features_extractor: bool = True, + share_features_extractor: bool = False, ): - super(SACPolicy, self).__init__( + super().__init__( observation_space, action_space, features_extractor_class, @@ -248,10 +248,7 @@ class SACPolicy(BasePolicy): ) if net_arch is None: - if features_extractor_class == NatureCNN: - net_arch = [] - else: - net_arch = [256, 256] + net_arch = [256, 256] actor_arch, critic_arch = get_actor_critic_arch(net_arch) @@ -422,9 +419,9 @@ class CnnPolicy(SACPolicy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, n_critics: int = 2, - share_features_extractor: bool = True, + share_features_extractor: bool = False, ): - super(CnnPolicy, self).__init__( + super().__init__( observation_space, action_space, lr_schedule, @@ -493,9 +490,9 @@ class MultiInputPolicy(SACPolicy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, n_critics: int = 2, - share_features_extractor: bool = True, + share_features_extractor: bool = False, ): - super(MultiInputPolicy, self).__init__( + super().__init__( observation_space, action_space, lr_schedule, @@ -514,8 +511,3 @@ class MultiInputPolicy(SACPolicy): n_critics, share_features_extractor, ) - - -register_policy("MlpPolicy", MlpPolicy) -register_policy("CnnPolicy", CnnPolicy) -register_policy("MultiInputPolicy", MultiInputPolicy) diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index 1a1ae1f..b7cbcf6 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -8,9 +8,10 @@ from torch.nn import functional as F from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm +from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import polyak_update -from stable_baselines3.sac.policies import SACPolicy +from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy class SAC(OffPolicyAlgorithm): @@ -72,6 +73,12 @@ class SAC(OffPolicyAlgorithm): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ + policy_aliases: Dict[str, Type[BasePolicy]] = { + "MlpPolicy": MlpPolicy, + "CnnPolicy": CnnPolicy, + "MultiInputPolicy": MultiInputPolicy, + } + def __init__( self, policy: Union[str, Type[SACPolicy]], @@ -103,10 +110,9 @@ class SAC(OffPolicyAlgorithm): _init_setup_model: bool = True, ): - super(SAC, self).__init__( + super().__init__( policy, env, - SACPolicy, learning_rate, buffer_size, learning_starts, @@ -144,7 +150,7 @@ class SAC(OffPolicyAlgorithm): self._setup_model() def _setup_model(self) -> None: - super(SAC, self)._setup_model() + super()._setup_model() self._create_aliases() # Target entropy is used when learning the entropy coefficient if self.target_entropy == "auto": @@ -248,7 +254,7 @@ class SAC(OffPolicyAlgorithm): current_q_values = self.critic(replay_data.observations, replay_data.actions) # Compute critic loss - critic_loss = 0.5 * sum([F.mse_loss(current_q, target_q_values) for current_q in current_q_values]) + critic_loss = 0.5 * sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values) critic_losses.append(critic_loss.item()) # Optimize the critic @@ -297,7 +303,7 @@ class SAC(OffPolicyAlgorithm): reset_num_timesteps: bool = True, ) -> OffPolicyAlgorithm: - return super(SAC, self).learn( + return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, @@ -310,7 +316,7 @@ class SAC(OffPolicyAlgorithm): ) def _excluded_save_params(self) -> List[str]: - return super(SAC, self)._excluded_save_params() + ["actor", "critic", "critic_target"] + return super()._excluded_save_params() + ["actor", "critic", "critic_target"] def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: state_dicts = ["policy", "actor.optimizer", "critic.optimizer"] diff --git a/stable_baselines3/td3/policies.py b/stable_baselines3/td3/policies.py index 264c760..8781b32 100644 --- a/stable_baselines3/td3/policies.py +++ b/stable_baselines3/td3/policies.py @@ -4,7 +4,7 @@ import gym import torch as th from torch import nn -from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, register_policy +from stable_baselines3.common.policies import BasePolicy, ContinuousCritic from stable_baselines3.common.preprocessing import get_action_dim from stable_baselines3.common.torch_layers import ( BaseFeaturesExtractor, @@ -42,7 +42,7 @@ class Actor(BasePolicy): activation_fn: Type[nn.Module] = nn.ReLU, normalize_images: bool = True, ): - super(Actor, self).__init__( + super().__init__( observation_space, action_space, features_extractor=features_extractor, @@ -119,9 +119,9 @@ class TD3Policy(BasePolicy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, n_critics: int = 2, - share_features_extractor: bool = True, + share_features_extractor: bool = False, ): - super(TD3Policy, self).__init__( + super().__init__( observation_space, action_space, features_extractor_class, @@ -134,7 +134,7 @@ class TD3Policy(BasePolicy): # Default network architecture, from the original paper if net_arch is None: if features_extractor_class == NatureCNN: - net_arch = [] + net_arch = [256, 256] else: net_arch = [400, 300] @@ -281,9 +281,9 @@ class CnnPolicy(TD3Policy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, n_critics: int = 2, - share_features_extractor: bool = True, + share_features_extractor: bool = False, ): - super(CnnPolicy, self).__init__( + super().__init__( observation_space, action_space, lr_schedule, @@ -335,9 +335,9 @@ class MultiInputPolicy(TD3Policy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, n_critics: int = 2, - share_features_extractor: bool = True, + share_features_extractor: bool = False, ): - super(MultiInputPolicy, self).__init__( + super().__init__( observation_space, action_space, lr_schedule, @@ -351,8 +351,3 @@ class MultiInputPolicy(TD3Policy): n_critics, share_features_extractor, ) - - -register_policy("MlpPolicy", MlpPolicy) -register_policy("CnnPolicy", CnnPolicy) -register_policy("MultiInputPolicy", MultiInputPolicy) diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index eb257a6..34a783d 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -8,9 +8,10 @@ from torch.nn import functional as F from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm +from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import polyak_update -from stable_baselines3.td3.policies import TD3Policy +from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TD3Policy class TD3(OffPolicyAlgorithm): @@ -60,6 +61,12 @@ class TD3(OffPolicyAlgorithm): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ + policy_aliases: Dict[str, Type[BasePolicy]] = { + "MlpPolicy": MlpPolicy, + "CnnPolicy": CnnPolicy, + "MultiInputPolicy": MultiInputPolicy, + } + def __init__( self, policy: Union[str, Type[TD3Policy]], @@ -88,10 +95,9 @@ class TD3(OffPolicyAlgorithm): _init_setup_model: bool = True, ): - super(TD3, self).__init__( + super().__init__( policy, env, - TD3Policy, learning_rate, buffer_size, learning_starts, @@ -123,7 +129,7 @@ class TD3(OffPolicyAlgorithm): self._setup_model() def _setup_model(self) -> None: - super(TD3, self)._setup_model() + super()._setup_model() self._create_aliases() def _create_aliases(self) -> None: @@ -162,7 +168,7 @@ class TD3(OffPolicyAlgorithm): current_q_values = self.critic(replay_data.observations, replay_data.actions) # Compute critic loss - critic_loss = sum([F.mse_loss(current_q, target_q_values) for current_q in current_q_values]) + critic_loss = sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values) critic_losses.append(critic_loss.item()) # Optimize the critics @@ -202,7 +208,7 @@ class TD3(OffPolicyAlgorithm): reset_num_timesteps: bool = True, ) -> OffPolicyAlgorithm: - return super(TD3, self).learn( + return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, @@ -215,7 +221,7 @@ class TD3(OffPolicyAlgorithm): ) def _excluded_save_params(self) -> List[str]: - return super(TD3, self)._excluded_save_params() + ["actor", "critic", "actor_target", "critic_target"] + return super()._excluded_save_params() + ["actor", "critic", "actor_target", "critic_target"] def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: state_dicts = ["policy", "actor.optimizer", "critic.optimizer"] diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 33271c4..035e3b6 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.5.1a0 +1.6.1a0 diff --git a/tests/test_buffers.py b/tests/test_buffers.py index 45c5e6a..0e028e6 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -4,9 +4,10 @@ import pytest import torch as th from gym import spaces -from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer +from stable_baselines3.common.buffers import DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.type_aliases import DictReplayBufferSamples, ReplayBufferSamples +from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import VecNormalize @@ -71,7 +72,7 @@ def test_replay_buffer_normalization(replay_buffer_cls): env = make_vec_env(env) env = VecNormalize(env) - buffer = replay_buffer_cls(100, env.observation_space, env.action_space) + buffer = replay_buffer_cls(100, env.observation_space, env.action_space, device="cpu") # Interract and store transitions env.reset() @@ -94,3 +95,47 @@ def test_replay_buffer_normalization(replay_buffer_cls): assert th.allclose(observations.mean(0), th.zeros(1), atol=1) # Test reward normalization assert np.allclose(sample.rewards.mean(0), np.zeros(1), atol=1) + + +@pytest.mark.parametrize("replay_buffer_cls", [DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer]) +@pytest.mark.parametrize("device", ["cpu", "cuda", "auto"]) +def test_device_buffer(replay_buffer_cls, device): + if device == "cuda" and not th.cuda.is_available(): + pytest.skip("CUDA not available") + + env = { + RolloutBuffer: DummyEnv, + DictRolloutBuffer: DummyDictEnv, + ReplayBuffer: DummyEnv, + DictReplayBuffer: DummyDictEnv, + }[replay_buffer_cls] + env = make_vec_env(env) + + buffer = replay_buffer_cls(100, env.observation_space, env.action_space, device=device) + + # Interract and store transitions + obs = env.reset() + for _ in range(100): + action = env.action_space.sample() + next_obs, reward, done, info = env.step(action) + if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]: + episode_start, values, log_prob = np.zeros(1), th.zeros(1), th.ones(1) + buffer.add(obs, action, reward, episode_start, values, log_prob) + else: + buffer.add(obs, next_obs, action, reward, done, info) + obs = next_obs + + # Get data from the buffer + if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]: + data = buffer.get(50) + elif replay_buffer_cls in [ReplayBuffer, DictReplayBuffer]: + data = buffer.sample(50) + + # Check that all data are on the desired device + desired_device = get_device(device).type + for value in list(data): + if isinstance(value, dict): + for key in value.keys(): + assert value[key].device.type == desired_device + elif isinstance(value, th.Tensor): + assert value.device.type == desired_device diff --git a/tests/test_distributions.py b/tests/test_distributions.py index 3652b18..07920db 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -163,7 +163,9 @@ def test_categorical(dist, CAT_ACTIONS): BernoulliDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS)), CategoricalDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS)), DiagGaussianDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS), th.rand(N_ACTIONS)), - MultiCategoricalDistribution([N_ACTIONS, N_ACTIONS]).proba_distribution(th.rand(1, sum([N_ACTIONS, N_ACTIONS]))), + MultiCategoricalDistribution(np.array([N_ACTIONS, N_ACTIONS])).proba_distribution( + th.rand(1, sum([N_ACTIONS, N_ACTIONS])) + ), SquashedDiagGaussianDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS), th.rand(N_ACTIONS)), StateDependentNoiseDistribution(N_ACTIONS).proba_distribution( th.rand(N_ACTIONS), th.rand([N_ACTIONS, N_ACTIONS]), th.rand([N_ACTIONS, N_ACTIONS]) diff --git a/tests/test_envs.py b/tests/test_envs.py index 671e2a5..1cec17a 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -141,6 +141,8 @@ def test_non_default_spaces(new_obs_space): spaces.Box(low=1, high=-1, shape=(2,), dtype=np.float32), # Same boundaries spaces.Box(low=1, high=1, shape=(2,), dtype=np.float32), + # Unbounded action space + spaces.Box(low=-np.inf, high=1, shape=(2,), dtype=np.float32), # Almost good, except for one dim spaces.Box(low=np.array([-1, -1, -1]), high=np.array([1, 1, 0.99]), dtype=np.float32), ], @@ -156,8 +158,14 @@ def test_non_default_action_spaces(new_action_space): # Change the action space env.action_space = new_action_space - with pytest.warns(UserWarning): - check_env(env) + # Unbounded action space throws an error, + # the rest only warning + if not np.all(np.isfinite(env.action_space.low)): + with pytest.raises(AssertionError), pytest.warns(UserWarning): + check_env(env) + else: + with pytest.warns(UserWarning): + check_env(env) def check_reset_assert_error(env, new_reset_return): diff --git a/tests/test_gae.py b/tests/test_gae.py index 54e03b8..8e461ed 100644 --- a/tests/test_gae.py +++ b/tests/test_gae.py @@ -10,7 +10,7 @@ from stable_baselines3.common.policies import ActorCriticPolicy class CustomEnv(gym.Env): def __init__(self, max_steps=8): - super(CustomEnv, self).__init__() + super().__init__() self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) self.max_steps = max_steps @@ -54,7 +54,7 @@ class InfiniteHorizonEnv(gym.Env): class CheckGAECallback(BaseCallback): def __init__(self): - super(CheckGAECallback, self).__init__(verbose=0) + super().__init__(verbose=0) def _on_rollout_end(self): buffer = self.model.rollout_buffer @@ -99,7 +99,7 @@ class CustomPolicy(ActorCriticPolicy): """Custom Policy with a constant value function""" def __init__(self, *args, **kwargs): - super(CustomPolicy, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.constant_value = 0.0 def forward(self, obs, deterministic=False): diff --git a/tests/test_her.py b/tests/test_her.py index 0f6d75f..888d36a 100644 --- a/tests/test_her.py +++ b/tests/test_her.py @@ -156,7 +156,7 @@ def test_save_load(tmp_path, model_class, use_sde, online_sampling): params = deepcopy(model.policy.state_dict()) # Modify all parameters to be random values - random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items()) + random_params = {param_name: th.rand_like(param) for param_name, param in params.items()} # Update model parameters with the new random values model.policy.load_state_dict(random_params) diff --git a/tests/test_logger.py b/tests/test_logger.py index 6fe536f..a55f88a 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -1,6 +1,7 @@ import os import time from typing import Sequence +from unittest import mock import gym import numpy as np @@ -381,3 +382,16 @@ def test_fps_logger(tmp_path, algo): # third time, FPS should be the same model.learn(100, log_interval=1, reset_num_timesteps=False) assert max_fps / 10 <= logger.name_to_value["time/fps"] <= max_fps + + +@pytest.mark.parametrize("algo", [A2C, DQN]) +def test_fps_no_div_zero(algo): + """Set time to constant and train algorithm to check no division by zero error. + + Time can appear to be constant during short runs on platforms with low-precision + timers. We should avoid division by zero errors e.g. when computing FPS in + this situation.""" + with mock.patch("time.time", lambda: 42.0): + with mock.patch("time.time_ns", lambda: 42.0): + model = algo("MlpPolicy", "CartPole-v1") + model.learn(total_timesteps=100) diff --git a/tests/test_monitor.py b/tests/test_monitor.py index d3d041b..4c1d3cf 100644 --- a/tests/test_monitor.py +++ b/tests/test_monitor.py @@ -14,7 +14,7 @@ def test_monitor(tmp_path): """ env = gym.make("CartPole-v1") env.seed(0) - monitor_file = os.path.join(str(tmp_path), "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4())) + monitor_file = os.path.join(str(tmp_path), f"stable_baselines-test-{uuid.uuid4()}.monitor.csv") monitor_env = Monitor(env, monitor_file) monitor_env.reset() total_steps = 1000 @@ -37,7 +37,7 @@ def test_monitor(tmp_path): assert sum(monitor_env.get_episode_rewards()) == sum(ep_rewards) _ = monitor_env.get_episode_times() - with open(monitor_file, "rt") as file_handler: + with open(monitor_file) as file_handler: first_line = file_handler.readline() assert first_line.startswith("#") metadata = json.loads(first_line[1:]) @@ -56,7 +56,7 @@ def test_monitor_load_results(tmp_path): tmp_path = str(tmp_path) env1 = gym.make("CartPole-v1") env1.seed(0) - monitor_file1 = os.path.join(tmp_path, "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4())) + monitor_file1 = os.path.join(tmp_path, f"stable_baselines-test-{uuid.uuid4()}.monitor.csv") monitor_env1 = Monitor(env1, monitor_file1) monitor_files = get_monitor_files(tmp_path) @@ -76,7 +76,7 @@ def test_monitor_load_results(tmp_path): env2 = gym.make("CartPole-v1") env2.seed(0) - monitor_file2 = os.path.join(tmp_path, "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4())) + monitor_file2 = os.path.join(tmp_path, f"stable_baselines-test-{uuid.uuid4()}.monitor.csv") monitor_env2 = Monitor(env2, monitor_file2) monitor_files = get_monitor_files(tmp_path) assert len(monitor_files) == 2 diff --git a/tests/test_predict.py b/tests/test_predict.py index 853f4d1..89cdb09 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -73,11 +73,13 @@ def test_predict(model_class, env_id, device): obs = env.reset() action, _ = model.predict(obs) + assert isinstance(action, np.ndarray) assert action.shape == env.action_space.shape assert env.action_space.contains(action) vec_env_obs = vec_env.reset() action, _ = model.predict(vec_env_obs) + assert isinstance(action, np.ndarray) assert action.shape[0] == vec_env_obs.shape[0] # Special case for DQN to check the epsilon greedy exploration diff --git a/tests/test_run.py b/tests/test_run.py index e4e8a2e..b0a9a11 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -10,7 +10,10 @@ normal_action_noise = NormalActionNoise(np.zeros(1), 0.1 * np.ones(1)) @pytest.mark.parametrize("model_class", [TD3, DDPG]) -@pytest.mark.parametrize("action_noise", [normal_action_noise, OrnsteinUhlenbeckActionNoise(np.zeros(1), 0.1 * np.ones(1))]) +@pytest.mark.parametrize( + "action_noise", + [normal_action_noise, OrnsteinUhlenbeckActionNoise(np.zeros(1), 0.1 * np.ones(1))], +) def test_deterministic_pg(model_class, action_noise): """ Test for DDPG and variants (TD3). diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 452e6fb..d7a74c5 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -64,7 +64,7 @@ def test_save_load(tmp_path, model_class): model.set_parameters(invalid_object_params, exact_match=False) # Test that exact_match catches when something was missed. - missing_object_params = dict((k, v) for k, v in list(original_params.items())[:-1]) + missing_object_params = {k: v for k, v in list(original_params.items())[:-1]} with pytest.raises(ValueError): model.set_parameters(missing_object_params, exact_match=True) @@ -375,6 +375,9 @@ def test_warn_buffer(recwarn, model_class, optimize_memory_usage): select_env(model_class), buffer_size=100, optimize_memory_usage=optimize_memory_usage, + # we cannot use optimize_memory_usage and handle_timeout_termination + # at the same time + replay_buffer_kwargs={"handle_timeout_termination": not optimize_memory_usage}, policy_kwargs=dict(net_arch=[64]), learning_starts=10, ) @@ -446,7 +449,7 @@ def test_save_load_policy(tmp_path, model_class, policy_str, use_sde): params = deepcopy(policy.state_dict()) # Modify all parameters to be random values - random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items()) + random_params = {param_name: th.rand_like(param) for param_name, param in params.items()} # Update model parameters with the new random values policy.load_state_dict(random_params) @@ -537,7 +540,7 @@ def test_save_load_q_net(tmp_path, model_class, policy_str): params = deepcopy(q_net.state_dict()) # Modify all parameters to be random values - random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items()) + random_params = {param_name: th.rand_like(param) for param_name, param in params.items()} # Update model parameters with the new random values q_net.load_state_dict(random_params) diff --git a/tests/test_spaces.py b/tests/test_spaces.py index 54994b2..0696492 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -9,7 +9,7 @@ from stable_baselines3.common.evaluation import evaluate_policy class DummyMultiDiscreteSpace(gym.Env): def __init__(self, nvec): - super(DummyMultiDiscreteSpace, self).__init__() + super().__init__() self.observation_space = gym.spaces.MultiDiscrete(nvec) self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) @@ -22,7 +22,7 @@ class DummyMultiDiscreteSpace(gym.Env): class DummyMultiBinary(gym.Env): def __init__(self, n): - super(DummyMultiBinary, self).__init__() + super().__init__() self.observation_space = gym.spaces.MultiBinary(n) self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) @@ -33,6 +33,19 @@ class DummyMultiBinary(gym.Env): return self.observation_space.sample(), 0.0, False, {} +class DummyMultidimensionalAction(gym.Env): + def __init__(self): + super().__init__() + self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) + self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32) + + def reset(self): + return self.observation_space.sample() + + def step(self, action): + return self.observation_space.sample(), 0.0, False, {} + + @pytest.mark.parametrize("model_class", [SAC, TD3, DQN]) @pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8)]) def test_identity_spaces(model_class, env): @@ -53,22 +66,39 @@ def test_identity_spaces(model_class, env): @pytest.mark.parametrize("model_class", [A2C, DDPG, DQN, PPO, SAC, TD3]) -@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1"]) +@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1", DummyMultidimensionalAction()]) def test_action_spaces(model_class, env): + kwargs = {} if model_class in [SAC, DDPG, TD3]: - supported_action_space = env == "Pendulum-v1" + supported_action_space = env == "Pendulum-v1" or isinstance(env, DummyMultidimensionalAction) + kwargs["learning_starts"] = 2 + kwargs["train_freq"] = 32 elif model_class == DQN: supported_action_space = env == "CartPole-v1" elif model_class in [A2C, PPO]: supported_action_space = True + kwargs["n_steps"] = 64 if supported_action_space: - model_class("MlpPolicy", env) + model = model_class("MlpPolicy", env, **kwargs) + if isinstance(env, DummyMultidimensionalAction): + model.learn(64) else: with pytest.raises(AssertionError): model_class("MlpPolicy", env) +def test_sde_multi_dim(): + SAC( + "MlpPolicy", + DummyMultidimensionalAction(), + learning_starts=10, + use_sde=True, + sde_sample_freq=2, + use_sde_at_warmup=True, + ).learn(20) + + @pytest.mark.parametrize("model_class", [A2C, PPO, DQN]) @pytest.mark.parametrize("env", ["Taxi-v3"]) def test_discrete_obs_space(model_class, env): diff --git a/tests/test_tensorboard.py b/tests/test_tensorboard.py index 20f58b9..6dccf41 100644 --- a/tests/test_tensorboard.py +++ b/tests/test_tensorboard.py @@ -3,6 +3,7 @@ import os import pytest from stable_baselines3 import A2C, PPO, SAC, TD3 +from stable_baselines3.common.utils import get_latest_run_id MODEL_DICT = { "a2c": (A2C, "CartPole-v1"), @@ -35,3 +36,13 @@ def test_tensorboard(tmp_path, model_name): assert os.path.isdir(tmp_path / str(logname + "_1")) # Check that the log dir name increments correctly assert os.path.isdir(tmp_path / str(logname + "_2")) + + +def test_escape_log_name(tmp_path): + # Log name that must be escaped + log_name = "filename[16, 16]" + # Create folder + os.makedirs(str(tmp_path) + f"/{log_name}_1", exist_ok=True) + os.makedirs(str(tmp_path) + f"/{log_name}_2", exist_ok=True) + last_run_id = get_latest_run_id(tmp_path, log_name) + assert last_run_id == 2 diff --git a/tests/test_train_eval_mode.py b/tests/test_train_eval_mode.py index 1ea2efe..4f023e9 100644 --- a/tests/test_train_eval_mode.py +++ b/tests/test_train_eval_mode.py @@ -28,7 +28,7 @@ class FlattenBatchNormDropoutExtractor(BaseFeaturesExtractor): """ def __init__(self, observation_space: gym.Space): - super(FlattenBatchNormDropoutExtractor, self).__init__( + super().__init__( observation_space, get_flattened_obs_dim(observation_space), ) diff --git a/tests/test_utils.py b/tests/test_utils.py index b07bbe9..67f2ad1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -180,7 +180,7 @@ class AlwaysDoneWrapper(gym.Wrapper): # Pretends that environment only has single step for each # episode. def __init__(self, env): - super(AlwaysDoneWrapper, self).__init__(env) + super().__init__(env) self.last_obs = None self.needs_reset = True diff --git a/tests/test_vec_check_nan.py b/tests/test_vec_check_nan.py index 265da2e..9623557 100644 --- a/tests/test_vec_check_nan.py +++ b/tests/test_vec_check_nan.py @@ -12,7 +12,7 @@ class NanAndInfEnv(gym.Env): metadata = {"render.modes": ["human"]} def __init__(self): - super(NanAndInfEnv, self).__init__() + super().__init__() self.action_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64) self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64) diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index 9a4c118..93ea348 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -31,7 +31,7 @@ class CustomGymEnv(gym.Env): return self.state def step(self, action): - reward = 1 + reward = float(np.random.rand()) self._choose_next_state() self.current_step += 1 done = self.current_step >= self.ep_length @@ -45,7 +45,9 @@ class CustomGymEnv(gym.Env): return np.zeros((4, 4, 3)) def seed(self, seed=None): - pass + if seed is not None: + np.random.seed(seed) + self.observation_space.seed(seed) @staticmethod def custom_method(dim_0=1, dim_1=1): @@ -440,3 +442,34 @@ def test_vec_env_is_wrapped(): vec_env = VecFrameStack(vec_env, n_stack=2) assert vec_env.env_is_wrapped(Monitor) == [False, True] + + +@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES) +def test_vec_seeding(vec_env_class): + def make_env(): + return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2))) + + # For SubprocVecEnv check for all starting methods + start_methods = [None] + if vec_env_class != DummyVecEnv: + all_methods = {"forkserver", "spawn", "fork"} + available_methods = multiprocessing.get_all_start_methods() + start_methods = list(all_methods.intersection(available_methods)) + + for start_method in start_methods: + if start_method is not None: + vec_env_class = functools.partial(SubprocVecEnv, start_method=start_method) + + n_envs = 3 + vec_env = vec_env_class([make_env] * n_envs) + # Seed with no argument + vec_env.seed() + obs = vec_env.reset() + _, rewards, _, _ = vec_env.step(np.array([vec_env.action_space.sample() for _ in range(n_envs)])) + # Seed should be different per process + assert not np.allclose(obs[0], obs[1]) + assert not np.allclose(rewards[0], rewards[1]) + assert not np.allclose(obs[1], obs[2]) + assert not np.allclose(rewards[1], rewards[2]) + + vec_env.close() diff --git a/tests/test_vec_monitor.py b/tests/test_vec_monitor.py index 974202b..5ccc33e 100644 --- a/tests/test_vec_monitor.py +++ b/tests/test_vec_monitor.py @@ -36,7 +36,7 @@ def test_vec_monitor(tmp_path): monitor_env.close() - with open(monitor_file, "rt") as file_handler: + with open(monitor_file) as file_handler: first_line = file_handler.readline() assert first_line.startswith("#") metadata = json.loads(first_line[1:]) @@ -66,7 +66,7 @@ def test_vec_monitor_info_keywords(tmp_path): monitor_env.close() - with open(monitor_file, "rt") as f: + with open(monitor_file) as f: reader = csv.reader(f) for i, line in enumerate(reader): if i == 0 or i == 1: diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index 07ad77f..a363e40 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -47,7 +47,7 @@ class DummyDictEnv(gym.GoalEnv): """ def __init__(self): - super(DummyDictEnv, self).__init__() + super().__init__() self.observation_space = spaces.Dict( { "observation": spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32), @@ -388,11 +388,11 @@ def test_offpolicy_normalization(model_class, online_sampling): @pytest.mark.parametrize("make_env", [make_env, make_dict_env]) def test_sync_vec_normalize(make_env): - env = DummyVecEnv([make_env]) + original_env = DummyVecEnv([make_env]) - assert unwrap_vec_normalize(env) is None + assert unwrap_vec_normalize(original_env) is None - env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=100.0, clip_reward=100.0) + env = VecNormalize(original_env, norm_obs=True, norm_reward=True, clip_obs=100.0, clip_reward=100.0) assert isinstance(unwrap_vec_normalize(env), VecNormalize) @@ -433,6 +433,17 @@ def test_sync_vec_normalize(make_env): assert allclose(obs, eval_env.normalize_obs(original_obs)) assert allclose(env.normalize_reward(dummy_rewards), eval_env.normalize_reward(dummy_rewards)) + # Check synchronization when only reward is normalized + env = VecNormalize(original_env, norm_obs=False, norm_reward=True, clip_reward=100.0) + eval_env = DummyVecEnv([make_env]) + eval_env = VecNormalize(eval_env, training=False, norm_obs=False, norm_reward=False) + env.reset() + env.step([env.action_space.sample()]) + assert not np.allclose(env.ret_rms.mean, eval_env.ret_rms.mean) + sync_envs_normalization(env, eval_env) + assert np.allclose(env.ret_rms.mean, eval_env.ret_rms.mean) + assert np.allclose(env.ret_rms.var, eval_env.ret_rms.var) + def test_discrete_obs(): with pytest.raises(ValueError, match=".*only supports.*"):