stable-baselines3/stable_baselines3/common/on_policy_algorithm.py

323 lines
13 KiB
Python
Raw Permalink Normal View History

import sys
import time
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import numpy as np
import torch as th
Add Gymnasium support (#1327) * Fix failing set_env test * Fix test failiing due to deprectation of env.seed * Adjust mean reward threshold in failing test * Fix her test failing due to rng * Change seed and revert reward threshold to 90 * Pin gym version * Make VecEnv compatible with gym seeding change * Revert change to VecEnv reset signature * Change subprocenv seed cmd to call reset instead * Fix type check * Add backward compat * Add `compat_gym_seed` helper * Add goal env checks in env_checker * Add docs on HER requirements for envs * Capture user warning in test with inverted box space * Update ale-py version * Fix randint * Allow noop_max to be zero * Update changelog * Update docker image * Update doc conda env and dockerfile * Custom envs should not have any warnings * Fix test for numpy >= 1.21 * Add check for vectorized compute reward * Bump to gym 0.24 * Fix gym default step docstring * Test downgrading gym * Revert "Test downgrading gym" This reverts commit 0072b77156c006ada8a1d6e26ce347ed85a83eeb. * Fix protobuf error * Fix in dependencies * Fix protobuf dep * Use newest version of cartpole * Update gym * Fix warning * Loosen required scipy version * Scipy no longer needed * Try gym 0.25 * Silence warnings from gym * Filter warnings during tests * Update doc * Update requirements * Add gym 26 compat in vec env * Fixes in envs and tests for gym 0.26+ * Enforce gym 0.26 api * format * Fix formatting * Fix dependencies * Fix syntax * Cleanup doc and warnings * Faster tests * Higher budget for HER perf test (revert prev change) * Fixes and update doc * Fix doc build * Fix breaking change * Fixes for rendering * Rename variables in monitor * update render method for gym 0.26 API backwards compatible (mode argument is allowed) while using the gym 0.26 API (render mode is determined at environment creation) * update tests and docs to new gym render API * undo removal of render modes metatadata check * set rgb_array as default render mode for gym.make * undo changes & raise warning if not 'rgb_array' * Fix type check * Remove recursion and fix type checking * Remove hacks for protobuf and gym 0.24 * Fix type annotations * reuse existing render_mode attribute * return tiled images for 'human' render mode * Allow to use opencv for human render, fix typos * Add warning when using non-zero start with Discrete (fixes #1197) * Fix type checking * Bug fixes and handle more cases * Throw proper warnings * Update test * Fix new metadata name * Ignore numpy warnings * Fixes in vec recorder * Global ignore * Filter local warning too * Monkey patch not needed for gym 26 * Add doc of VecEnv vs Gym API * Add render test * Fix return type * Update VecEnv vs Gym API doc * Fix for custom render mode * Fix return type * Fix type checking * check test env test_buffer * skip render check * check env test_dict_env * test_env test_gae * check envs in remaining tests * Update tests * Add warning for Discrete action space with non-zero (#1295) * Fix atari annotation * ignore get_action_meanings [attr-defined] * Fix mypy issues * Add patch for gym/gymnasium transition * Switch to gymnasium * Rely on signature instead of version * More patches * Type ignore because of https://github.com/Farama-Foundation/Gymnasium/pull/39 * Fix doc build * Fix pytype errors * Fix atari requirement * Update env checker due to change in dtype for Discrete * Fix type hint * Convert spaces for saved models * Ignore pytype * Remove gitlab CI * Disable pytype for convert space * Fix undefined info * Fix undefined info * Upgrade shimmy * Fix wrappers type annotation (need PR from Gymnasium) * Fix gymnasium dependency * Fix dependency declaration * Cap pygame version for python 3.7 * Point to master branch (v0.28.0) * Fix: use main not master branch * Rename done to terminated * Fix pygame dependency for python 3.7 * Rename gym to gymnasium * Update Gymnasium * Fix test * Fix tests * Forks don't have access to private variables * Fix linter warnings * Update read the doc env * Fix env checker for GoalEnv * Fix import * Update env checker (more info) and fix dtype * Use micromamab for Docker * Update dependencies * Clarify VecEnv doc * Fix Gymnasium version * Copy file only after mamba install * [ci skip] Update docker doc * Polish code * Reformat * Remove deprecated features * Ignore warning * Update doc * Update examples and changelog * Fix type annotation bundle (SAC, TD3, A2C, PPO, base class) (#1436) * Fix SAC type hints, improve DQN ones * Fix A2C and TD3 type hints * Fix PPO type hints * Fix on-policy type hints * Fix base class type annotation, do not use defaults * Update version * Disable mypy for python 3.7 * Rename Gym26StepReturn * Update continuous critic type annotation * Fix pytype complain --------- Co-authored-by: Carlos Luis <carlos.luisgonc@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Thomas Lips <37955681+tlpss@users.noreply.github.com> Co-authored-by: tlips <thomas.lips@ugent.be> Co-authored-by: tlpss <thomas17.lips@gmail.com> Co-authored-by: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
2023-04-14 11:13:59 +00:00
from gymnasium import spaces
from stable_baselines3.common.base_class import BaseAlgorithm
Dictionary Observations (#243) * First commit * Fixing missing refs from a quick merge from master * Reformat * Adding DictBuffers * Reformat * Minor reformat * added slow dict test. Added SACMultiInputPolicy for future. Added private static image transpose helper to common policy * Ran black on buffers * Ran isort * Adding StackedObservations classes used within VecStackEnvs wrappers. Made test_dict_env shorter and removed slow * Running isort :facepalm * Fixed typing issues * Adding docstrings and typing. Using util for moving data to device. * Fixed trailing commas * Fix types * Minor edits * Avoid duplicating code * Fix calls to parents * Adding assert to buffers. Updating changelong * Running format on buffers * Adding multi-input policies to dqn,td3,a2c. Fixing warnings. Fixed bug with DictReplayBuffer as Replay buffers use only 1 env * Fixing warnings, splitting is_vectorized_observation into multiple functions based on space type * Created envs folder in common. Updated imports. Moved stacked_obs to vec_env folder * Moved envs to envs directory. Moved stacked obs to vec_envs. Started update on documentation * Fixes * Running code style * Update docstrings on torch_layers * Decapitalize non-constant variables * Using NatureCNN architecture in combined extractor. Increasing img size in multi input env. Adding memory reduction in test * Update doc * Update doc * Fix format * Removing NineRoom env. Using nested preprocess. Removing mutable default args * running code style * Passing channel check through to stacked dict observations. * Running black * Adding channel control to SimpleMultiObsEnv. Passing check_channels to CombinedExtractor * Remove optimize memory for dict buffers * Update doc * Move identity env * Minor edits + bump version * Update doc * Fix doc build * Bug fixes + add support for more type of dict env * Fixes + add multi env test * Add support for vectranspose * Fix stacked obs for dict and add tests * Add check for nested spaces. Fix dict-subprocvecenv test * Fix (single) pytype error * Simplify CombinedExtractor * Fix tests * Fix check * Merge branch 'master' into feat/dict_observations * Fix for net_arch with dict and vector obs * Fixes * Add consistency test * Update env checker * Add some docs on dict obs * Update default CNN feature vector size * Refactor HER (#351) * Start refactoring HER * Fixes * Additional fixes * Faster tests * WIP: HER as a custom replay buffer * New replay only version (working with DQN) * Add support for all off-policy algorithms * Fix saving/loading * Remove ObsDictWrapper and add VecNormalize tests with dict * Stable-Baselines3 v1.0 (#354) * Bump version and update doc * Fix name * Apply suggestions from code review Co-authored-by: Adam Gleave <adam@gleave.me> * Update docs/index.rst Co-authored-by: Adam Gleave <adam@gleave.me> * Update wording for RL zoo Co-authored-by: Adam Gleave <adam@gleave.me> * Add gym-pybullet-drones project (#358) * Update projects.rst Added gym-pybullet-drones * Update projects.rst Longer title underline * Update changelog Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org> * Include SuperSuit in projects (#359) * include supersuit * longer title underline * Update changelog.rst * Fix default arguments + add bugbear (#363) * Fix potential bug + add bug bear * Remove unused variables * Minor: version bump * Add code of conduct + update doc (#373) * Add code of conduct * Fix DQN doc example * Update doc (channel-last/first) * Apply suggestions from code review Co-authored-by: Anssi <kaneran21@hotmail.com> * Apply suggestions from code review Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Anssi <kaneran21@hotmail.com> Co-authored-by: Adam Gleave <adam@gleave.me> * Make installation command compatible with ZSH (#376) * Add quotes * Add Zsh bracket info * Add clarify pip installation line * Make note bold * Add Zsh pip installation note * Add handle timeouts param * Fixes * Fixes (buffer size, extend test) * Fix `max_episode_length` redefinition * Fix potential issue * Add some docs on dict obs * Fix performance bug * Fix slowdown * Add package to install (#378) * Add package to install * Update docs packages installation command Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Fix backward compat + add test * Fix VecEnv detection * Update doc * Fix vec env check * Support for `VecMonitor` for gym3-style environments (#311) * add vectorized monitor * auto format of the code * add documentation and VecExtractDictObs * refactor and add test cases * add test cases and format * avoid circular import and fix doc * fix type * fix type * oops * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * add test cases * update changelog * fix mutable argument * quick fix * Apply suggestions from code review * fix terminal observation for gym3 envs * delete comment * Update doc and bump version * Add warning when already using `Monitor` wrapper * Update vecmonitor tests * Fixes Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Reformat * Fixed loading of ``ent_coef`` for ``SAC`` and ``TQC``, it was not optimized anymore (#392) * Fix ent coef loading bug * Add test * Add comment * Reuse save path * Add test for GAE + rename `RolloutBuffer.dones` for clarification (#375) * Fix return computation + add test for GAE * Rename `last_dones` to `episode_starts` for clarification * Revert advantage * Cleanup test * Rename variable * Clarify return computation * Clarify docs * Add multi-episode rollout test * Reformat Co-authored-by: Anssi "Miffyli" Kanervisto <kaneran21@hotmail.com> * Fixed saving of `A2C` and `PPO` policy when using gSDE (#401) * Improve doc and replay buffer loading * Add support for images * Fix doc * Update Procgen doc * Update changelog * Update docstrings Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Jacopo Panerati <jacopo.panerati@utoronto.ca> Co-authored-by: Justin Terry <justinkterry@gmail.com> Co-authored-by: Anssi <kaneran21@hotmail.com> Co-authored-by: Tom Dörr <tomdoerr96@gmail.com> Co-authored-by: Tom Dörr <tom.doerr@tum.de> Co-authored-by: Costa Huang <costa.huang@outlook.com> * Update doc and minor fixes * Update doc * Added note about MultiInputPolicy in error of NatureCNN * Merge branch 'master' into feat/dict_observations * Address comments * Naming clarifications * Actually saving the file would be nice * Fix edge case when doing online sampling with HER * Cleanup * Add sanity check Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Anssi "Miffyli" Kanervisto <kaneran21@hotmail.com> Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Jacopo Panerati <jacopo.panerati@utoronto.ca> Co-authored-by: Justin Terry <justinkterry@gmail.com> Co-authored-by: Tom Dörr <tomdoerr96@gmail.com> Co-authored-by: Tom Dörr <tom.doerr@tum.de> Co-authored-by: Costa Huang <costa.huang@outlook.com>
2021-05-11 10:29:30 +00:00
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
Dictionary Observations (#243) * First commit * Fixing missing refs from a quick merge from master * Reformat * Adding DictBuffers * Reformat * Minor reformat * added slow dict test. Added SACMultiInputPolicy for future. Added private static image transpose helper to common policy * Ran black on buffers * Ran isort * Adding StackedObservations classes used within VecStackEnvs wrappers. Made test_dict_env shorter and removed slow * Running isort :facepalm * Fixed typing issues * Adding docstrings and typing. Using util for moving data to device. * Fixed trailing commas * Fix types * Minor edits * Avoid duplicating code * Fix calls to parents * Adding assert to buffers. Updating changelong * Running format on buffers * Adding multi-input policies to dqn,td3,a2c. Fixing warnings. Fixed bug with DictReplayBuffer as Replay buffers use only 1 env * Fixing warnings, splitting is_vectorized_observation into multiple functions based on space type * Created envs folder in common. Updated imports. Moved stacked_obs to vec_env folder * Moved envs to envs directory. Moved stacked obs to vec_envs. Started update on documentation * Fixes * Running code style * Update docstrings on torch_layers * Decapitalize non-constant variables * Using NatureCNN architecture in combined extractor. Increasing img size in multi input env. Adding memory reduction in test * Update doc * Update doc * Fix format * Removing NineRoom env. Using nested preprocess. Removing mutable default args * running code style * Passing channel check through to stacked dict observations. * Running black * Adding channel control to SimpleMultiObsEnv. Passing check_channels to CombinedExtractor * Remove optimize memory for dict buffers * Update doc * Move identity env * Minor edits + bump version * Update doc * Fix doc build * Bug fixes + add support for more type of dict env * Fixes + add multi env test * Add support for vectranspose * Fix stacked obs for dict and add tests * Add check for nested spaces. Fix dict-subprocvecenv test * Fix (single) pytype error * Simplify CombinedExtractor * Fix tests * Fix check * Merge branch 'master' into feat/dict_observations * Fix for net_arch with dict and vector obs * Fixes * Add consistency test * Update env checker * Add some docs on dict obs * Update default CNN feature vector size * Refactor HER (#351) * Start refactoring HER * Fixes * Additional fixes * Faster tests * WIP: HER as a custom replay buffer * New replay only version (working with DQN) * Add support for all off-policy algorithms * Fix saving/loading * Remove ObsDictWrapper and add VecNormalize tests with dict * Stable-Baselines3 v1.0 (#354) * Bump version and update doc * Fix name * Apply suggestions from code review Co-authored-by: Adam Gleave <adam@gleave.me> * Update docs/index.rst Co-authored-by: Adam Gleave <adam@gleave.me> * Update wording for RL zoo Co-authored-by: Adam Gleave <adam@gleave.me> * Add gym-pybullet-drones project (#358) * Update projects.rst Added gym-pybullet-drones * Update projects.rst Longer title underline * Update changelog Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org> * Include SuperSuit in projects (#359) * include supersuit * longer title underline * Update changelog.rst * Fix default arguments + add bugbear (#363) * Fix potential bug + add bug bear * Remove unused variables * Minor: version bump * Add code of conduct + update doc (#373) * Add code of conduct * Fix DQN doc example * Update doc (channel-last/first) * Apply suggestions from code review Co-authored-by: Anssi <kaneran21@hotmail.com> * Apply suggestions from code review Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Anssi <kaneran21@hotmail.com> Co-authored-by: Adam Gleave <adam@gleave.me> * Make installation command compatible with ZSH (#376) * Add quotes * Add Zsh bracket info * Add clarify pip installation line * Make note bold * Add Zsh pip installation note * Add handle timeouts param * Fixes * Fixes (buffer size, extend test) * Fix `max_episode_length` redefinition * Fix potential issue * Add some docs on dict obs * Fix performance bug * Fix slowdown * Add package to install (#378) * Add package to install * Update docs packages installation command Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Fix backward compat + add test * Fix VecEnv detection * Update doc * Fix vec env check * Support for `VecMonitor` for gym3-style environments (#311) * add vectorized monitor * auto format of the code * add documentation and VecExtractDictObs * refactor and add test cases * add test cases and format * avoid circular import and fix doc * fix type * fix type * oops * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * add test cases * update changelog * fix mutable argument * quick fix * Apply suggestions from code review * fix terminal observation for gym3 envs * delete comment * Update doc and bump version * Add warning when already using `Monitor` wrapper * Update vecmonitor tests * Fixes Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Reformat * Fixed loading of ``ent_coef`` for ``SAC`` and ``TQC``, it was not optimized anymore (#392) * Fix ent coef loading bug * Add test * Add comment * Reuse save path * Add test for GAE + rename `RolloutBuffer.dones` for clarification (#375) * Fix return computation + add test for GAE * Rename `last_dones` to `episode_starts` for clarification * Revert advantage * Cleanup test * Rename variable * Clarify return computation * Clarify docs * Add multi-episode rollout test * Reformat Co-authored-by: Anssi "Miffyli" Kanervisto <kaneran21@hotmail.com> * Fixed saving of `A2C` and `PPO` policy when using gSDE (#401) * Improve doc and replay buffer loading * Add support for images * Fix doc * Update Procgen doc * Update changelog * Update docstrings Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Jacopo Panerati <jacopo.panerati@utoronto.ca> Co-authored-by: Justin Terry <justinkterry@gmail.com> Co-authored-by: Anssi <kaneran21@hotmail.com> Co-authored-by: Tom Dörr <tomdoerr96@gmail.com> Co-authored-by: Tom Dörr <tom.doerr@tum.de> Co-authored-by: Costa Huang <costa.huang@outlook.com> * Update doc and minor fixes * Update doc * Added note about MultiInputPolicy in error of NatureCNN * Merge branch 'master' into feat/dict_observations * Address comments * Naming clarifications * Actually saving the file would be nice * Fix edge case when doing online sampling with HER * Cleanup * Add sanity check Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Anssi "Miffyli" Kanervisto <kaneran21@hotmail.com> Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Jacopo Panerati <jacopo.panerati@utoronto.ca> Co-authored-by: Justin Terry <justinkterry@gmail.com> Co-authored-by: Tom Dörr <tomdoerr96@gmail.com> Co-authored-by: Tom Dörr <tom.doerr@tum.de> Co-authored-by: Costa Huang <costa.huang@outlook.com>
2021-05-11 10:29:30 +00:00
from stable_baselines3.common.utils import obs_as_tensor, safe_mean
from stable_baselines3.common.vec_env import VecEnv
SelfOnPolicyAlgorithm = TypeVar("SelfOnPolicyAlgorithm", bound="OnPolicyAlgorithm")
class OnPolicyAlgorithm(BaseAlgorithm):
"""
The base for On-Policy algorithms (ex: A2C/PPO).
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from (if registered in Gym, can be str)
:param learning_rate: The learning rate, it can be a function
of the current progress remaining (from 1 to 0)
:param n_steps: The number of steps to run for each environment per update
(i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
:param gamma: Discount factor
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator.
Equivalent to classic advantage when set to 1.
:param ent_coef: Entropy coefficient for the loss calculation
:param vf_coef: Value function coefficient for the loss calculation
:param max_grad_norm: The maximum value for the gradient clipping
:param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
instead of action noise exploration (default: False)
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected.
:param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation.
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param monitor_wrapper: When creating an environment, whether to wrap it
or not in a Monitor wrapper.
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
:param _init_setup_model: Whether or not to build the network at the creation of the instance
:param supported_action_spaces: The action spaces supported by the algorithm.
"""
Add Gymnasium support (#1327) * Fix failing set_env test * Fix test failiing due to deprectation of env.seed * Adjust mean reward threshold in failing test * Fix her test failing due to rng * Change seed and revert reward threshold to 90 * Pin gym version * Make VecEnv compatible with gym seeding change * Revert change to VecEnv reset signature * Change subprocenv seed cmd to call reset instead * Fix type check * Add backward compat * Add `compat_gym_seed` helper * Add goal env checks in env_checker * Add docs on HER requirements for envs * Capture user warning in test with inverted box space * Update ale-py version * Fix randint * Allow noop_max to be zero * Update changelog * Update docker image * Update doc conda env and dockerfile * Custom envs should not have any warnings * Fix test for numpy >= 1.21 * Add check for vectorized compute reward * Bump to gym 0.24 * Fix gym default step docstring * Test downgrading gym * Revert "Test downgrading gym" This reverts commit 0072b77156c006ada8a1d6e26ce347ed85a83eeb. * Fix protobuf error * Fix in dependencies * Fix protobuf dep * Use newest version of cartpole * Update gym * Fix warning * Loosen required scipy version * Scipy no longer needed * Try gym 0.25 * Silence warnings from gym * Filter warnings during tests * Update doc * Update requirements * Add gym 26 compat in vec env * Fixes in envs and tests for gym 0.26+ * Enforce gym 0.26 api * format * Fix formatting * Fix dependencies * Fix syntax * Cleanup doc and warnings * Faster tests * Higher budget for HER perf test (revert prev change) * Fixes and update doc * Fix doc build * Fix breaking change * Fixes for rendering * Rename variables in monitor * update render method for gym 0.26 API backwards compatible (mode argument is allowed) while using the gym 0.26 API (render mode is determined at environment creation) * update tests and docs to new gym render API * undo removal of render modes metatadata check * set rgb_array as default render mode for gym.make * undo changes & raise warning if not 'rgb_array' * Fix type check * Remove recursion and fix type checking * Remove hacks for protobuf and gym 0.24 * Fix type annotations * reuse existing render_mode attribute * return tiled images for 'human' render mode * Allow to use opencv for human render, fix typos * Add warning when using non-zero start with Discrete (fixes #1197) * Fix type checking * Bug fixes and handle more cases * Throw proper warnings * Update test * Fix new metadata name * Ignore numpy warnings * Fixes in vec recorder * Global ignore * Filter local warning too * Monkey patch not needed for gym 26 * Add doc of VecEnv vs Gym API * Add render test * Fix return type * Update VecEnv vs Gym API doc * Fix for custom render mode * Fix return type * Fix type checking * check test env test_buffer * skip render check * check env test_dict_env * test_env test_gae * check envs in remaining tests * Update tests * Add warning for Discrete action space with non-zero (#1295) * Fix atari annotation * ignore get_action_meanings [attr-defined] * Fix mypy issues * Add patch for gym/gymnasium transition * Switch to gymnasium * Rely on signature instead of version * More patches * Type ignore because of https://github.com/Farama-Foundation/Gymnasium/pull/39 * Fix doc build * Fix pytype errors * Fix atari requirement * Update env checker due to change in dtype for Discrete * Fix type hint * Convert spaces for saved models * Ignore pytype * Remove gitlab CI * Disable pytype for convert space * Fix undefined info * Fix undefined info * Upgrade shimmy * Fix wrappers type annotation (need PR from Gymnasium) * Fix gymnasium dependency * Fix dependency declaration * Cap pygame version for python 3.7 * Point to master branch (v0.28.0) * Fix: use main not master branch * Rename done to terminated * Fix pygame dependency for python 3.7 * Rename gym to gymnasium * Update Gymnasium * Fix test * Fix tests * Forks don't have access to private variables * Fix linter warnings * Update read the doc env * Fix env checker for GoalEnv * Fix import * Update env checker (more info) and fix dtype * Use micromamab for Docker * Update dependencies * Clarify VecEnv doc * Fix Gymnasium version * Copy file only after mamba install * [ci skip] Update docker doc * Polish code * Reformat * Remove deprecated features * Ignore warning * Update doc * Update examples and changelog * Fix type annotation bundle (SAC, TD3, A2C, PPO, base class) (#1436) * Fix SAC type hints, improve DQN ones * Fix A2C and TD3 type hints * Fix PPO type hints * Fix on-policy type hints * Fix base class type annotation, do not use defaults * Update version * Disable mypy for python 3.7 * Rename Gym26StepReturn * Update continuous critic type annotation * Fix pytype complain --------- Co-authored-by: Carlos Luis <carlos.luisgonc@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Thomas Lips <37955681+tlpss@users.noreply.github.com> Co-authored-by: tlips <thomas.lips@ugent.be> Co-authored-by: tlpss <thomas17.lips@gmail.com> Co-authored-by: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
2023-04-14 11:13:59 +00:00
rollout_buffer: RolloutBuffer
policy: ActorCriticPolicy
def __init__(
self,
policy: Union[str, Type[ActorCriticPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule],
n_steps: int,
gamma: float,
gae_lambda: float,
ent_coef: float,
vf_coef: float,
max_grad_norm: float,
use_sde: bool,
sde_sample_freq: int,
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
monitor_wrapper: bool = True,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
Add Gymnasium support (#1327) * Fix failing set_env test * Fix test failiing due to deprectation of env.seed * Adjust mean reward threshold in failing test * Fix her test failing due to rng * Change seed and revert reward threshold to 90 * Pin gym version * Make VecEnv compatible with gym seeding change * Revert change to VecEnv reset signature * Change subprocenv seed cmd to call reset instead * Fix type check * Add backward compat * Add `compat_gym_seed` helper * Add goal env checks in env_checker * Add docs on HER requirements for envs * Capture user warning in test with inverted box space * Update ale-py version * Fix randint * Allow noop_max to be zero * Update changelog * Update docker image * Update doc conda env and dockerfile * Custom envs should not have any warnings * Fix test for numpy >= 1.21 * Add check for vectorized compute reward * Bump to gym 0.24 * Fix gym default step docstring * Test downgrading gym * Revert "Test downgrading gym" This reverts commit 0072b77156c006ada8a1d6e26ce347ed85a83eeb. * Fix protobuf error * Fix in dependencies * Fix protobuf dep * Use newest version of cartpole * Update gym * Fix warning * Loosen required scipy version * Scipy no longer needed * Try gym 0.25 * Silence warnings from gym * Filter warnings during tests * Update doc * Update requirements * Add gym 26 compat in vec env * Fixes in envs and tests for gym 0.26+ * Enforce gym 0.26 api * format * Fix formatting * Fix dependencies * Fix syntax * Cleanup doc and warnings * Faster tests * Higher budget for HER perf test (revert prev change) * Fixes and update doc * Fix doc build * Fix breaking change * Fixes for rendering * Rename variables in monitor * update render method for gym 0.26 API backwards compatible (mode argument is allowed) while using the gym 0.26 API (render mode is determined at environment creation) * update tests and docs to new gym render API * undo removal of render modes metatadata check * set rgb_array as default render mode for gym.make * undo changes & raise warning if not 'rgb_array' * Fix type check * Remove recursion and fix type checking * Remove hacks for protobuf and gym 0.24 * Fix type annotations * reuse existing render_mode attribute * return tiled images for 'human' render mode * Allow to use opencv for human render, fix typos * Add warning when using non-zero start with Discrete (fixes #1197) * Fix type checking * Bug fixes and handle more cases * Throw proper warnings * Update test * Fix new metadata name * Ignore numpy warnings * Fixes in vec recorder * Global ignore * Filter local warning too * Monkey patch not needed for gym 26 * Add doc of VecEnv vs Gym API * Add render test * Fix return type * Update VecEnv vs Gym API doc * Fix for custom render mode * Fix return type * Fix type checking * check test env test_buffer * skip render check * check env test_dict_env * test_env test_gae * check envs in remaining tests * Update tests * Add warning for Discrete action space with non-zero (#1295) * Fix atari annotation * ignore get_action_meanings [attr-defined] * Fix mypy issues * Add patch for gym/gymnasium transition * Switch to gymnasium * Rely on signature instead of version * More patches * Type ignore because of https://github.com/Farama-Foundation/Gymnasium/pull/39 * Fix doc build * Fix pytype errors * Fix atari requirement * Update env checker due to change in dtype for Discrete * Fix type hint * Convert spaces for saved models * Ignore pytype * Remove gitlab CI * Disable pytype for convert space * Fix undefined info * Fix undefined info * Upgrade shimmy * Fix wrappers type annotation (need PR from Gymnasium) * Fix gymnasium dependency * Fix dependency declaration * Cap pygame version for python 3.7 * Point to master branch (v0.28.0) * Fix: use main not master branch * Rename done to terminated * Fix pygame dependency for python 3.7 * Rename gym to gymnasium * Update Gymnasium * Fix test * Fix tests * Forks don't have access to private variables * Fix linter warnings * Update read the doc env * Fix env checker for GoalEnv * Fix import * Update env checker (more info) and fix dtype * Use micromamab for Docker * Update dependencies * Clarify VecEnv doc * Fix Gymnasium version * Copy file only after mamba install * [ci skip] Update docker doc * Polish code * Reformat * Remove deprecated features * Ignore warning * Update doc * Update examples and changelog * Fix type annotation bundle (SAC, TD3, A2C, PPO, base class) (#1436) * Fix SAC type hints, improve DQN ones * Fix A2C and TD3 type hints * Fix PPO type hints * Fix on-policy type hints * Fix base class type annotation, do not use defaults * Update version * Disable mypy for python 3.7 * Rename Gym26StepReturn * Update continuous critic type annotation * Fix pytype complain --------- Co-authored-by: Carlos Luis <carlos.luisgonc@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Thomas Lips <37955681+tlpss@users.noreply.github.com> Co-authored-by: tlips <thomas.lips@ugent.be> Co-authored-by: tlpss <thomas17.lips@gmail.com> Co-authored-by: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
2023-04-14 11:13:59 +00:00
supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None,
):
super().__init__(
policy=policy,
env=env,
learning_rate=learning_rate,
policy_kwargs=policy_kwargs,
verbose=verbose,
device=device,
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
support_multi_env=True,
monitor_wrapper=monitor_wrapper,
seed=seed,
stats_window_size=stats_window_size,
tensorboard_log=tensorboard_log,
supported_action_spaces=supported_action_spaces,
)
self.n_steps = n_steps
self.gamma = gamma
self.gae_lambda = gae_lambda
self.ent_coef = ent_coef
self.vf_coef = vf_coef
self.max_grad_norm = max_grad_norm
self.rollout_buffer_class = rollout_buffer_class
self.rollout_buffer_kwargs = rollout_buffer_kwargs or {}
if _init_setup_model:
self._setup_model()
def _setup_model(self) -> None:
self._setup_lr_schedule()
self.set_random_seed(self.seed)
if self.rollout_buffer_class is None:
if isinstance(self.observation_space, spaces.Dict):
self.rollout_buffer_class = DictRolloutBuffer
else:
self.rollout_buffer_class = RolloutBuffer
Dictionary Observations (#243) * First commit * Fixing missing refs from a quick merge from master * Reformat * Adding DictBuffers * Reformat * Minor reformat * added slow dict test. Added SACMultiInputPolicy for future. Added private static image transpose helper to common policy * Ran black on buffers * Ran isort * Adding StackedObservations classes used within VecStackEnvs wrappers. Made test_dict_env shorter and removed slow * Running isort :facepalm * Fixed typing issues * Adding docstrings and typing. Using util for moving data to device. * Fixed trailing commas * Fix types * Minor edits * Avoid duplicating code * Fix calls to parents * Adding assert to buffers. Updating changelong * Running format on buffers * Adding multi-input policies to dqn,td3,a2c. Fixing warnings. Fixed bug with DictReplayBuffer as Replay buffers use only 1 env * Fixing warnings, splitting is_vectorized_observation into multiple functions based on space type * Created envs folder in common. Updated imports. Moved stacked_obs to vec_env folder * Moved envs to envs directory. Moved stacked obs to vec_envs. Started update on documentation * Fixes * Running code style * Update docstrings on torch_layers * Decapitalize non-constant variables * Using NatureCNN architecture in combined extractor. Increasing img size in multi input env. Adding memory reduction in test * Update doc * Update doc * Fix format * Removing NineRoom env. Using nested preprocess. Removing mutable default args * running code style * Passing channel check through to stacked dict observations. * Running black * Adding channel control to SimpleMultiObsEnv. Passing check_channels to CombinedExtractor * Remove optimize memory for dict buffers * Update doc * Move identity env * Minor edits + bump version * Update doc * Fix doc build * Bug fixes + add support for more type of dict env * Fixes + add multi env test * Add support for vectranspose * Fix stacked obs for dict and add tests * Add check for nested spaces. Fix dict-subprocvecenv test * Fix (single) pytype error * Simplify CombinedExtractor * Fix tests * Fix check * Merge branch 'master' into feat/dict_observations * Fix for net_arch with dict and vector obs * Fixes * Add consistency test * Update env checker * Add some docs on dict obs * Update default CNN feature vector size * Refactor HER (#351) * Start refactoring HER * Fixes * Additional fixes * Faster tests * WIP: HER as a custom replay buffer * New replay only version (working with DQN) * Add support for all off-policy algorithms * Fix saving/loading * Remove ObsDictWrapper and add VecNormalize tests with dict * Stable-Baselines3 v1.0 (#354) * Bump version and update doc * Fix name * Apply suggestions from code review Co-authored-by: Adam Gleave <adam@gleave.me> * Update docs/index.rst Co-authored-by: Adam Gleave <adam@gleave.me> * Update wording for RL zoo Co-authored-by: Adam Gleave <adam@gleave.me> * Add gym-pybullet-drones project (#358) * Update projects.rst Added gym-pybullet-drones * Update projects.rst Longer title underline * Update changelog Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org> * Include SuperSuit in projects (#359) * include supersuit * longer title underline * Update changelog.rst * Fix default arguments + add bugbear (#363) * Fix potential bug + add bug bear * Remove unused variables * Minor: version bump * Add code of conduct + update doc (#373) * Add code of conduct * Fix DQN doc example * Update doc (channel-last/first) * Apply suggestions from code review Co-authored-by: Anssi <kaneran21@hotmail.com> * Apply suggestions from code review Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Anssi <kaneran21@hotmail.com> Co-authored-by: Adam Gleave <adam@gleave.me> * Make installation command compatible with ZSH (#376) * Add quotes * Add Zsh bracket info * Add clarify pip installation line * Make note bold * Add Zsh pip installation note * Add handle timeouts param * Fixes * Fixes (buffer size, extend test) * Fix `max_episode_length` redefinition * Fix potential issue * Add some docs on dict obs * Fix performance bug * Fix slowdown * Add package to install (#378) * Add package to install * Update docs packages installation command Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Fix backward compat + add test * Fix VecEnv detection * Update doc * Fix vec env check * Support for `VecMonitor` for gym3-style environments (#311) * add vectorized monitor * auto format of the code * add documentation and VecExtractDictObs * refactor and add test cases * add test cases and format * avoid circular import and fix doc * fix type * fix type * oops * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * add test cases * update changelog * fix mutable argument * quick fix * Apply suggestions from code review * fix terminal observation for gym3 envs * delete comment * Update doc and bump version * Add warning when already using `Monitor` wrapper * Update vecmonitor tests * Fixes Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Reformat * Fixed loading of ``ent_coef`` for ``SAC`` and ``TQC``, it was not optimized anymore (#392) * Fix ent coef loading bug * Add test * Add comment * Reuse save path * Add test for GAE + rename `RolloutBuffer.dones` for clarification (#375) * Fix return computation + add test for GAE * Rename `last_dones` to `episode_starts` for clarification * Revert advantage * Cleanup test * Rename variable * Clarify return computation * Clarify docs * Add multi-episode rollout test * Reformat Co-authored-by: Anssi "Miffyli" Kanervisto <kaneran21@hotmail.com> * Fixed saving of `A2C` and `PPO` policy when using gSDE (#401) * Improve doc and replay buffer loading * Add support for images * Fix doc * Update Procgen doc * Update changelog * Update docstrings Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Jacopo Panerati <jacopo.panerati@utoronto.ca> Co-authored-by: Justin Terry <justinkterry@gmail.com> Co-authored-by: Anssi <kaneran21@hotmail.com> Co-authored-by: Tom Dörr <tomdoerr96@gmail.com> Co-authored-by: Tom Dörr <tom.doerr@tum.de> Co-authored-by: Costa Huang <costa.huang@outlook.com> * Update doc and minor fixes * Update doc * Added note about MultiInputPolicy in error of NatureCNN * Merge branch 'master' into feat/dict_observations * Address comments * Naming clarifications * Actually saving the file would be nice * Fix edge case when doing online sampling with HER * Cleanup * Add sanity check Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Anssi "Miffyli" Kanervisto <kaneran21@hotmail.com> Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Jacopo Panerati <jacopo.panerati@utoronto.ca> Co-authored-by: Justin Terry <justinkterry@gmail.com> Co-authored-by: Tom Dörr <tomdoerr96@gmail.com> Co-authored-by: Tom Dörr <tom.doerr@tum.de> Co-authored-by: Costa Huang <costa.huang@outlook.com>
2021-05-11 10:29:30 +00:00
self.rollout_buffer = self.rollout_buffer_class(
self.n_steps,
self.observation_space, # type: ignore[arg-type]
self.action_space,
device=self.device,
gamma=self.gamma,
gae_lambda=self.gae_lambda,
n_envs=self.n_envs,
**self.rollout_buffer_kwargs,
)
Add Gymnasium support (#1327) * Fix failing set_env test * Fix test failiing due to deprectation of env.seed * Adjust mean reward threshold in failing test * Fix her test failing due to rng * Change seed and revert reward threshold to 90 * Pin gym version * Make VecEnv compatible with gym seeding change * Revert change to VecEnv reset signature * Change subprocenv seed cmd to call reset instead * Fix type check * Add backward compat * Add `compat_gym_seed` helper * Add goal env checks in env_checker * Add docs on HER requirements for envs * Capture user warning in test with inverted box space * Update ale-py version * Fix randint * Allow noop_max to be zero * Update changelog * Update docker image * Update doc conda env and dockerfile * Custom envs should not have any warnings * Fix test for numpy >= 1.21 * Add check for vectorized compute reward * Bump to gym 0.24 * Fix gym default step docstring * Test downgrading gym * Revert "Test downgrading gym" This reverts commit 0072b77156c006ada8a1d6e26ce347ed85a83eeb. * Fix protobuf error * Fix in dependencies * Fix protobuf dep * Use newest version of cartpole * Update gym * Fix warning * Loosen required scipy version * Scipy no longer needed * Try gym 0.25 * Silence warnings from gym * Filter warnings during tests * Update doc * Update requirements * Add gym 26 compat in vec env * Fixes in envs and tests for gym 0.26+ * Enforce gym 0.26 api * format * Fix formatting * Fix dependencies * Fix syntax * Cleanup doc and warnings * Faster tests * Higher budget for HER perf test (revert prev change) * Fixes and update doc * Fix doc build * Fix breaking change * Fixes for rendering * Rename variables in monitor * update render method for gym 0.26 API backwards compatible (mode argument is allowed) while using the gym 0.26 API (render mode is determined at environment creation) * update tests and docs to new gym render API * undo removal of render modes metatadata check * set rgb_array as default render mode for gym.make * undo changes & raise warning if not 'rgb_array' * Fix type check * Remove recursion and fix type checking * Remove hacks for protobuf and gym 0.24 * Fix type annotations * reuse existing render_mode attribute * return tiled images for 'human' render mode * Allow to use opencv for human render, fix typos * Add warning when using non-zero start with Discrete (fixes #1197) * Fix type checking * Bug fixes and handle more cases * Throw proper warnings * Update test * Fix new metadata name * Ignore numpy warnings * Fixes in vec recorder * Global ignore * Filter local warning too * Monkey patch not needed for gym 26 * Add doc of VecEnv vs Gym API * Add render test * Fix return type * Update VecEnv vs Gym API doc * Fix for custom render mode * Fix return type * Fix type checking * check test env test_buffer * skip render check * check env test_dict_env * test_env test_gae * check envs in remaining tests * Update tests * Add warning for Discrete action space with non-zero (#1295) * Fix atari annotation * ignore get_action_meanings [attr-defined] * Fix mypy issues * Add patch for gym/gymnasium transition * Switch to gymnasium * Rely on signature instead of version * More patches * Type ignore because of https://github.com/Farama-Foundation/Gymnasium/pull/39 * Fix doc build * Fix pytype errors * Fix atari requirement * Update env checker due to change in dtype for Discrete * Fix type hint * Convert spaces for saved models * Ignore pytype * Remove gitlab CI * Disable pytype for convert space * Fix undefined info * Fix undefined info * Upgrade shimmy * Fix wrappers type annotation (need PR from Gymnasium) * Fix gymnasium dependency * Fix dependency declaration * Cap pygame version for python 3.7 * Point to master branch (v0.28.0) * Fix: use main not master branch * Rename done to terminated * Fix pygame dependency for python 3.7 * Rename gym to gymnasium * Update Gymnasium * Fix test * Fix tests * Forks don't have access to private variables * Fix linter warnings * Update read the doc env * Fix env checker for GoalEnv * Fix import * Update env checker (more info) and fix dtype * Use micromamab for Docker * Update dependencies * Clarify VecEnv doc * Fix Gymnasium version * Copy file only after mamba install * [ci skip] Update docker doc * Polish code * Reformat * Remove deprecated features * Ignore warning * Update doc * Update examples and changelog * Fix type annotation bundle (SAC, TD3, A2C, PPO, base class) (#1436) * Fix SAC type hints, improve DQN ones * Fix A2C and TD3 type hints * Fix PPO type hints * Fix on-policy type hints * Fix base class type annotation, do not use defaults * Update version * Disable mypy for python 3.7 * Rename Gym26StepReturn * Update continuous critic type annotation * Fix pytype complain --------- Co-authored-by: Carlos Luis <carlos.luisgonc@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Thomas Lips <37955681+tlpss@users.noreply.github.com> Co-authored-by: tlips <thomas.lips@ugent.be> Co-authored-by: tlpss <thomas17.lips@gmail.com> Co-authored-by: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
2023-04-14 11:13:59 +00:00
self.policy = self.policy_class( # type: ignore[assignment]
self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, **self.policy_kwargs
)
self.policy = self.policy.to(self.device)
def collect_rollouts(
Dictionary Observations (#243) * First commit * Fixing missing refs from a quick merge from master * Reformat * Adding DictBuffers * Reformat * Minor reformat * added slow dict test. Added SACMultiInputPolicy for future. Added private static image transpose helper to common policy * Ran black on buffers * Ran isort * Adding StackedObservations classes used within VecStackEnvs wrappers. Made test_dict_env shorter and removed slow * Running isort :facepalm * Fixed typing issues * Adding docstrings and typing. Using util for moving data to device. * Fixed trailing commas * Fix types * Minor edits * Avoid duplicating code * Fix calls to parents * Adding assert to buffers. Updating changelong * Running format on buffers * Adding multi-input policies to dqn,td3,a2c. Fixing warnings. Fixed bug with DictReplayBuffer as Replay buffers use only 1 env * Fixing warnings, splitting is_vectorized_observation into multiple functions based on space type * Created envs folder in common. Updated imports. Moved stacked_obs to vec_env folder * Moved envs to envs directory. Moved stacked obs to vec_envs. Started update on documentation * Fixes * Running code style * Update docstrings on torch_layers * Decapitalize non-constant variables * Using NatureCNN architecture in combined extractor. Increasing img size in multi input env. Adding memory reduction in test * Update doc * Update doc * Fix format * Removing NineRoom env. Using nested preprocess. Removing mutable default args * running code style * Passing channel check through to stacked dict observations. * Running black * Adding channel control to SimpleMultiObsEnv. Passing check_channels to CombinedExtractor * Remove optimize memory for dict buffers * Update doc * Move identity env * Minor edits + bump version * Update doc * Fix doc build * Bug fixes + add support for more type of dict env * Fixes + add multi env test * Add support for vectranspose * Fix stacked obs for dict and add tests * Add check for nested spaces. Fix dict-subprocvecenv test * Fix (single) pytype error * Simplify CombinedExtractor * Fix tests * Fix check * Merge branch 'master' into feat/dict_observations * Fix for net_arch with dict and vector obs * Fixes * Add consistency test * Update env checker * Add some docs on dict obs * Update default CNN feature vector size * Refactor HER (#351) * Start refactoring HER * Fixes * Additional fixes * Faster tests * WIP: HER as a custom replay buffer * New replay only version (working with DQN) * Add support for all off-policy algorithms * Fix saving/loading * Remove ObsDictWrapper and add VecNormalize tests with dict * Stable-Baselines3 v1.0 (#354) * Bump version and update doc * Fix name * Apply suggestions from code review Co-authored-by: Adam Gleave <adam@gleave.me> * Update docs/index.rst Co-authored-by: Adam Gleave <adam@gleave.me> * Update wording for RL zoo Co-authored-by: Adam Gleave <adam@gleave.me> * Add gym-pybullet-drones project (#358) * Update projects.rst Added gym-pybullet-drones * Update projects.rst Longer title underline * Update changelog Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org> * Include SuperSuit in projects (#359) * include supersuit * longer title underline * Update changelog.rst * Fix default arguments + add bugbear (#363) * Fix potential bug + add bug bear * Remove unused variables * Minor: version bump * Add code of conduct + update doc (#373) * Add code of conduct * Fix DQN doc example * Update doc (channel-last/first) * Apply suggestions from code review Co-authored-by: Anssi <kaneran21@hotmail.com> * Apply suggestions from code review Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Anssi <kaneran21@hotmail.com> Co-authored-by: Adam Gleave <adam@gleave.me> * Make installation command compatible with ZSH (#376) * Add quotes * Add Zsh bracket info * Add clarify pip installation line * Make note bold * Add Zsh pip installation note * Add handle timeouts param * Fixes * Fixes (buffer size, extend test) * Fix `max_episode_length` redefinition * Fix potential issue * Add some docs on dict obs * Fix performance bug * Fix slowdown * Add package to install (#378) * Add package to install * Update docs packages installation command Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Fix backward compat + add test * Fix VecEnv detection * Update doc * Fix vec env check * Support for `VecMonitor` for gym3-style environments (#311) * add vectorized monitor * auto format of the code * add documentation and VecExtractDictObs * refactor and add test cases * add test cases and format * avoid circular import and fix doc * fix type * fix type * oops * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * add test cases * update changelog * fix mutable argument * quick fix * Apply suggestions from code review * fix terminal observation for gym3 envs * delete comment * Update doc and bump version * Add warning when already using `Monitor` wrapper * Update vecmonitor tests * Fixes Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Reformat * Fixed loading of ``ent_coef`` for ``SAC`` and ``TQC``, it was not optimized anymore (#392) * Fix ent coef loading bug * Add test * Add comment * Reuse save path * Add test for GAE + rename `RolloutBuffer.dones` for clarification (#375) * Fix return computation + add test for GAE * Rename `last_dones` to `episode_starts` for clarification * Revert advantage * Cleanup test * Rename variable * Clarify return computation * Clarify docs * Add multi-episode rollout test * Reformat Co-authored-by: Anssi "Miffyli" Kanervisto <kaneran21@hotmail.com> * Fixed saving of `A2C` and `PPO` policy when using gSDE (#401) * Improve doc and replay buffer loading * Add support for images * Fix doc * Update Procgen doc * Update changelog * Update docstrings Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Jacopo Panerati <jacopo.panerati@utoronto.ca> Co-authored-by: Justin Terry <justinkterry@gmail.com> Co-authored-by: Anssi <kaneran21@hotmail.com> Co-authored-by: Tom Dörr <tomdoerr96@gmail.com> Co-authored-by: Tom Dörr <tom.doerr@tum.de> Co-authored-by: Costa Huang <costa.huang@outlook.com> * Update doc and minor fixes * Update doc * Added note about MultiInputPolicy in error of NatureCNN * Merge branch 'master' into feat/dict_observations * Address comments * Naming clarifications * Actually saving the file would be nice * Fix edge case when doing online sampling with HER * Cleanup * Add sanity check Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Anssi "Miffyli" Kanervisto <kaneran21@hotmail.com> Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Jacopo Panerati <jacopo.panerati@utoronto.ca> Co-authored-by: Justin Terry <justinkterry@gmail.com> Co-authored-by: Tom Dörr <tomdoerr96@gmail.com> Co-authored-by: Tom Dörr <tom.doerr@tum.de> Co-authored-by: Costa Huang <costa.huang@outlook.com>
2021-05-11 10:29:30 +00:00
self,
env: VecEnv,
callback: BaseCallback,
rollout_buffer: RolloutBuffer,
n_rollout_steps: int,
) -> bool:
"""
Collect experiences using the current policy and fill a ``RolloutBuffer``.
The term rollout here refers to the model-free notion and should not
be used with the concept of rollout used in model-based RL or planning.
:param env: The training environment
:param callback: Callback that will be called at each step
(and at the beginning and end of the rollout)
:param rollout_buffer: Buffer to fill with rollouts
:param n_rollout_steps: Number of experiences to collect per environment
:return: True if function returned with at least `n_rollout_steps`
collected, False if callback terminated rollout prematurely.
"""
assert self._last_obs is not None, "No previous observation was provided"
# Switch to eval mode (this affects batch norm / dropout)
self.policy.set_training_mode(False)
n_steps = 0
rollout_buffer.reset()
# Sample new weights for the state dependent exploration
if self.use_sde:
self.policy.reset_noise(env.num_envs)
callback.on_rollout_start()
while n_steps < n_rollout_steps:
if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
# Sample a new noise matrix
self.policy.reset_noise(env.num_envs)
with th.no_grad():
Dictionary Observations (#243) * First commit * Fixing missing refs from a quick merge from master * Reformat * Adding DictBuffers * Reformat * Minor reformat * added slow dict test. Added SACMultiInputPolicy for future. Added private static image transpose helper to common policy * Ran black on buffers * Ran isort * Adding StackedObservations classes used within VecStackEnvs wrappers. Made test_dict_env shorter and removed slow * Running isort :facepalm * Fixed typing issues * Adding docstrings and typing. Using util for moving data to device. * Fixed trailing commas * Fix types * Minor edits * Avoid duplicating code * Fix calls to parents * Adding assert to buffers. Updating changelong * Running format on buffers * Adding multi-input policies to dqn,td3,a2c. Fixing warnings. Fixed bug with DictReplayBuffer as Replay buffers use only 1 env * Fixing warnings, splitting is_vectorized_observation into multiple functions based on space type * Created envs folder in common. Updated imports. Moved stacked_obs to vec_env folder * Moved envs to envs directory. Moved stacked obs to vec_envs. Started update on documentation * Fixes * Running code style * Update docstrings on torch_layers * Decapitalize non-constant variables * Using NatureCNN architecture in combined extractor. Increasing img size in multi input env. Adding memory reduction in test * Update doc * Update doc * Fix format * Removing NineRoom env. Using nested preprocess. Removing mutable default args * running code style * Passing channel check through to stacked dict observations. * Running black * Adding channel control to SimpleMultiObsEnv. Passing check_channels to CombinedExtractor * Remove optimize memory for dict buffers * Update doc * Move identity env * Minor edits + bump version * Update doc * Fix doc build * Bug fixes + add support for more type of dict env * Fixes + add multi env test * Add support for vectranspose * Fix stacked obs for dict and add tests * Add check for nested spaces. Fix dict-subprocvecenv test * Fix (single) pytype error * Simplify CombinedExtractor * Fix tests * Fix check * Merge branch 'master' into feat/dict_observations * Fix for net_arch with dict and vector obs * Fixes * Add consistency test * Update env checker * Add some docs on dict obs * Update default CNN feature vector size * Refactor HER (#351) * Start refactoring HER * Fixes * Additional fixes * Faster tests * WIP: HER as a custom replay buffer * New replay only version (working with DQN) * Add support for all off-policy algorithms * Fix saving/loading * Remove ObsDictWrapper and add VecNormalize tests with dict * Stable-Baselines3 v1.0 (#354) * Bump version and update doc * Fix name * Apply suggestions from code review Co-authored-by: Adam Gleave <adam@gleave.me> * Update docs/index.rst Co-authored-by: Adam Gleave <adam@gleave.me> * Update wording for RL zoo Co-authored-by: Adam Gleave <adam@gleave.me> * Add gym-pybullet-drones project (#358) * Update projects.rst Added gym-pybullet-drones * Update projects.rst Longer title underline * Update changelog Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org> * Include SuperSuit in projects (#359) * include supersuit * longer title underline * Update changelog.rst * Fix default arguments + add bugbear (#363) * Fix potential bug + add bug bear * Remove unused variables * Minor: version bump * Add code of conduct + update doc (#373) * Add code of conduct * Fix DQN doc example * Update doc (channel-last/first) * Apply suggestions from code review Co-authored-by: Anssi <kaneran21@hotmail.com> * Apply suggestions from code review Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Anssi <kaneran21@hotmail.com> Co-authored-by: Adam Gleave <adam@gleave.me> * Make installation command compatible with ZSH (#376) * Add quotes * Add Zsh bracket info * Add clarify pip installation line * Make note bold * Add Zsh pip installation note * Add handle timeouts param * Fixes * Fixes (buffer size, extend test) * Fix `max_episode_length` redefinition * Fix potential issue * Add some docs on dict obs * Fix performance bug * Fix slowdown * Add package to install (#378) * Add package to install * Update docs packages installation command Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Fix backward compat + add test * Fix VecEnv detection * Update doc * Fix vec env check * Support for `VecMonitor` for gym3-style environments (#311) * add vectorized monitor * auto format of the code * add documentation and VecExtractDictObs * refactor and add test cases * add test cases and format * avoid circular import and fix doc * fix type * fix type * oops * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * add test cases * update changelog * fix mutable argument * quick fix * Apply suggestions from code review * fix terminal observation for gym3 envs * delete comment * Update doc and bump version * Add warning when already using `Monitor` wrapper * Update vecmonitor tests * Fixes Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Reformat * Fixed loading of ``ent_coef`` for ``SAC`` and ``TQC``, it was not optimized anymore (#392) * Fix ent coef loading bug * Add test * Add comment * Reuse save path * Add test for GAE + rename `RolloutBuffer.dones` for clarification (#375) * Fix return computation + add test for GAE * Rename `last_dones` to `episode_starts` for clarification * Revert advantage * Cleanup test * Rename variable * Clarify return computation * Clarify docs * Add multi-episode rollout test * Reformat Co-authored-by: Anssi "Miffyli" Kanervisto <kaneran21@hotmail.com> * Fixed saving of `A2C` and `PPO` policy when using gSDE (#401) * Improve doc and replay buffer loading * Add support for images * Fix doc * Update Procgen doc * Update changelog * Update docstrings Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Jacopo Panerati <jacopo.panerati@utoronto.ca> Co-authored-by: Justin Terry <justinkterry@gmail.com> Co-authored-by: Anssi <kaneran21@hotmail.com> Co-authored-by: Tom Dörr <tomdoerr96@gmail.com> Co-authored-by: Tom Dörr <tom.doerr@tum.de> Co-authored-by: Costa Huang <costa.huang@outlook.com> * Update doc and minor fixes * Update doc * Added note about MultiInputPolicy in error of NatureCNN * Merge branch 'master' into feat/dict_observations * Address comments * Naming clarifications * Actually saving the file would be nice * Fix edge case when doing online sampling with HER * Cleanup * Add sanity check Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Anssi "Miffyli" Kanervisto <kaneran21@hotmail.com> Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Jacopo Panerati <jacopo.panerati@utoronto.ca> Co-authored-by: Justin Terry <justinkterry@gmail.com> Co-authored-by: Tom Dörr <tomdoerr96@gmail.com> Co-authored-by: Tom Dörr <tom.doerr@tum.de> Co-authored-by: Costa Huang <costa.huang@outlook.com>
2021-05-11 10:29:30 +00:00
# Convert to pytorch tensor or to TensorDict
obs_tensor = obs_as_tensor(self._last_obs, self.device)
actions, values, log_probs = self.policy(obs_tensor)
actions = actions.cpu().numpy()
# Rescale and perform action
clipped_actions = actions
if isinstance(self.action_space, spaces.Box):
if self.policy.squash_output:
# Unscale the actions to match env bounds
# if they were previously squashed (scaled in [-1, 1])
clipped_actions = self.policy.unscale_action(clipped_actions)
else:
# Otherwise, clip the actions to avoid out of bound error
# as we are sampling from an unbounded Gaussian distribution
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
new_obs, rewards, dones, infos = env.step(clipped_actions)
self.num_timesteps += env.num_envs
# Give access to local variables
callback.update_locals(locals())
if not callback.on_step():
return False
self._update_info_buffer(infos, dones)
n_steps += 1
if isinstance(self.action_space, spaces.Discrete):
# Reshape in case of discrete action
actions = actions.reshape(-1, 1)
# Handle timeout by bootstraping with value function
# see GitHub issue #633
for idx, done in enumerate(dones):
if (
done
and infos[idx].get("terminal_observation") is not None
and infos[idx].get("TimeLimit.truncated", False)
):
terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
with th.no_grad():
Add Gymnasium support (#1327) * Fix failing set_env test * Fix test failiing due to deprectation of env.seed * Adjust mean reward threshold in failing test * Fix her test failing due to rng * Change seed and revert reward threshold to 90 * Pin gym version * Make VecEnv compatible with gym seeding change * Revert change to VecEnv reset signature * Change subprocenv seed cmd to call reset instead * Fix type check * Add backward compat * Add `compat_gym_seed` helper * Add goal env checks in env_checker * Add docs on HER requirements for envs * Capture user warning in test with inverted box space * Update ale-py version * Fix randint * Allow noop_max to be zero * Update changelog * Update docker image * Update doc conda env and dockerfile * Custom envs should not have any warnings * Fix test for numpy >= 1.21 * Add check for vectorized compute reward * Bump to gym 0.24 * Fix gym default step docstring * Test downgrading gym * Revert "Test downgrading gym" This reverts commit 0072b77156c006ada8a1d6e26ce347ed85a83eeb. * Fix protobuf error * Fix in dependencies * Fix protobuf dep * Use newest version of cartpole * Update gym * Fix warning * Loosen required scipy version * Scipy no longer needed * Try gym 0.25 * Silence warnings from gym * Filter warnings during tests * Update doc * Update requirements * Add gym 26 compat in vec env * Fixes in envs and tests for gym 0.26+ * Enforce gym 0.26 api * format * Fix formatting * Fix dependencies * Fix syntax * Cleanup doc and warnings * Faster tests * Higher budget for HER perf test (revert prev change) * Fixes and update doc * Fix doc build * Fix breaking change * Fixes for rendering * Rename variables in monitor * update render method for gym 0.26 API backwards compatible (mode argument is allowed) while using the gym 0.26 API (render mode is determined at environment creation) * update tests and docs to new gym render API * undo removal of render modes metatadata check * set rgb_array as default render mode for gym.make * undo changes & raise warning if not 'rgb_array' * Fix type check * Remove recursion and fix type checking * Remove hacks for protobuf and gym 0.24 * Fix type annotations * reuse existing render_mode attribute * return tiled images for 'human' render mode * Allow to use opencv for human render, fix typos * Add warning when using non-zero start with Discrete (fixes #1197) * Fix type checking * Bug fixes and handle more cases * Throw proper warnings * Update test * Fix new metadata name * Ignore numpy warnings * Fixes in vec recorder * Global ignore * Filter local warning too * Monkey patch not needed for gym 26 * Add doc of VecEnv vs Gym API * Add render test * Fix return type * Update VecEnv vs Gym API doc * Fix for custom render mode * Fix return type * Fix type checking * check test env test_buffer * skip render check * check env test_dict_env * test_env test_gae * check envs in remaining tests * Update tests * Add warning for Discrete action space with non-zero (#1295) * Fix atari annotation * ignore get_action_meanings [attr-defined] * Fix mypy issues * Add patch for gym/gymnasium transition * Switch to gymnasium * Rely on signature instead of version * More patches * Type ignore because of https://github.com/Farama-Foundation/Gymnasium/pull/39 * Fix doc build * Fix pytype errors * Fix atari requirement * Update env checker due to change in dtype for Discrete * Fix type hint * Convert spaces for saved models * Ignore pytype * Remove gitlab CI * Disable pytype for convert space * Fix undefined info * Fix undefined info * Upgrade shimmy * Fix wrappers type annotation (need PR from Gymnasium) * Fix gymnasium dependency * Fix dependency declaration * Cap pygame version for python 3.7 * Point to master branch (v0.28.0) * Fix: use main not master branch * Rename done to terminated * Fix pygame dependency for python 3.7 * Rename gym to gymnasium * Update Gymnasium * Fix test * Fix tests * Forks don't have access to private variables * Fix linter warnings * Update read the doc env * Fix env checker for GoalEnv * Fix import * Update env checker (more info) and fix dtype * Use micromamab for Docker * Update dependencies * Clarify VecEnv doc * Fix Gymnasium version * Copy file only after mamba install * [ci skip] Update docker doc * Polish code * Reformat * Remove deprecated features * Ignore warning * Update doc * Update examples and changelog * Fix type annotation bundle (SAC, TD3, A2C, PPO, base class) (#1436) * Fix SAC type hints, improve DQN ones * Fix A2C and TD3 type hints * Fix PPO type hints * Fix on-policy type hints * Fix base class type annotation, do not use defaults * Update version * Disable mypy for python 3.7 * Rename Gym26StepReturn * Update continuous critic type annotation * Fix pytype complain --------- Co-authored-by: Carlos Luis <carlos.luisgonc@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Thomas Lips <37955681+tlpss@users.noreply.github.com> Co-authored-by: tlips <thomas.lips@ugent.be> Co-authored-by: tlpss <thomas17.lips@gmail.com> Co-authored-by: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
2023-04-14 11:13:59 +00:00
terminal_value = self.policy.predict_values(terminal_obs)[0] # type: ignore[arg-type]
rewards[idx] += self.gamma * terminal_value
Add Gymnasium support (#1327) * Fix failing set_env test * Fix test failiing due to deprectation of env.seed * Adjust mean reward threshold in failing test * Fix her test failing due to rng * Change seed and revert reward threshold to 90 * Pin gym version * Make VecEnv compatible with gym seeding change * Revert change to VecEnv reset signature * Change subprocenv seed cmd to call reset instead * Fix type check * Add backward compat * Add `compat_gym_seed` helper * Add goal env checks in env_checker * Add docs on HER requirements for envs * Capture user warning in test with inverted box space * Update ale-py version * Fix randint * Allow noop_max to be zero * Update changelog * Update docker image * Update doc conda env and dockerfile * Custom envs should not have any warnings * Fix test for numpy >= 1.21 * Add check for vectorized compute reward * Bump to gym 0.24 * Fix gym default step docstring * Test downgrading gym * Revert "Test downgrading gym" This reverts commit 0072b77156c006ada8a1d6e26ce347ed85a83eeb. * Fix protobuf error * Fix in dependencies * Fix protobuf dep * Use newest version of cartpole * Update gym * Fix warning * Loosen required scipy version * Scipy no longer needed * Try gym 0.25 * Silence warnings from gym * Filter warnings during tests * Update doc * Update requirements * Add gym 26 compat in vec env * Fixes in envs and tests for gym 0.26+ * Enforce gym 0.26 api * format * Fix formatting * Fix dependencies * Fix syntax * Cleanup doc and warnings * Faster tests * Higher budget for HER perf test (revert prev change) * Fixes and update doc * Fix doc build * Fix breaking change * Fixes for rendering * Rename variables in monitor * update render method for gym 0.26 API backwards compatible (mode argument is allowed) while using the gym 0.26 API (render mode is determined at environment creation) * update tests and docs to new gym render API * undo removal of render modes metatadata check * set rgb_array as default render mode for gym.make * undo changes & raise warning if not 'rgb_array' * Fix type check * Remove recursion and fix type checking * Remove hacks for protobuf and gym 0.24 * Fix type annotations * reuse existing render_mode attribute * return tiled images for 'human' render mode * Allow to use opencv for human render, fix typos * Add warning when using non-zero start with Discrete (fixes #1197) * Fix type checking * Bug fixes and handle more cases * Throw proper warnings * Update test * Fix new metadata name * Ignore numpy warnings * Fixes in vec recorder * Global ignore * Filter local warning too * Monkey patch not needed for gym 26 * Add doc of VecEnv vs Gym API * Add render test * Fix return type * Update VecEnv vs Gym API doc * Fix for custom render mode * Fix return type * Fix type checking * check test env test_buffer * skip render check * check env test_dict_env * test_env test_gae * check envs in remaining tests * Update tests * Add warning for Discrete action space with non-zero (#1295) * Fix atari annotation * ignore get_action_meanings [attr-defined] * Fix mypy issues * Add patch for gym/gymnasium transition * Switch to gymnasium * Rely on signature instead of version * More patches * Type ignore because of https://github.com/Farama-Foundation/Gymnasium/pull/39 * Fix doc build * Fix pytype errors * Fix atari requirement * Update env checker due to change in dtype for Discrete * Fix type hint * Convert spaces for saved models * Ignore pytype * Remove gitlab CI * Disable pytype for convert space * Fix undefined info * Fix undefined info * Upgrade shimmy * Fix wrappers type annotation (need PR from Gymnasium) * Fix gymnasium dependency * Fix dependency declaration * Cap pygame version for python 3.7 * Point to master branch (v0.28.0) * Fix: use main not master branch * Rename done to terminated * Fix pygame dependency for python 3.7 * Rename gym to gymnasium * Update Gymnasium * Fix test * Fix tests * Forks don't have access to private variables * Fix linter warnings * Update read the doc env * Fix env checker for GoalEnv * Fix import * Update env checker (more info) and fix dtype * Use micromamab for Docker * Update dependencies * Clarify VecEnv doc * Fix Gymnasium version * Copy file only after mamba install * [ci skip] Update docker doc * Polish code * Reformat * Remove deprecated features * Ignore warning * Update doc * Update examples and changelog * Fix type annotation bundle (SAC, TD3, A2C, PPO, base class) (#1436) * Fix SAC type hints, improve DQN ones * Fix A2C and TD3 type hints * Fix PPO type hints * Fix on-policy type hints * Fix base class type annotation, do not use defaults * Update version * Disable mypy for python 3.7 * Rename Gym26StepReturn * Update continuous critic type annotation * Fix pytype complain --------- Co-authored-by: Carlos Luis <carlos.luisgonc@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Thomas Lips <37955681+tlpss@users.noreply.github.com> Co-authored-by: tlips <thomas.lips@ugent.be> Co-authored-by: tlpss <thomas17.lips@gmail.com> Co-authored-by: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
2023-04-14 11:13:59 +00:00
rollout_buffer.add(
self._last_obs, # type: ignore[arg-type]
actions,
rewards,
self._last_episode_starts, # type: ignore[arg-type]
values,
log_probs,
)
self._last_obs = new_obs # type: ignore[assignment]
self._last_episode_starts = dones
with th.no_grad():
# Compute value for the last timestep
Add Gymnasium support (#1327) * Fix failing set_env test * Fix test failiing due to deprectation of env.seed * Adjust mean reward threshold in failing test * Fix her test failing due to rng * Change seed and revert reward threshold to 90 * Pin gym version * Make VecEnv compatible with gym seeding change * Revert change to VecEnv reset signature * Change subprocenv seed cmd to call reset instead * Fix type check * Add backward compat * Add `compat_gym_seed` helper * Add goal env checks in env_checker * Add docs on HER requirements for envs * Capture user warning in test with inverted box space * Update ale-py version * Fix randint * Allow noop_max to be zero * Update changelog * Update docker image * Update doc conda env and dockerfile * Custom envs should not have any warnings * Fix test for numpy >= 1.21 * Add check for vectorized compute reward * Bump to gym 0.24 * Fix gym default step docstring * Test downgrading gym * Revert "Test downgrading gym" This reverts commit 0072b77156c006ada8a1d6e26ce347ed85a83eeb. * Fix protobuf error * Fix in dependencies * Fix protobuf dep * Use newest version of cartpole * Update gym * Fix warning * Loosen required scipy version * Scipy no longer needed * Try gym 0.25 * Silence warnings from gym * Filter warnings during tests * Update doc * Update requirements * Add gym 26 compat in vec env * Fixes in envs and tests for gym 0.26+ * Enforce gym 0.26 api * format * Fix formatting * Fix dependencies * Fix syntax * Cleanup doc and warnings * Faster tests * Higher budget for HER perf test (revert prev change) * Fixes and update doc * Fix doc build * Fix breaking change * Fixes for rendering * Rename variables in monitor * update render method for gym 0.26 API backwards compatible (mode argument is allowed) while using the gym 0.26 API (render mode is determined at environment creation) * update tests and docs to new gym render API * undo removal of render modes metatadata check * set rgb_array as default render mode for gym.make * undo changes & raise warning if not 'rgb_array' * Fix type check * Remove recursion and fix type checking * Remove hacks for protobuf and gym 0.24 * Fix type annotations * reuse existing render_mode attribute * return tiled images for 'human' render mode * Allow to use opencv for human render, fix typos * Add warning when using non-zero start with Discrete (fixes #1197) * Fix type checking * Bug fixes and handle more cases * Throw proper warnings * Update test * Fix new metadata name * Ignore numpy warnings * Fixes in vec recorder * Global ignore * Filter local warning too * Monkey patch not needed for gym 26 * Add doc of VecEnv vs Gym API * Add render test * Fix return type * Update VecEnv vs Gym API doc * Fix for custom render mode * Fix return type * Fix type checking * check test env test_buffer * skip render check * check env test_dict_env * test_env test_gae * check envs in remaining tests * Update tests * Add warning for Discrete action space with non-zero (#1295) * Fix atari annotation * ignore get_action_meanings [attr-defined] * Fix mypy issues * Add patch for gym/gymnasium transition * Switch to gymnasium * Rely on signature instead of version * More patches * Type ignore because of https://github.com/Farama-Foundation/Gymnasium/pull/39 * Fix doc build * Fix pytype errors * Fix atari requirement * Update env checker due to change in dtype for Discrete * Fix type hint * Convert spaces for saved models * Ignore pytype * Remove gitlab CI * Disable pytype for convert space * Fix undefined info * Fix undefined info * Upgrade shimmy * Fix wrappers type annotation (need PR from Gymnasium) * Fix gymnasium dependency * Fix dependency declaration * Cap pygame version for python 3.7 * Point to master branch (v0.28.0) * Fix: use main not master branch * Rename done to terminated * Fix pygame dependency for python 3.7 * Rename gym to gymnasium * Update Gymnasium * Fix test * Fix tests * Forks don't have access to private variables * Fix linter warnings * Update read the doc env * Fix env checker for GoalEnv * Fix import * Update env checker (more info) and fix dtype * Use micromamab for Docker * Update dependencies * Clarify VecEnv doc * Fix Gymnasium version * Copy file only after mamba install * [ci skip] Update docker doc * Polish code * Reformat * Remove deprecated features * Ignore warning * Update doc * Update examples and changelog * Fix type annotation bundle (SAC, TD3, A2C, PPO, base class) (#1436) * Fix SAC type hints, improve DQN ones * Fix A2C and TD3 type hints * Fix PPO type hints * Fix on-policy type hints * Fix base class type annotation, do not use defaults * Update version * Disable mypy for python 3.7 * Rename Gym26StepReturn * Update continuous critic type annotation * Fix pytype complain --------- Co-authored-by: Carlos Luis <carlos.luisgonc@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Thomas Lips <37955681+tlpss@users.noreply.github.com> Co-authored-by: tlips <thomas.lips@ugent.be> Co-authored-by: tlpss <thomas17.lips@gmail.com> Co-authored-by: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
2023-04-14 11:13:59 +00:00
values = self.policy.predict_values(obs_as_tensor(new_obs, self.device)) # type: ignore[arg-type]
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
callback.update_locals(locals())
callback.on_rollout_end()
return True
def train(self) -> None:
"""
Consume current rollout data and update policy parameters.
Implemented by individual algorithms.
"""
raise NotImplementedError
def _dump_logs(self, iteration: int) -> None:
"""
Write log.
:param iteration: Current logging iteration
"""
assert self.ep_info_buffer is not None
assert self.ep_success_buffer is not None
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
self.logger.record("time/iterations", iteration, exclude="tensorboard")
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
self.logger.record("time/fps", fps)
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
if len(self.ep_success_buffer) > 0:
self.logger.record("rollout/success_rate", safe_mean(self.ep_success_buffer))
self.logger.dump(step=self.num_timesteps)
def learn(
self: SelfOnPolicyAlgorithm,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 1,
tb_log_name: str = "OnPolicyAlgorithm",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfOnPolicyAlgorithm:
iteration = 0
total_timesteps, callback = self._setup_learn(
total_timesteps,
callback,
reset_num_timesteps,
tb_log_name,
progress_bar,
)
callback.on_training_start(locals(), globals())
Add Gymnasium support (#1327) * Fix failing set_env test * Fix test failiing due to deprectation of env.seed * Adjust mean reward threshold in failing test * Fix her test failing due to rng * Change seed and revert reward threshold to 90 * Pin gym version * Make VecEnv compatible with gym seeding change * Revert change to VecEnv reset signature * Change subprocenv seed cmd to call reset instead * Fix type check * Add backward compat * Add `compat_gym_seed` helper * Add goal env checks in env_checker * Add docs on HER requirements for envs * Capture user warning in test with inverted box space * Update ale-py version * Fix randint * Allow noop_max to be zero * Update changelog * Update docker image * Update doc conda env and dockerfile * Custom envs should not have any warnings * Fix test for numpy >= 1.21 * Add check for vectorized compute reward * Bump to gym 0.24 * Fix gym default step docstring * Test downgrading gym * Revert "Test downgrading gym" This reverts commit 0072b77156c006ada8a1d6e26ce347ed85a83eeb. * Fix protobuf error * Fix in dependencies * Fix protobuf dep * Use newest version of cartpole * Update gym * Fix warning * Loosen required scipy version * Scipy no longer needed * Try gym 0.25 * Silence warnings from gym * Filter warnings during tests * Update doc * Update requirements * Add gym 26 compat in vec env * Fixes in envs and tests for gym 0.26+ * Enforce gym 0.26 api * format * Fix formatting * Fix dependencies * Fix syntax * Cleanup doc and warnings * Faster tests * Higher budget for HER perf test (revert prev change) * Fixes and update doc * Fix doc build * Fix breaking change * Fixes for rendering * Rename variables in monitor * update render method for gym 0.26 API backwards compatible (mode argument is allowed) while using the gym 0.26 API (render mode is determined at environment creation) * update tests and docs to new gym render API * undo removal of render modes metatadata check * set rgb_array as default render mode for gym.make * undo changes & raise warning if not 'rgb_array' * Fix type check * Remove recursion and fix type checking * Remove hacks for protobuf and gym 0.24 * Fix type annotations * reuse existing render_mode attribute * return tiled images for 'human' render mode * Allow to use opencv for human render, fix typos * Add warning when using non-zero start with Discrete (fixes #1197) * Fix type checking * Bug fixes and handle more cases * Throw proper warnings * Update test * Fix new metadata name * Ignore numpy warnings * Fixes in vec recorder * Global ignore * Filter local warning too * Monkey patch not needed for gym 26 * Add doc of VecEnv vs Gym API * Add render test * Fix return type * Update VecEnv vs Gym API doc * Fix for custom render mode * Fix return type * Fix type checking * check test env test_buffer * skip render check * check env test_dict_env * test_env test_gae * check envs in remaining tests * Update tests * Add warning for Discrete action space with non-zero (#1295) * Fix atari annotation * ignore get_action_meanings [attr-defined] * Fix mypy issues * Add patch for gym/gymnasium transition * Switch to gymnasium * Rely on signature instead of version * More patches * Type ignore because of https://github.com/Farama-Foundation/Gymnasium/pull/39 * Fix doc build * Fix pytype errors * Fix atari requirement * Update env checker due to change in dtype for Discrete * Fix type hint * Convert spaces for saved models * Ignore pytype * Remove gitlab CI * Disable pytype for convert space * Fix undefined info * Fix undefined info * Upgrade shimmy * Fix wrappers type annotation (need PR from Gymnasium) * Fix gymnasium dependency * Fix dependency declaration * Cap pygame version for python 3.7 * Point to master branch (v0.28.0) * Fix: use main not master branch * Rename done to terminated * Fix pygame dependency for python 3.7 * Rename gym to gymnasium * Update Gymnasium * Fix test * Fix tests * Forks don't have access to private variables * Fix linter warnings * Update read the doc env * Fix env checker for GoalEnv * Fix import * Update env checker (more info) and fix dtype * Use micromamab for Docker * Update dependencies * Clarify VecEnv doc * Fix Gymnasium version * Copy file only after mamba install * [ci skip] Update docker doc * Polish code * Reformat * Remove deprecated features * Ignore warning * Update doc * Update examples and changelog * Fix type annotation bundle (SAC, TD3, A2C, PPO, base class) (#1436) * Fix SAC type hints, improve DQN ones * Fix A2C and TD3 type hints * Fix PPO type hints * Fix on-policy type hints * Fix base class type annotation, do not use defaults * Update version * Disable mypy for python 3.7 * Rename Gym26StepReturn * Update continuous critic type annotation * Fix pytype complain --------- Co-authored-by: Carlos Luis <carlos.luisgonc@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Thomas Lips <37955681+tlpss@users.noreply.github.com> Co-authored-by: tlips <thomas.lips@ugent.be> Co-authored-by: tlpss <thomas17.lips@gmail.com> Co-authored-by: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
2023-04-14 11:13:59 +00:00
assert self.env is not None
while self.num_timesteps < total_timesteps:
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
if not continue_training:
break
iteration += 1
self._update_current_progress_remaining(self.num_timesteps, total_timesteps)
# Display training infos
if log_interval is not None and iteration % log_interval == 0:
Add Gymnasium support (#1327) * Fix failing set_env test * Fix test failiing due to deprectation of env.seed * Adjust mean reward threshold in failing test * Fix her test failing due to rng * Change seed and revert reward threshold to 90 * Pin gym version * Make VecEnv compatible with gym seeding change * Revert change to VecEnv reset signature * Change subprocenv seed cmd to call reset instead * Fix type check * Add backward compat * Add `compat_gym_seed` helper * Add goal env checks in env_checker * Add docs on HER requirements for envs * Capture user warning in test with inverted box space * Update ale-py version * Fix randint * Allow noop_max to be zero * Update changelog * Update docker image * Update doc conda env and dockerfile * Custom envs should not have any warnings * Fix test for numpy >= 1.21 * Add check for vectorized compute reward * Bump to gym 0.24 * Fix gym default step docstring * Test downgrading gym * Revert "Test downgrading gym" This reverts commit 0072b77156c006ada8a1d6e26ce347ed85a83eeb. * Fix protobuf error * Fix in dependencies * Fix protobuf dep * Use newest version of cartpole * Update gym * Fix warning * Loosen required scipy version * Scipy no longer needed * Try gym 0.25 * Silence warnings from gym * Filter warnings during tests * Update doc * Update requirements * Add gym 26 compat in vec env * Fixes in envs and tests for gym 0.26+ * Enforce gym 0.26 api * format * Fix formatting * Fix dependencies * Fix syntax * Cleanup doc and warnings * Faster tests * Higher budget for HER perf test (revert prev change) * Fixes and update doc * Fix doc build * Fix breaking change * Fixes for rendering * Rename variables in monitor * update render method for gym 0.26 API backwards compatible (mode argument is allowed) while using the gym 0.26 API (render mode is determined at environment creation) * update tests and docs to new gym render API * undo removal of render modes metatadata check * set rgb_array as default render mode for gym.make * undo changes & raise warning if not 'rgb_array' * Fix type check * Remove recursion and fix type checking * Remove hacks for protobuf and gym 0.24 * Fix type annotations * reuse existing render_mode attribute * return tiled images for 'human' render mode * Allow to use opencv for human render, fix typos * Add warning when using non-zero start with Discrete (fixes #1197) * Fix type checking * Bug fixes and handle more cases * Throw proper warnings * Update test * Fix new metadata name * Ignore numpy warnings * Fixes in vec recorder * Global ignore * Filter local warning too * Monkey patch not needed for gym 26 * Add doc of VecEnv vs Gym API * Add render test * Fix return type * Update VecEnv vs Gym API doc * Fix for custom render mode * Fix return type * Fix type checking * check test env test_buffer * skip render check * check env test_dict_env * test_env test_gae * check envs in remaining tests * Update tests * Add warning for Discrete action space with non-zero (#1295) * Fix atari annotation * ignore get_action_meanings [attr-defined] * Fix mypy issues * Add patch for gym/gymnasium transition * Switch to gymnasium * Rely on signature instead of version * More patches * Type ignore because of https://github.com/Farama-Foundation/Gymnasium/pull/39 * Fix doc build * Fix pytype errors * Fix atari requirement * Update env checker due to change in dtype for Discrete * Fix type hint * Convert spaces for saved models * Ignore pytype * Remove gitlab CI * Disable pytype for convert space * Fix undefined info * Fix undefined info * Upgrade shimmy * Fix wrappers type annotation (need PR from Gymnasium) * Fix gymnasium dependency * Fix dependency declaration * Cap pygame version for python 3.7 * Point to master branch (v0.28.0) * Fix: use main not master branch * Rename done to terminated * Fix pygame dependency for python 3.7 * Rename gym to gymnasium * Update Gymnasium * Fix test * Fix tests * Forks don't have access to private variables * Fix linter warnings * Update read the doc env * Fix env checker for GoalEnv * Fix import * Update env checker (more info) and fix dtype * Use micromamab for Docker * Update dependencies * Clarify VecEnv doc * Fix Gymnasium version * Copy file only after mamba install * [ci skip] Update docker doc * Polish code * Reformat * Remove deprecated features * Ignore warning * Update doc * Update examples and changelog * Fix type annotation bundle (SAC, TD3, A2C, PPO, base class) (#1436) * Fix SAC type hints, improve DQN ones * Fix A2C and TD3 type hints * Fix PPO type hints * Fix on-policy type hints * Fix base class type annotation, do not use defaults * Update version * Disable mypy for python 3.7 * Rename Gym26StepReturn * Update continuous critic type annotation * Fix pytype complain --------- Co-authored-by: Carlos Luis <carlos.luisgonc@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Thomas Lips <37955681+tlpss@users.noreply.github.com> Co-authored-by: tlips <thomas.lips@ugent.be> Co-authored-by: tlpss <thomas17.lips@gmail.com> Co-authored-by: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
2023-04-14 11:13:59 +00:00
assert self.ep_info_buffer is not None
self._dump_logs(iteration)
self.train()
callback.on_training_end()
return self
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "policy.optimizer"]
return state_dicts, []