stable-baselines3/stable_baselines3/common/base_class.py

795 lines
34 KiB
Python
Raw Normal View History

2020-07-03 01:49:59 +00:00
"""Abstract base classes for RL algorithms."""
import io
import pathlib
2019-10-10 11:47:13 +00:00
import time
from abc import ABC, abstractmethod
from collections import deque
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union
2019-09-05 15:29:41 +00:00
import gym
2019-09-12 09:19:06 +00:00
import numpy as np
import torch as th
2019-09-05 15:29:41 +00:00
from stable_baselines3.common import utils
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback
from stable_baselines3.common.env_util import is_wrapped
from stable_baselines3.common.logger import Logger
2020-05-05 13:02:35 +00:00
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.policies import BasePolicy
Dictionary Observations (#243) * First commit * Fixing missing refs from a quick merge from master * Reformat * Adding DictBuffers * Reformat * Minor reformat * added slow dict test. Added SACMultiInputPolicy for future. Added private static image transpose helper to common policy * Ran black on buffers * Ran isort * Adding StackedObservations classes used within VecStackEnvs wrappers. Made test_dict_env shorter and removed slow * Running isort :facepalm * Fixed typing issues * Adding docstrings and typing. Using util for moving data to device. * Fixed trailing commas * Fix types * Minor edits * Avoid duplicating code * Fix calls to parents * Adding assert to buffers. Updating changelong * Running format on buffers * Adding multi-input policies to dqn,td3,a2c. Fixing warnings. Fixed bug with DictReplayBuffer as Replay buffers use only 1 env * Fixing warnings, splitting is_vectorized_observation into multiple functions based on space type * Created envs folder in common. Updated imports. Moved stacked_obs to vec_env folder * Moved envs to envs directory. Moved stacked obs to vec_envs. Started update on documentation * Fixes * Running code style * Update docstrings on torch_layers * Decapitalize non-constant variables * Using NatureCNN architecture in combined extractor. Increasing img size in multi input env. Adding memory reduction in test * Update doc * Update doc * Fix format * Removing NineRoom env. Using nested preprocess. Removing mutable default args * running code style * Passing channel check through to stacked dict observations. * Running black * Adding channel control to SimpleMultiObsEnv. Passing check_channels to CombinedExtractor * Remove optimize memory for dict buffers * Update doc * Move identity env * Minor edits + bump version * Update doc * Fix doc build * Bug fixes + add support for more type of dict env * Fixes + add multi env test * Add support for vectranspose * Fix stacked obs for dict and add tests * Add check for nested spaces. Fix dict-subprocvecenv test * Fix (single) pytype error * Simplify CombinedExtractor * Fix tests * Fix check * Merge branch 'master' into feat/dict_observations * Fix for net_arch with dict and vector obs * Fixes * Add consistency test * Update env checker * Add some docs on dict obs * Update default CNN feature vector size * Refactor HER (#351) * Start refactoring HER * Fixes * Additional fixes * Faster tests * WIP: HER as a custom replay buffer * New replay only version (working with DQN) * Add support for all off-policy algorithms * Fix saving/loading * Remove ObsDictWrapper and add VecNormalize tests with dict * Stable-Baselines3 v1.0 (#354) * Bump version and update doc * Fix name * Apply suggestions from code review Co-authored-by: Adam Gleave <adam@gleave.me> * Update docs/index.rst Co-authored-by: Adam Gleave <adam@gleave.me> * Update wording for RL zoo Co-authored-by: Adam Gleave <adam@gleave.me> * Add gym-pybullet-drones project (#358) * Update projects.rst Added gym-pybullet-drones * Update projects.rst Longer title underline * Update changelog Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org> * Include SuperSuit in projects (#359) * include supersuit * longer title underline * Update changelog.rst * Fix default arguments + add bugbear (#363) * Fix potential bug + add bug bear * Remove unused variables * Minor: version bump * Add code of conduct + update doc (#373) * Add code of conduct * Fix DQN doc example * Update doc (channel-last/first) * Apply suggestions from code review Co-authored-by: Anssi <kaneran21@hotmail.com> * Apply suggestions from code review Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Anssi <kaneran21@hotmail.com> Co-authored-by: Adam Gleave <adam@gleave.me> * Make installation command compatible with ZSH (#376) * Add quotes * Add Zsh bracket info * Add clarify pip installation line * Make note bold * Add Zsh pip installation note * Add handle timeouts param * Fixes * Fixes (buffer size, extend test) * Fix `max_episode_length` redefinition * Fix potential issue * Add some docs on dict obs * Fix performance bug * Fix slowdown * Add package to install (#378) * Add package to install * Update docs packages installation command Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Fix backward compat + add test * Fix VecEnv detection * Update doc * Fix vec env check * Support for `VecMonitor` for gym3-style environments (#311) * add vectorized monitor * auto format of the code * add documentation and VecExtractDictObs * refactor and add test cases * add test cases and format * avoid circular import and fix doc * fix type * fix type * oops * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * add test cases * update changelog * fix mutable argument * quick fix * Apply suggestions from code review * fix terminal observation for gym3 envs * delete comment * Update doc and bump version * Add warning when already using `Monitor` wrapper * Update vecmonitor tests * Fixes Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Reformat * Fixed loading of ``ent_coef`` for ``SAC`` and ``TQC``, it was not optimized anymore (#392) * Fix ent coef loading bug * Add test * Add comment * Reuse save path * Add test for GAE + rename `RolloutBuffer.dones` for clarification (#375) * Fix return computation + add test for GAE * Rename `last_dones` to `episode_starts` for clarification * Revert advantage * Cleanup test * Rename variable * Clarify return computation * Clarify docs * Add multi-episode rollout test * Reformat Co-authored-by: Anssi "Miffyli" Kanervisto <kaneran21@hotmail.com> * Fixed saving of `A2C` and `PPO` policy when using gSDE (#401) * Improve doc and replay buffer loading * Add support for images * Fix doc * Update Procgen doc * Update changelog * Update docstrings Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Jacopo Panerati <jacopo.panerati@utoronto.ca> Co-authored-by: Justin Terry <justinkterry@gmail.com> Co-authored-by: Anssi <kaneran21@hotmail.com> Co-authored-by: Tom Dörr <tomdoerr96@gmail.com> Co-authored-by: Tom Dörr <tom.doerr@tum.de> Co-authored-by: Costa Huang <costa.huang@outlook.com> * Update doc and minor fixes * Update doc * Added note about MultiInputPolicy in error of NatureCNN * Merge branch 'master' into feat/dict_observations * Address comments * Naming clarifications * Actually saving the file would be nice * Fix edge case when doing online sampling with HER * Cleanup * Add sanity check Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Anssi "Miffyli" Kanervisto <kaneran21@hotmail.com> Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Jacopo Panerati <jacopo.panerati@utoronto.ca> Co-authored-by: Justin Terry <justinkterry@gmail.com> Co-authored-by: Tom Dörr <tomdoerr96@gmail.com> Co-authored-by: Tom Dörr <tom.doerr@tum.de> Co-authored-by: Costa Huang <costa.huang@outlook.com>
2021-05-11 10:29:30 +00:00
from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space, is_image_space_channels_first
from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import (
check_for_correct_spaces,
get_device,
get_schedule_fn,
get_system_info,
set_random_seed,
update_learning_rate,
)
Implement HER (#120) * Added working her version, Online sampling is missing. * Updated test_her. * Added first version of online her sampling. Still problems with tensor dimensions. * Reformat * Fixed tests * Added some comments. * Updated changelog. * Add missing init file * Fixed some small bugs. * Reduced arguments for HER, small changes. * Added getattr. Fixed bug for online sampling. * Updated save/load funtions. Small changes. * Added her to init. * Updated save method. * Updated her ratio. * Move obs_wrapper * Added DQN test. * Fix potential bug * Offline and online her share same sample_goal function. * Changed lists into arrays. * Updated her test. * Fix online sampling * Fixed action bug. Updated time limit for episodes. * Updated convert_dict method to take keys as arguments. * Renamed obs dict wrapper. * Seed bit flipping env * Remove get_episode_dict * Add fast online sampling version * Added documentation. * Vectorized reward computation * Vectorized goal sampling * Update time limit for episodes in online her sampling. * Fix max episode length inference * Bug fix for Fetch envs * Fix for HER + gSDE * Reformat (new black version) * Added info dict to compute new reward. Check her_replay_buffer again. * Fix info buffer * Updated done flag. * Fixes for gSDE * Offline her version uses now HerReplayBuffer as episode storage. * Fix num_timesteps computation * Fix get torch params * Vectorized version for offline sampling. * Modified offline her sampling to use sample method of her_replay_buffer * Updated HER tests. * Updated documentation * Cleanup docstrings * Updated to review comments * Fix pytype * Update according to review comments. * Removed random goal strategy. Updated sample transitions. * Updated migration. Removed time signal removal. * Update doc * Fix potential load issue * Add VecNormalize support for dict obs * Updated saving/loading replay buffer for HER. * Fix test memory usage * Fixed save/load replay buffer. * Fixed save/load replay buffer * Fixed transition index after loading replay buffer in online sampling * Better error handling * Add tests for get_time_limit * More tests for VecNormalize with dict obs * Update doc * Improve HER description * Add test for sde support * Add comments * Add comments * Remove check that was always valid * Fix for terminal observation * Updated buffer size in offline version and reset of HER buffer * Reformat * Update doc * Remove np.empty + add doc * Fix loading * Updated loading replay buffer * Separate online and offline sampling + bug fixes * Update tensorboard log name * Version bump * Bug fix for special case Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de> Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-10-22 09:56:43 +00:00
from stable_baselines3.common.vec_env import (
DummyVecEnv,
VecEnv,
VecNormalize,
VecTransposeImage,
is_vecenv_wrapped,
Implement HER (#120) * Added working her version, Online sampling is missing. * Updated test_her. * Added first version of online her sampling. Still problems with tensor dimensions. * Reformat * Fixed tests * Added some comments. * Updated changelog. * Add missing init file * Fixed some small bugs. * Reduced arguments for HER, small changes. * Added getattr. Fixed bug for online sampling. * Updated save/load funtions. Small changes. * Added her to init. * Updated save method. * Updated her ratio. * Move obs_wrapper * Added DQN test. * Fix potential bug * Offline and online her share same sample_goal function. * Changed lists into arrays. * Updated her test. * Fix online sampling * Fixed action bug. Updated time limit for episodes. * Updated convert_dict method to take keys as arguments. * Renamed obs dict wrapper. * Seed bit flipping env * Remove get_episode_dict * Add fast online sampling version * Added documentation. * Vectorized reward computation * Vectorized goal sampling * Update time limit for episodes in online her sampling. * Fix max episode length inference * Bug fix for Fetch envs * Fix for HER + gSDE * Reformat (new black version) * Added info dict to compute new reward. Check her_replay_buffer again. * Fix info buffer * Updated done flag. * Fixes for gSDE * Offline her version uses now HerReplayBuffer as episode storage. * Fix num_timesteps computation * Fix get torch params * Vectorized version for offline sampling. * Modified offline her sampling to use sample method of her_replay_buffer * Updated HER tests. * Updated documentation * Cleanup docstrings * Updated to review comments * Fix pytype * Update according to review comments. * Removed random goal strategy. Updated sample transitions. * Updated migration. Removed time signal removal. * Update doc * Fix potential load issue * Add VecNormalize support for dict obs * Updated saving/loading replay buffer for HER. * Fix test memory usage * Fixed save/load replay buffer. * Fixed save/load replay buffer * Fixed transition index after loading replay buffer in online sampling * Better error handling * Add tests for get_time_limit * More tests for VecNormalize with dict obs * Update doc * Improve HER description * Add test for sde support * Add comments * Add comments * Remove check that was always valid * Fix for terminal observation * Updated buffer size in offline version and reset of HER buffer * Reformat * Update doc * Remove np.empty + add doc * Fix loading * Updated loading replay buffer * Separate online and offline sampling + bug fixes * Update tensorboard log name * Version bump * Bug fix for special case Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de> Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-10-22 09:56:43 +00:00
unwrap_vec_normalize,
)
2019-09-05 15:29:41 +00:00
SelfBaseAlgorithm = TypeVar("SelfBaseAlgorithm", bound="BaseAlgorithm")
2019-09-05 15:29:41 +00:00
def maybe_make_env(env: Union[GymEnv, str, None], verbose: int) -> Optional[GymEnv]:
2020-07-03 01:49:59 +00:00
"""If env is a string, make the environment; otherwise, return env.
:param env: The environment to learn from.
:param verbose: Verbosity level: 0 for no output, 1 for indicating if envrironment is created
2020-07-03 01:49:59 +00:00
:return A Gym (vector) environment.
"""
if isinstance(env, str):
if verbose >= 1:
print(f"Creating environment from the given name '{env}'")
env = gym.make(env)
return env
class BaseAlgorithm(ABC):
2019-09-05 15:29:41 +00:00
"""
The base of RL algorithms
2019-09-05 15:29:41 +00:00
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from
2019-09-05 15:29:41 +00:00
(if registered in Gym, can be str. Can be None for loading trained models)
:param learning_rate: learning rate for the optimizer,
it can be a function of the current progress remaining (from 1 to 0)
:param policy_kwargs: Additional arguments to be passed to the policy on creation
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages
:param device: Device on which the code should run.
2019-09-12 09:19:06 +00:00
By default, it will try to use a Cuda compatible device and fallback to cpu
if it is not possible.
:param support_multi_env: Whether the algorithm supports training
2019-11-22 12:33:12 +00:00
with multiple environments (as in A2C)
:param monitor_wrapper: When creating an environment, whether to wrap it
2019-10-10 11:47:13 +00:00
or not in a Monitor wrapper.
:param seed: Seed for the pseudo random generators
:param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
2019-11-26 14:26:12 +00:00
instead of action noise exploration (default: False)
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
2019-12-17 10:47:21 +00:00
Default: -1 (only sample at the beginning of the rollout)
:param supported_action_spaces: The action spaces supported by the algorithm.
2019-09-05 15:29:41 +00:00
"""
2020-01-27 14:53:27 +00:00
# Policy aliases (see _get_policy_from_name())
policy_aliases: Dict[str, Type[BasePolicy]] = {}
def __init__(
self,
policy: Union[str, Type[BasePolicy]],
env: Union[GymEnv, str, None],
learning_rate: Union[float, Schedule],
policy_kwargs: Optional[Dict[str, Any]] = None,
tensorboard_log: Optional[str] = None,
verbose: int = 0,
device: Union[th.device, str] = "auto",
support_multi_env: bool = False,
monitor_wrapper: bool = True,
seed: Optional[int] = None,
use_sde: bool = False,
sde_sample_freq: int = -1,
supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None,
):
if isinstance(policy, str):
self.policy_class = self._get_policy_from_name(policy)
2019-09-06 08:44:55 +00:00
else:
self.policy_class = policy
2019-09-12 09:19:06 +00:00
self.device = get_device(device)
if verbose >= 1:
2020-01-22 15:39:25 +00:00
print(f"Using {self.device} device")
2019-09-12 09:19:06 +00:00
2020-03-23 16:15:30 +00:00
self.env = None # type: Optional[GymEnv]
2019-11-14 13:35:00 +00:00
# get VecNormalize object if needed
self._vec_normalize_env = unwrap_vec_normalize(env)
2019-09-05 15:29:41 +00:00
self.verbose = verbose
self.policy_kwargs = {} if policy_kwargs is None else policy_kwargs
self.observation_space = None # type: Optional[gym.spaces.Space]
self.action_space = None # type: Optional[gym.spaces.Space]
2019-09-05 15:29:41 +00:00
self.n_envs = None
self.num_timesteps = 0
Implement DQN (#28) * Created DQN template according to the paper. Next steps: - Create Policy - Complete Training - Debug * Changed Base Class * refactor save, to be consistence with overriding the excluded_save_params function. Do not try to exclude the parameters twice. * Added simple DQN policy * Finished learn and train function - missing correct loss computation * changed collect_rollouts to work with discrete space * moved discrete space collect_rollouts to dqn * basic dqn working * deleted SDE related code * added gradient clipping and moved greedy policy to policy * changed policy to implement target network and added soft update(in fact standart tau is 1 so hard update) * fixed policy setup * rebase target_update_intervall on _n_updates * adapted all tests all tests passing * Move to stable-baseline3 * Fixes for DQN * Fix tests + add CNNPolicy * Allow any optimizer for DQN * added some util functions to create a arbitrary linear schedule, fixed pickle problem with old exploration schedule * more documentation * changed buffer dtype * refactor and document * Added Sphinx Documentation Updated changelog.rst * removed custom collect_rollouts as it is no longer necessary * Implemented suggestions to clean code and documentation. * extracted some functions on tests to reduce duplicated code * added support for exploration_fraction * Fixed exploration_fraction * Added documentation * Fixed get_linear_fn -> proper progress scaling * Merged master * Added nature reference * Changed default parameters to https://www.nature.com/articles/nature14236/tables/1 * Fixed n_updates to be incremented correctly * Correct train_freq * Doc update * added special parameter for DQN in tests * different fix for test_discrete * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Added RMSProp in optimizer_kwargs, as described in nature paper * Exploration fraction is inverse of 50.000.000 (total frames) / 1.000.000 (frames with linear schedule) according to nature paper * Changelog update for buffer dtype * standard exlude parameters should be always excluded to assure proper saving only if intentionally included by ``include`` parameter * slightly more iterations on test_discrete to pass the test * added param use_rms_prop instead of mutable default argument * forgot alpha * using huber loss, adam and learning rate 1e-4 * account for train_freq in update_target_network * Added memory check for both buffers * Doc updated for buffer allocation * Added psutil Requirement * Adapted test_identity.py * Fixes with new SB3 version * Fix for tensorboard name * Convert assert to warning and fix tests * Refactor off-policy algorithms * Fixes * test: remove next_obs in replay buffer * Update changelog * Fix tests and use tmp_path where possible * Fix sampling bug in buffer * Do not store next obs on episode termination * Fix replay buffer sampling * Update comment * moved epsilon from policy to model * Update predict method * Update atari wrappers to match SB2 * Minor edit in the buffers * Update changelog * Merge branch 'master' into dqn * Update DQN to new structure * Fix tests and remove hardcoded path * Fix for DQN * Disable memory efficient replay buffer by default * Fix docstring * Add tests for memory efficient buffer * Update changelog * Split collect rollout * Move target update outside `train()` for DQN * Update changelog * Update linear schedule doc * Cleanup DQN code * Minor edit * Update version and docker images Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-06-29 09:16:54 +00:00
# Used for updating schedules
self._total_timesteps = 0
# Used for computing fps, it is updated at each call of learn()
self._num_timesteps_at_start = 0
2019-10-10 11:47:13 +00:00
self.seed = seed
2022-11-25 14:14:55 +00:00
self.action_noise: Optional[ActionNoise] = None
2020-01-22 16:17:12 +00:00
self.start_time = None
2020-02-03 17:18:41 +00:00
self.policy = None
self.learning_rate = learning_rate
self.tensorboard_log = tensorboard_log
self.lr_schedule = None # type: Optional[Schedule]
Dictionary Observations (#243) * First commit * Fixing missing refs from a quick merge from master * Reformat * Adding DictBuffers * Reformat * Minor reformat * added slow dict test. Added SACMultiInputPolicy for future. Added private static image transpose helper to common policy * Ran black on buffers * Ran isort * Adding StackedObservations classes used within VecStackEnvs wrappers. Made test_dict_env shorter and removed slow * Running isort :facepalm * Fixed typing issues * Adding docstrings and typing. Using util for moving data to device. * Fixed trailing commas * Fix types * Minor edits * Avoid duplicating code * Fix calls to parents * Adding assert to buffers. Updating changelong * Running format on buffers * Adding multi-input policies to dqn,td3,a2c. Fixing warnings. Fixed bug with DictReplayBuffer as Replay buffers use only 1 env * Fixing warnings, splitting is_vectorized_observation into multiple functions based on space type * Created envs folder in common. Updated imports. Moved stacked_obs to vec_env folder * Moved envs to envs directory. Moved stacked obs to vec_envs. Started update on documentation * Fixes * Running code style * Update docstrings on torch_layers * Decapitalize non-constant variables * Using NatureCNN architecture in combined extractor. Increasing img size in multi input env. Adding memory reduction in test * Update doc * Update doc * Fix format * Removing NineRoom env. Using nested preprocess. Removing mutable default args * running code style * Passing channel check through to stacked dict observations. * Running black * Adding channel control to SimpleMultiObsEnv. Passing check_channels to CombinedExtractor * Remove optimize memory for dict buffers * Update doc * Move identity env * Minor edits + bump version * Update doc * Fix doc build * Bug fixes + add support for more type of dict env * Fixes + add multi env test * Add support for vectranspose * Fix stacked obs for dict and add tests * Add check for nested spaces. Fix dict-subprocvecenv test * Fix (single) pytype error * Simplify CombinedExtractor * Fix tests * Fix check * Merge branch 'master' into feat/dict_observations * Fix for net_arch with dict and vector obs * Fixes * Add consistency test * Update env checker * Add some docs on dict obs * Update default CNN feature vector size * Refactor HER (#351) * Start refactoring HER * Fixes * Additional fixes * Faster tests * WIP: HER as a custom replay buffer * New replay only version (working with DQN) * Add support for all off-policy algorithms * Fix saving/loading * Remove ObsDictWrapper and add VecNormalize tests with dict * Stable-Baselines3 v1.0 (#354) * Bump version and update doc * Fix name * Apply suggestions from code review Co-authored-by: Adam Gleave <adam@gleave.me> * Update docs/index.rst Co-authored-by: Adam Gleave <adam@gleave.me> * Update wording for RL zoo Co-authored-by: Adam Gleave <adam@gleave.me> * Add gym-pybullet-drones project (#358) * Update projects.rst Added gym-pybullet-drones * Update projects.rst Longer title underline * Update changelog Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org> * Include SuperSuit in projects (#359) * include supersuit * longer title underline * Update changelog.rst * Fix default arguments + add bugbear (#363) * Fix potential bug + add bug bear * Remove unused variables * Minor: version bump * Add code of conduct + update doc (#373) * Add code of conduct * Fix DQN doc example * Update doc (channel-last/first) * Apply suggestions from code review Co-authored-by: Anssi <kaneran21@hotmail.com> * Apply suggestions from code review Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Anssi <kaneran21@hotmail.com> Co-authored-by: Adam Gleave <adam@gleave.me> * Make installation command compatible with ZSH (#376) * Add quotes * Add Zsh bracket info * Add clarify pip installation line * Make note bold * Add Zsh pip installation note * Add handle timeouts param * Fixes * Fixes (buffer size, extend test) * Fix `max_episode_length` redefinition * Fix potential issue * Add some docs on dict obs * Fix performance bug * Fix slowdown * Add package to install (#378) * Add package to install * Update docs packages installation command Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Fix backward compat + add test * Fix VecEnv detection * Update doc * Fix vec env check * Support for `VecMonitor` for gym3-style environments (#311) * add vectorized monitor * auto format of the code * add documentation and VecExtractDictObs * refactor and add test cases * add test cases and format * avoid circular import and fix doc * fix type * fix type * oops * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * add test cases * update changelog * fix mutable argument * quick fix * Apply suggestions from code review * fix terminal observation for gym3 envs * delete comment * Update doc and bump version * Add warning when already using `Monitor` wrapper * Update vecmonitor tests * Fixes Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Reformat * Fixed loading of ``ent_coef`` for ``SAC`` and ``TQC``, it was not optimized anymore (#392) * Fix ent coef loading bug * Add test * Add comment * Reuse save path * Add test for GAE + rename `RolloutBuffer.dones` for clarification (#375) * Fix return computation + add test for GAE * Rename `last_dones` to `episode_starts` for clarification * Revert advantage * Cleanup test * Rename variable * Clarify return computation * Clarify docs * Add multi-episode rollout test * Reformat Co-authored-by: Anssi "Miffyli" Kanervisto <kaneran21@hotmail.com> * Fixed saving of `A2C` and `PPO` policy when using gSDE (#401) * Improve doc and replay buffer loading * Add support for images * Fix doc * Update Procgen doc * Update changelog * Update docstrings Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Jacopo Panerati <jacopo.panerati@utoronto.ca> Co-authored-by: Justin Terry <justinkterry@gmail.com> Co-authored-by: Anssi <kaneran21@hotmail.com> Co-authored-by: Tom Dörr <tomdoerr96@gmail.com> Co-authored-by: Tom Dörr <tom.doerr@tum.de> Co-authored-by: Costa Huang <costa.huang@outlook.com> * Update doc and minor fixes * Update doc * Added note about MultiInputPolicy in error of NatureCNN * Merge branch 'master' into feat/dict_observations * Address comments * Naming clarifications * Actually saving the file would be nice * Fix edge case when doing online sampling with HER * Cleanup * Add sanity check Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Anssi "Miffyli" Kanervisto <kaneran21@hotmail.com> Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Jacopo Panerati <jacopo.panerati@utoronto.ca> Co-authored-by: Justin Terry <justinkterry@gmail.com> Co-authored-by: Tom Dörr <tomdoerr96@gmail.com> Co-authored-by: Tom Dörr <tom.doerr@tum.de> Co-authored-by: Costa Huang <costa.huang@outlook.com>
2021-05-11 10:29:30 +00:00
self._last_obs = None # type: Optional[Union[np.ndarray, Dict[str, np.ndarray]]]
self._last_episode_starts = None # type: Optional[np.ndarray]
2020-04-17 10:36:27 +00:00
# When using VecNormalize:
Dictionary Observations (#243) * First commit * Fixing missing refs from a quick merge from master * Reformat * Adding DictBuffers * Reformat * Minor reformat * added slow dict test. Added SACMultiInputPolicy for future. Added private static image transpose helper to common policy * Ran black on buffers * Ran isort * Adding StackedObservations classes used within VecStackEnvs wrappers. Made test_dict_env shorter and removed slow * Running isort :facepalm * Fixed typing issues * Adding docstrings and typing. Using util for moving data to device. * Fixed trailing commas * Fix types * Minor edits * Avoid duplicating code * Fix calls to parents * Adding assert to buffers. Updating changelong * Running format on buffers * Adding multi-input policies to dqn,td3,a2c. Fixing warnings. Fixed bug with DictReplayBuffer as Replay buffers use only 1 env * Fixing warnings, splitting is_vectorized_observation into multiple functions based on space type * Created envs folder in common. Updated imports. Moved stacked_obs to vec_env folder * Moved envs to envs directory. Moved stacked obs to vec_envs. Started update on documentation * Fixes * Running code style * Update docstrings on torch_layers * Decapitalize non-constant variables * Using NatureCNN architecture in combined extractor. Increasing img size in multi input env. Adding memory reduction in test * Update doc * Update doc * Fix format * Removing NineRoom env. Using nested preprocess. Removing mutable default args * running code style * Passing channel check through to stacked dict observations. * Running black * Adding channel control to SimpleMultiObsEnv. Passing check_channels to CombinedExtractor * Remove optimize memory for dict buffers * Update doc * Move identity env * Minor edits + bump version * Update doc * Fix doc build * Bug fixes + add support for more type of dict env * Fixes + add multi env test * Add support for vectranspose * Fix stacked obs for dict and add tests * Add check for nested spaces. Fix dict-subprocvecenv test * Fix (single) pytype error * Simplify CombinedExtractor * Fix tests * Fix check * Merge branch 'master' into feat/dict_observations * Fix for net_arch with dict and vector obs * Fixes * Add consistency test * Update env checker * Add some docs on dict obs * Update default CNN feature vector size * Refactor HER (#351) * Start refactoring HER * Fixes * Additional fixes * Faster tests * WIP: HER as a custom replay buffer * New replay only version (working with DQN) * Add support for all off-policy algorithms * Fix saving/loading * Remove ObsDictWrapper and add VecNormalize tests with dict * Stable-Baselines3 v1.0 (#354) * Bump version and update doc * Fix name * Apply suggestions from code review Co-authored-by: Adam Gleave <adam@gleave.me> * Update docs/index.rst Co-authored-by: Adam Gleave <adam@gleave.me> * Update wording for RL zoo Co-authored-by: Adam Gleave <adam@gleave.me> * Add gym-pybullet-drones project (#358) * Update projects.rst Added gym-pybullet-drones * Update projects.rst Longer title underline * Update changelog Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org> * Include SuperSuit in projects (#359) * include supersuit * longer title underline * Update changelog.rst * Fix default arguments + add bugbear (#363) * Fix potential bug + add bug bear * Remove unused variables * Minor: version bump * Add code of conduct + update doc (#373) * Add code of conduct * Fix DQN doc example * Update doc (channel-last/first) * Apply suggestions from code review Co-authored-by: Anssi <kaneran21@hotmail.com> * Apply suggestions from code review Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Anssi <kaneran21@hotmail.com> Co-authored-by: Adam Gleave <adam@gleave.me> * Make installation command compatible with ZSH (#376) * Add quotes * Add Zsh bracket info * Add clarify pip installation line * Make note bold * Add Zsh pip installation note * Add handle timeouts param * Fixes * Fixes (buffer size, extend test) * Fix `max_episode_length` redefinition * Fix potential issue * Add some docs on dict obs * Fix performance bug * Fix slowdown * Add package to install (#378) * Add package to install * Update docs packages installation command Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Fix backward compat + add test * Fix VecEnv detection * Update doc * Fix vec env check * Support for `VecMonitor` for gym3-style environments (#311) * add vectorized monitor * auto format of the code * add documentation and VecExtractDictObs * refactor and add test cases * add test cases and format * avoid circular import and fix doc * fix type * fix type * oops * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * add test cases * update changelog * fix mutable argument * quick fix * Apply suggestions from code review * fix terminal observation for gym3 envs * delete comment * Update doc and bump version * Add warning when already using `Monitor` wrapper * Update vecmonitor tests * Fixes Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Reformat * Fixed loading of ``ent_coef`` for ``SAC`` and ``TQC``, it was not optimized anymore (#392) * Fix ent coef loading bug * Add test * Add comment * Reuse save path * Add test for GAE + rename `RolloutBuffer.dones` for clarification (#375) * Fix return computation + add test for GAE * Rename `last_dones` to `episode_starts` for clarification * Revert advantage * Cleanup test * Rename variable * Clarify return computation * Clarify docs * Add multi-episode rollout test * Reformat Co-authored-by: Anssi "Miffyli" Kanervisto <kaneran21@hotmail.com> * Fixed saving of `A2C` and `PPO` policy when using gSDE (#401) * Improve doc and replay buffer loading * Add support for images * Fix doc * Update Procgen doc * Update changelog * Update docstrings Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Jacopo Panerati <jacopo.panerati@utoronto.ca> Co-authored-by: Justin Terry <justinkterry@gmail.com> Co-authored-by: Anssi <kaneran21@hotmail.com> Co-authored-by: Tom Dörr <tomdoerr96@gmail.com> Co-authored-by: Tom Dörr <tom.doerr@tum.de> Co-authored-by: Costa Huang <costa.huang@outlook.com> * Update doc and minor fixes * Update doc * Added note about MultiInputPolicy in error of NatureCNN * Merge branch 'master' into feat/dict_observations * Address comments * Naming clarifications * Actually saving the file would be nice * Fix edge case when doing online sampling with HER * Cleanup * Add sanity check Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Anssi "Miffyli" Kanervisto <kaneran21@hotmail.com> Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Jacopo Panerati <jacopo.panerati@utoronto.ca> Co-authored-by: Justin Terry <justinkterry@gmail.com> Co-authored-by: Tom Dörr <tomdoerr96@gmail.com> Co-authored-by: Tom Dörr <tom.doerr@tum.de> Co-authored-by: Costa Huang <costa.huang@outlook.com>
2021-05-11 10:29:30 +00:00
self._last_original_obs = None # type: Optional[Union[np.ndarray, Dict[str, np.ndarray]]]
2020-04-17 10:36:27 +00:00
self._episode_num = 0
# Used for gSDE only
2019-11-26 14:26:12 +00:00
self.use_sde = use_sde
2019-12-17 10:47:21 +00:00
self.sde_sample_freq = sde_sample_freq
# Track the training progress remaining (from 1 to 0)
2019-10-28 15:47:13 +00:00
# this is used to update the learning rate
self._current_progress_remaining = 1
2020-02-04 12:24:09 +00:00
# Buffers for logging
self.ep_info_buffer = None # type: Optional[deque]
self.ep_success_buffer = None # type: Optional[deque]
TD3 Code review (#245) * Removed unneeded overrides of feature_extractor and normalize_images in the TD3 Actor. * Add learning rate schedule example (#248) * Add learning rate schedule example * Update docs/guide/examples.rst Co-authored-by: Adam Gleave <adam@gleave.me> * Address comments Co-authored-by: Adam Gleave <adam@gleave.me> * Add supported action spaces checks (#254) * Add supported action spaces checks * Address comment * Use `pass` in an abstractmethod instead of deleting the arguments. * Remove the "deterministic" keyword from the forward method of the TD3 Actor since it always is deterministic anyways. * Rename _get_data to _get_data_to_reconstruct_model. _get_data was too generic and could have meant anything. * Remove the n_episodes_rollout parameter and allow passing tuples as train_freq instead. * Fix docstring of `train_freq` parameter. * Black fixes. * Fix TD3 delayed update + rename `_get_data()` * Fix TD3 test * Normalize `train_freq` to a tuple in the constructor and turn the warning into an assert. * Make one step the default train frequency. * Black fixes. * Change np.bool to bool. * Use the tuple format to specify an amount of steps in terms of steps or episodes in the collect_collouts of the off policy algorithm. * Use the tuple format to specify an amount of steps in terms of steps or episodes in the collect_collouts of HER. * Use named tuple for train freq * Rename train_freq to train_every and TrainFreq to ExperienceDuration. Also add some type annotations and documentation. * Black fixes. * Revert to train_freq * Fix terminal observation issues * Typo * Fix action noise bug in HER * Add assert when loading HER models * Update version Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Adam Gleave <adam@gleave.me>
2021-02-27 16:33:50 +00:00
# For logging (and TD3 delayed updates)
2020-03-13 10:43:12 +00:00
self._n_updates = 0 # type: int
# The logger object
self._logger = None # type: Logger
# Whether the user passed a custom logger or not
self._custom_logger = False
2019-09-05 15:29:41 +00:00
2019-11-22 12:33:12 +00:00
# Create and wrap the env if needed
2019-09-05 15:29:41 +00:00
if env is not None:
env = maybe_make_env(env, self.verbose)
env = self._wrap_env(env, self.verbose, monitor_wrapper)
2019-09-05 15:29:41 +00:00
self.observation_space = env.observation_space
self.action_space = env.action_space
2019-09-20 13:19:04 +00:00
self.n_envs = env.num_envs
self.env = env
if supported_action_spaces is not None:
assert isinstance(self.action_space, supported_action_spaces), (
f"The algorithm only supports {supported_action_spaces} as action spaces "
f"but {self.action_space} was provided"
)
2019-09-20 13:19:04 +00:00
if not support_multi_env and self.n_envs > 1:
raise ValueError(
"Error: the model does not support multiple envs; it requires " "a single vectorized environment."
)
2019-09-20 13:19:04 +00:00
# Catch common mistake: using MlpPolicy/CnnPolicy instead of MultiInputPolicy
if policy in ["MlpPolicy", "CnnPolicy"] and isinstance(self.observation_space, gym.spaces.Dict):
raise ValueError(f"You must use `MultiInputPolicy` when working with dict observation space, not {policy}")
if self.use_sde and not isinstance(self.action_space, gym.spaces.Box):
raise ValueError("generalized State-Dependent Exploration (gSDE) can only be used with continuous actions.")
if isinstance(self.action_space, gym.spaces.Box):
assert np.all(
np.isfinite(np.array([self.action_space.low, self.action_space.high]))
), "Continuous action space must have a finite lower and upper bound"
Implement HER (#120) * Added working her version, Online sampling is missing. * Updated test_her. * Added first version of online her sampling. Still problems with tensor dimensions. * Reformat * Fixed tests * Added some comments. * Updated changelog. * Add missing init file * Fixed some small bugs. * Reduced arguments for HER, small changes. * Added getattr. Fixed bug for online sampling. * Updated save/load funtions. Small changes. * Added her to init. * Updated save method. * Updated her ratio. * Move obs_wrapper * Added DQN test. * Fix potential bug * Offline and online her share same sample_goal function. * Changed lists into arrays. * Updated her test. * Fix online sampling * Fixed action bug. Updated time limit for episodes. * Updated convert_dict method to take keys as arguments. * Renamed obs dict wrapper. * Seed bit flipping env * Remove get_episode_dict * Add fast online sampling version * Added documentation. * Vectorized reward computation * Vectorized goal sampling * Update time limit for episodes in online her sampling. * Fix max episode length inference * Bug fix for Fetch envs * Fix for HER + gSDE * Reformat (new black version) * Added info dict to compute new reward. Check her_replay_buffer again. * Fix info buffer * Updated done flag. * Fixes for gSDE * Offline her version uses now HerReplayBuffer as episode storage. * Fix num_timesteps computation * Fix get torch params * Vectorized version for offline sampling. * Modified offline her sampling to use sample method of her_replay_buffer * Updated HER tests. * Updated documentation * Cleanup docstrings * Updated to review comments * Fix pytype * Update according to review comments. * Removed random goal strategy. Updated sample transitions. * Updated migration. Removed time signal removal. * Update doc * Fix potential load issue * Add VecNormalize support for dict obs * Updated saving/loading replay buffer for HER. * Fix test memory usage * Fixed save/load replay buffer. * Fixed save/load replay buffer * Fixed transition index after loading replay buffer in online sampling * Better error handling * Add tests for get_time_limit * More tests for VecNormalize with dict obs * Update doc * Improve HER description * Add test for sde support * Add comments * Add comments * Remove check that was always valid * Fix for terminal observation * Updated buffer size in offline version and reset of HER buffer * Reformat * Update doc * Remove np.empty + add doc * Fix loading * Updated loading replay buffer * Separate online and offline sampling + bug fixes * Update tensorboard log name * Version bump * Bug fix for special case Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de> Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-10-22 09:56:43 +00:00
@staticmethod
def _wrap_env(env: GymEnv, verbose: int = 0, monitor_wrapper: bool = True) -> VecEnv:
""" "
Wrap environment with the appropriate wrappers if needed.
For instance, to have a vectorized environment
or to re-order the image channels.
:param env:
:param verbose: Verbosity level: 0 for no output, 1 for indicating wrappers used
:param monitor_wrapper: Whether to wrap the env in a ``Monitor`` when possible.
:return: The wrapped environment.
"""
if not isinstance(env, VecEnv):
if not is_wrapped(env, Monitor) and monitor_wrapper:
if verbose >= 1:
print("Wrapping the env with a `Monitor` wrapper")
env = Monitor(env)
Implement HER (#120) * Added working her version, Online sampling is missing. * Updated test_her. * Added first version of online her sampling. Still problems with tensor dimensions. * Reformat * Fixed tests * Added some comments. * Updated changelog. * Add missing init file * Fixed some small bugs. * Reduced arguments for HER, small changes. * Added getattr. Fixed bug for online sampling. * Updated save/load funtions. Small changes. * Added her to init. * Updated save method. * Updated her ratio. * Move obs_wrapper * Added DQN test. * Fix potential bug * Offline and online her share same sample_goal function. * Changed lists into arrays. * Updated her test. * Fix online sampling * Fixed action bug. Updated time limit for episodes. * Updated convert_dict method to take keys as arguments. * Renamed obs dict wrapper. * Seed bit flipping env * Remove get_episode_dict * Add fast online sampling version * Added documentation. * Vectorized reward computation * Vectorized goal sampling * Update time limit for episodes in online her sampling. * Fix max episode length inference * Bug fix for Fetch envs * Fix for HER + gSDE * Reformat (new black version) * Added info dict to compute new reward. Check her_replay_buffer again. * Fix info buffer * Updated done flag. * Fixes for gSDE * Offline her version uses now HerReplayBuffer as episode storage. * Fix num_timesteps computation * Fix get torch params * Vectorized version for offline sampling. * Modified offline her sampling to use sample method of her_replay_buffer * Updated HER tests. * Updated documentation * Cleanup docstrings * Updated to review comments * Fix pytype * Update according to review comments. * Removed random goal strategy. Updated sample transitions. * Updated migration. Removed time signal removal. * Update doc * Fix potential load issue * Add VecNormalize support for dict obs * Updated saving/loading replay buffer for HER. * Fix test memory usage * Fixed save/load replay buffer. * Fixed save/load replay buffer * Fixed transition index after loading replay buffer in online sampling * Better error handling * Add tests for get_time_limit * More tests for VecNormalize with dict obs * Update doc * Improve HER description * Add test for sde support * Add comments * Add comments * Remove check that was always valid * Fix for terminal observation * Updated buffer size in offline version and reset of HER buffer * Reformat * Update doc * Remove np.empty + add doc * Fix loading * Updated loading replay buffer * Separate online and offline sampling + bug fixes * Update tensorboard log name * Version bump * Bug fix for special case Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de> Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-10-22 09:56:43 +00:00
if verbose >= 1:
print("Wrapping the env in a DummyVecEnv.")
env = DummyVecEnv([lambda: env])
Dictionary Observations (#243) * First commit * Fixing missing refs from a quick merge from master * Reformat * Adding DictBuffers * Reformat * Minor reformat * added slow dict test. Added SACMultiInputPolicy for future. Added private static image transpose helper to common policy * Ran black on buffers * Ran isort * Adding StackedObservations classes used within VecStackEnvs wrappers. Made test_dict_env shorter and removed slow * Running isort :facepalm * Fixed typing issues * Adding docstrings and typing. Using util for moving data to device. * Fixed trailing commas * Fix types * Minor edits * Avoid duplicating code * Fix calls to parents * Adding assert to buffers. Updating changelong * Running format on buffers * Adding multi-input policies to dqn,td3,a2c. Fixing warnings. Fixed bug with DictReplayBuffer as Replay buffers use only 1 env * Fixing warnings, splitting is_vectorized_observation into multiple functions based on space type * Created envs folder in common. Updated imports. Moved stacked_obs to vec_env folder * Moved envs to envs directory. Moved stacked obs to vec_envs. Started update on documentation * Fixes * Running code style * Update docstrings on torch_layers * Decapitalize non-constant variables * Using NatureCNN architecture in combined extractor. Increasing img size in multi input env. Adding memory reduction in test * Update doc * Update doc * Fix format * Removing NineRoom env. Using nested preprocess. Removing mutable default args * running code style * Passing channel check through to stacked dict observations. * Running black * Adding channel control to SimpleMultiObsEnv. Passing check_channels to CombinedExtractor * Remove optimize memory for dict buffers * Update doc * Move identity env * Minor edits + bump version * Update doc * Fix doc build * Bug fixes + add support for more type of dict env * Fixes + add multi env test * Add support for vectranspose * Fix stacked obs for dict and add tests * Add check for nested spaces. Fix dict-subprocvecenv test * Fix (single) pytype error * Simplify CombinedExtractor * Fix tests * Fix check * Merge branch 'master' into feat/dict_observations * Fix for net_arch with dict and vector obs * Fixes * Add consistency test * Update env checker * Add some docs on dict obs * Update default CNN feature vector size * Refactor HER (#351) * Start refactoring HER * Fixes * Additional fixes * Faster tests * WIP: HER as a custom replay buffer * New replay only version (working with DQN) * Add support for all off-policy algorithms * Fix saving/loading * Remove ObsDictWrapper and add VecNormalize tests with dict * Stable-Baselines3 v1.0 (#354) * Bump version and update doc * Fix name * Apply suggestions from code review Co-authored-by: Adam Gleave <adam@gleave.me> * Update docs/index.rst Co-authored-by: Adam Gleave <adam@gleave.me> * Update wording for RL zoo Co-authored-by: Adam Gleave <adam@gleave.me> * Add gym-pybullet-drones project (#358) * Update projects.rst Added gym-pybullet-drones * Update projects.rst Longer title underline * Update changelog Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org> * Include SuperSuit in projects (#359) * include supersuit * longer title underline * Update changelog.rst * Fix default arguments + add bugbear (#363) * Fix potential bug + add bug bear * Remove unused variables * Minor: version bump * Add code of conduct + update doc (#373) * Add code of conduct * Fix DQN doc example * Update doc (channel-last/first) * Apply suggestions from code review Co-authored-by: Anssi <kaneran21@hotmail.com> * Apply suggestions from code review Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Anssi <kaneran21@hotmail.com> Co-authored-by: Adam Gleave <adam@gleave.me> * Make installation command compatible with ZSH (#376) * Add quotes * Add Zsh bracket info * Add clarify pip installation line * Make note bold * Add Zsh pip installation note * Add handle timeouts param * Fixes * Fixes (buffer size, extend test) * Fix `max_episode_length` redefinition * Fix potential issue * Add some docs on dict obs * Fix performance bug * Fix slowdown * Add package to install (#378) * Add package to install * Update docs packages installation command Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Fix backward compat + add test * Fix VecEnv detection * Update doc * Fix vec env check * Support for `VecMonitor` for gym3-style environments (#311) * add vectorized monitor * auto format of the code * add documentation and VecExtractDictObs * refactor and add test cases * add test cases and format * avoid circular import and fix doc * fix type * fix type * oops * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * add test cases * update changelog * fix mutable argument * quick fix * Apply suggestions from code review * fix terminal observation for gym3 envs * delete comment * Update doc and bump version * Add warning when already using `Monitor` wrapper * Update vecmonitor tests * Fixes Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Reformat * Fixed loading of ``ent_coef`` for ``SAC`` and ``TQC``, it was not optimized anymore (#392) * Fix ent coef loading bug * Add test * Add comment * Reuse save path * Add test for GAE + rename `RolloutBuffer.dones` for clarification (#375) * Fix return computation + add test for GAE * Rename `last_dones` to `episode_starts` for clarification * Revert advantage * Cleanup test * Rename variable * Clarify return computation * Clarify docs * Add multi-episode rollout test * Reformat Co-authored-by: Anssi "Miffyli" Kanervisto <kaneran21@hotmail.com> * Fixed saving of `A2C` and `PPO` policy when using gSDE (#401) * Improve doc and replay buffer loading * Add support for images * Fix doc * Update Procgen doc * Update changelog * Update docstrings Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Jacopo Panerati <jacopo.panerati@utoronto.ca> Co-authored-by: Justin Terry <justinkterry@gmail.com> Co-authored-by: Anssi <kaneran21@hotmail.com> Co-authored-by: Tom Dörr <tomdoerr96@gmail.com> Co-authored-by: Tom Dörr <tom.doerr@tum.de> Co-authored-by: Costa Huang <costa.huang@outlook.com> * Update doc and minor fixes * Update doc * Added note about MultiInputPolicy in error of NatureCNN * Merge branch 'master' into feat/dict_observations * Address comments * Naming clarifications * Actually saving the file would be nice * Fix edge case when doing online sampling with HER * Cleanup * Add sanity check Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Anssi "Miffyli" Kanervisto <kaneran21@hotmail.com> Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Jacopo Panerati <jacopo.panerati@utoronto.ca> Co-authored-by: Justin Terry <justinkterry@gmail.com> Co-authored-by: Tom Dörr <tomdoerr96@gmail.com> Co-authored-by: Tom Dörr <tom.doerr@tum.de> Co-authored-by: Costa Huang <costa.huang@outlook.com>
2021-05-11 10:29:30 +00:00
# Make sure that dict-spaces are not nested (not supported)
check_for_nested_spaces(env.observation_space)
if not is_vecenv_wrapped(env, VecTransposeImage):
wrap_with_vectranspose = False
if isinstance(env.observation_space, gym.spaces.Dict):
# If even one of the keys is a image-space in need of transpose, apply transpose
# If the image spaces are not consistent (for instance one is channel first,
# the other channel last), VecTransposeImage will throw an error
for space in env.observation_space.spaces.values():
wrap_with_vectranspose = wrap_with_vectranspose or (
is_image_space(space) and not is_image_space_channels_first(space)
)
else:
wrap_with_vectranspose = is_image_space(env.observation_space) and not is_image_space_channels_first(
env.observation_space
)
Implement HER (#120) * Added working her version, Online sampling is missing. * Updated test_her. * Added first version of online her sampling. Still problems with tensor dimensions. * Reformat * Fixed tests * Added some comments. * Updated changelog. * Add missing init file * Fixed some small bugs. * Reduced arguments for HER, small changes. * Added getattr. Fixed bug for online sampling. * Updated save/load funtions. Small changes. * Added her to init. * Updated save method. * Updated her ratio. * Move obs_wrapper * Added DQN test. * Fix potential bug * Offline and online her share same sample_goal function. * Changed lists into arrays. * Updated her test. * Fix online sampling * Fixed action bug. Updated time limit for episodes. * Updated convert_dict method to take keys as arguments. * Renamed obs dict wrapper. * Seed bit flipping env * Remove get_episode_dict * Add fast online sampling version * Added documentation. * Vectorized reward computation * Vectorized goal sampling * Update time limit for episodes in online her sampling. * Fix max episode length inference * Bug fix for Fetch envs * Fix for HER + gSDE * Reformat (new black version) * Added info dict to compute new reward. Check her_replay_buffer again. * Fix info buffer * Updated done flag. * Fixes for gSDE * Offline her version uses now HerReplayBuffer as episode storage. * Fix num_timesteps computation * Fix get torch params * Vectorized version for offline sampling. * Modified offline her sampling to use sample method of her_replay_buffer * Updated HER tests. * Updated documentation * Cleanup docstrings * Updated to review comments * Fix pytype * Update according to review comments. * Removed random goal strategy. Updated sample transitions. * Updated migration. Removed time signal removal. * Update doc * Fix potential load issue * Add VecNormalize support for dict obs * Updated saving/loading replay buffer for HER. * Fix test memory usage * Fixed save/load replay buffer. * Fixed save/load replay buffer * Fixed transition index after loading replay buffer in online sampling * Better error handling * Add tests for get_time_limit * More tests for VecNormalize with dict obs * Update doc * Improve HER description * Add test for sde support * Add comments * Add comments * Remove check that was always valid * Fix for terminal observation * Updated buffer size in offline version and reset of HER buffer * Reformat * Update doc * Remove np.empty + add doc * Fix loading * Updated loading replay buffer * Separate online and offline sampling + bug fixes * Update tensorboard log name * Version bump * Bug fix for special case Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de> Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-10-22 09:56:43 +00:00
Dictionary Observations (#243) * First commit * Fixing missing refs from a quick merge from master * Reformat * Adding DictBuffers * Reformat * Minor reformat * added slow dict test. Added SACMultiInputPolicy for future. Added private static image transpose helper to common policy * Ran black on buffers * Ran isort * Adding StackedObservations classes used within VecStackEnvs wrappers. Made test_dict_env shorter and removed slow * Running isort :facepalm * Fixed typing issues * Adding docstrings and typing. Using util for moving data to device. * Fixed trailing commas * Fix types * Minor edits * Avoid duplicating code * Fix calls to parents * Adding assert to buffers. Updating changelong * Running format on buffers * Adding multi-input policies to dqn,td3,a2c. Fixing warnings. Fixed bug with DictReplayBuffer as Replay buffers use only 1 env * Fixing warnings, splitting is_vectorized_observation into multiple functions based on space type * Created envs folder in common. Updated imports. Moved stacked_obs to vec_env folder * Moved envs to envs directory. Moved stacked obs to vec_envs. Started update on documentation * Fixes * Running code style * Update docstrings on torch_layers * Decapitalize non-constant variables * Using NatureCNN architecture in combined extractor. Increasing img size in multi input env. Adding memory reduction in test * Update doc * Update doc * Fix format * Removing NineRoom env. Using nested preprocess. Removing mutable default args * running code style * Passing channel check through to stacked dict observations. * Running black * Adding channel control to SimpleMultiObsEnv. Passing check_channels to CombinedExtractor * Remove optimize memory for dict buffers * Update doc * Move identity env * Minor edits + bump version * Update doc * Fix doc build * Bug fixes + add support for more type of dict env * Fixes + add multi env test * Add support for vectranspose * Fix stacked obs for dict and add tests * Add check for nested spaces. Fix dict-subprocvecenv test * Fix (single) pytype error * Simplify CombinedExtractor * Fix tests * Fix check * Merge branch 'master' into feat/dict_observations * Fix for net_arch with dict and vector obs * Fixes * Add consistency test * Update env checker * Add some docs on dict obs * Update default CNN feature vector size * Refactor HER (#351) * Start refactoring HER * Fixes * Additional fixes * Faster tests * WIP: HER as a custom replay buffer * New replay only version (working with DQN) * Add support for all off-policy algorithms * Fix saving/loading * Remove ObsDictWrapper and add VecNormalize tests with dict * Stable-Baselines3 v1.0 (#354) * Bump version and update doc * Fix name * Apply suggestions from code review Co-authored-by: Adam Gleave <adam@gleave.me> * Update docs/index.rst Co-authored-by: Adam Gleave <adam@gleave.me> * Update wording for RL zoo Co-authored-by: Adam Gleave <adam@gleave.me> * Add gym-pybullet-drones project (#358) * Update projects.rst Added gym-pybullet-drones * Update projects.rst Longer title underline * Update changelog Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org> * Include SuperSuit in projects (#359) * include supersuit * longer title underline * Update changelog.rst * Fix default arguments + add bugbear (#363) * Fix potential bug + add bug bear * Remove unused variables * Minor: version bump * Add code of conduct + update doc (#373) * Add code of conduct * Fix DQN doc example * Update doc (channel-last/first) * Apply suggestions from code review Co-authored-by: Anssi <kaneran21@hotmail.com> * Apply suggestions from code review Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Anssi <kaneran21@hotmail.com> Co-authored-by: Adam Gleave <adam@gleave.me> * Make installation command compatible with ZSH (#376) * Add quotes * Add Zsh bracket info * Add clarify pip installation line * Make note bold * Add Zsh pip installation note * Add handle timeouts param * Fixes * Fixes (buffer size, extend test) * Fix `max_episode_length` redefinition * Fix potential issue * Add some docs on dict obs * Fix performance bug * Fix slowdown * Add package to install (#378) * Add package to install * Update docs packages installation command Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Fix backward compat + add test * Fix VecEnv detection * Update doc * Fix vec env check * Support for `VecMonitor` for gym3-style environments (#311) * add vectorized monitor * auto format of the code * add documentation and VecExtractDictObs * refactor and add test cases * add test cases and format * avoid circular import and fix doc * fix type * fix type * oops * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * add test cases * update changelog * fix mutable argument * quick fix * Apply suggestions from code review * fix terminal observation for gym3 envs * delete comment * Update doc and bump version * Add warning when already using `Monitor` wrapper * Update vecmonitor tests * Fixes Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Reformat * Fixed loading of ``ent_coef`` for ``SAC`` and ``TQC``, it was not optimized anymore (#392) * Fix ent coef loading bug * Add test * Add comment * Reuse save path * Add test for GAE + rename `RolloutBuffer.dones` for clarification (#375) * Fix return computation + add test for GAE * Rename `last_dones` to `episode_starts` for clarification * Revert advantage * Cleanup test * Rename variable * Clarify return computation * Clarify docs * Add multi-episode rollout test * Reformat Co-authored-by: Anssi "Miffyli" Kanervisto <kaneran21@hotmail.com> * Fixed saving of `A2C` and `PPO` policy when using gSDE (#401) * Improve doc and replay buffer loading * Add support for images * Fix doc * Update Procgen doc * Update changelog * Update docstrings Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Jacopo Panerati <jacopo.panerati@utoronto.ca> Co-authored-by: Justin Terry <justinkterry@gmail.com> Co-authored-by: Anssi <kaneran21@hotmail.com> Co-authored-by: Tom Dörr <tomdoerr96@gmail.com> Co-authored-by: Tom Dörr <tom.doerr@tum.de> Co-authored-by: Costa Huang <costa.huang@outlook.com> * Update doc and minor fixes * Update doc * Added note about MultiInputPolicy in error of NatureCNN * Merge branch 'master' into feat/dict_observations * Address comments * Naming clarifications * Actually saving the file would be nice * Fix edge case when doing online sampling with HER * Cleanup * Add sanity check Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Anssi "Miffyli" Kanervisto <kaneran21@hotmail.com> Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Jacopo Panerati <jacopo.panerati@utoronto.ca> Co-authored-by: Justin Terry <justinkterry@gmail.com> Co-authored-by: Tom Dörr <tomdoerr96@gmail.com> Co-authored-by: Tom Dörr <tom.doerr@tum.de> Co-authored-by: Costa Huang <costa.huang@outlook.com>
2021-05-11 10:29:30 +00:00
if wrap_with_vectranspose:
if verbose >= 1:
print("Wrapping the env in a VecTransposeImage.")
env = VecTransposeImage(env)
Implement HER (#120) * Added working her version, Online sampling is missing. * Updated test_her. * Added first version of online her sampling. Still problems with tensor dimensions. * Reformat * Fixed tests * Added some comments. * Updated changelog. * Add missing init file * Fixed some small bugs. * Reduced arguments for HER, small changes. * Added getattr. Fixed bug for online sampling. * Updated save/load funtions. Small changes. * Added her to init. * Updated save method. * Updated her ratio. * Move obs_wrapper * Added DQN test. * Fix potential bug * Offline and online her share same sample_goal function. * Changed lists into arrays. * Updated her test. * Fix online sampling * Fixed action bug. Updated time limit for episodes. * Updated convert_dict method to take keys as arguments. * Renamed obs dict wrapper. * Seed bit flipping env * Remove get_episode_dict * Add fast online sampling version * Added documentation. * Vectorized reward computation * Vectorized goal sampling * Update time limit for episodes in online her sampling. * Fix max episode length inference * Bug fix for Fetch envs * Fix for HER + gSDE * Reformat (new black version) * Added info dict to compute new reward. Check her_replay_buffer again. * Fix info buffer * Updated done flag. * Fixes for gSDE * Offline her version uses now HerReplayBuffer as episode storage. * Fix num_timesteps computation * Fix get torch params * Vectorized version for offline sampling. * Modified offline her sampling to use sample method of her_replay_buffer * Updated HER tests. * Updated documentation * Cleanup docstrings * Updated to review comments * Fix pytype * Update according to review comments. * Removed random goal strategy. Updated sample transitions. * Updated migration. Removed time signal removal. * Update doc * Fix potential load issue * Add VecNormalize support for dict obs * Updated saving/loading replay buffer for HER. * Fix test memory usage * Fixed save/load replay buffer. * Fixed save/load replay buffer * Fixed transition index after loading replay buffer in online sampling * Better error handling * Add tests for get_time_limit * More tests for VecNormalize with dict obs * Update doc * Improve HER description * Add test for sde support * Add comments * Add comments * Remove check that was always valid * Fix for terminal observation * Updated buffer size in offline version and reset of HER buffer * Reformat * Update doc * Remove np.empty + add doc * Fix loading * Updated loading replay buffer * Separate online and offline sampling + bug fixes * Update tensorboard log name * Version bump * Bug fix for special case Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de> Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-10-22 09:56:43 +00:00
return env
@abstractmethod
def _setup_model(self) -> None:
2020-07-03 01:49:59 +00:00
"""Create networks, buffer and optimizers."""
def set_logger(self, logger: Logger) -> None:
"""
Setter for for logger object.
.. warning::
When passing a custom logger object,
this will overwrite ``tensorboard_log`` and ``verbose`` settings
passed to the constructor.
"""
self._logger = logger
# User defined logger
self._custom_logger = True
@property
def logger(self) -> Logger:
"""Getter for the logger object."""
return self._logger
2020-03-16 13:05:21 +00:00
def _setup_lr_schedule(self) -> None:
2019-10-28 15:47:13 +00:00
"""Transform to callable if needed."""
self.lr_schedule = get_schedule_fn(self.learning_rate)
2019-10-28 15:47:13 +00:00
def _update_current_progress_remaining(self, num_timesteps: int, total_timesteps: int) -> None:
2019-10-28 15:47:13 +00:00
"""
Compute current progress remaining (starts from 1 and ends to 0)
2019-10-28 15:47:13 +00:00
2020-01-22 16:51:27 +00:00
:param num_timesteps: current number of timesteps
:param total_timesteps:
2019-10-28 15:47:13 +00:00
"""
self._current_progress_remaining = 1.0 - float(num_timesteps) / float(total_timesteps)
2019-10-28 15:47:13 +00:00
2020-01-22 16:51:27 +00:00
def _update_learning_rate(self, optimizers: Union[List[th.optim.Optimizer], th.optim.Optimizer]) -> None:
2019-10-28 16:42:39 +00:00
"""
Update the optimizers learning rate using the current learning rate schedule
and the current progress remaining (from 1 to 0).
2019-10-28 16:42:39 +00:00
:param optimizers:
An optimizer or a list of optimizers.
2019-10-28 16:42:39 +00:00
"""
# Log the current learning rate
self.logger.record("train/learning_rate", self.lr_schedule(self._current_progress_remaining))
2019-10-28 15:47:13 +00:00
if not isinstance(optimizers, list):
optimizers = [optimizers]
for optimizer in optimizers:
update_learning_rate(optimizer, self.lr_schedule(self._current_progress_remaining))
2019-10-10 11:47:13 +00:00
def _excluded_save_params(self) -> List[str]:
2019-09-05 15:29:41 +00:00
"""
Returns the names of the parameters that should be excluded from being
saved by pickling. E.g. replay buffers are skipped by default
as they take up a lot of space. PyTorch variables should be excluded
with this so they can be stored with ``th.save``.
2019-09-05 15:29:41 +00:00
:return: List of parameters that should be excluded from being saved with pickle.
2019-09-05 15:29:41 +00:00
"""
return [
"policy",
"device",
"env",
"replay_buffer",
"rollout_buffer",
"_vec_normalize_env",
Dictionary Observations (#243) * First commit * Fixing missing refs from a quick merge from master * Reformat * Adding DictBuffers * Reformat * Minor reformat * added slow dict test. Added SACMultiInputPolicy for future. Added private static image transpose helper to common policy * Ran black on buffers * Ran isort * Adding StackedObservations classes used within VecStackEnvs wrappers. Made test_dict_env shorter and removed slow * Running isort :facepalm * Fixed typing issues * Adding docstrings and typing. Using util for moving data to device. * Fixed trailing commas * Fix types * Minor edits * Avoid duplicating code * Fix calls to parents * Adding assert to buffers. Updating changelong * Running format on buffers * Adding multi-input policies to dqn,td3,a2c. Fixing warnings. Fixed bug with DictReplayBuffer as Replay buffers use only 1 env * Fixing warnings, splitting is_vectorized_observation into multiple functions based on space type * Created envs folder in common. Updated imports. Moved stacked_obs to vec_env folder * Moved envs to envs directory. Moved stacked obs to vec_envs. Started update on documentation * Fixes * Running code style * Update docstrings on torch_layers * Decapitalize non-constant variables * Using NatureCNN architecture in combined extractor. Increasing img size in multi input env. Adding memory reduction in test * Update doc * Update doc * Fix format * Removing NineRoom env. Using nested preprocess. Removing mutable default args * running code style * Passing channel check through to stacked dict observations. * Running black * Adding channel control to SimpleMultiObsEnv. Passing check_channels to CombinedExtractor * Remove optimize memory for dict buffers * Update doc * Move identity env * Minor edits + bump version * Update doc * Fix doc build * Bug fixes + add support for more type of dict env * Fixes + add multi env test * Add support for vectranspose * Fix stacked obs for dict and add tests * Add check for nested spaces. Fix dict-subprocvecenv test * Fix (single) pytype error * Simplify CombinedExtractor * Fix tests * Fix check * Merge branch 'master' into feat/dict_observations * Fix for net_arch with dict and vector obs * Fixes * Add consistency test * Update env checker * Add some docs on dict obs * Update default CNN feature vector size * Refactor HER (#351) * Start refactoring HER * Fixes * Additional fixes * Faster tests * WIP: HER as a custom replay buffer * New replay only version (working with DQN) * Add support for all off-policy algorithms * Fix saving/loading * Remove ObsDictWrapper and add VecNormalize tests with dict * Stable-Baselines3 v1.0 (#354) * Bump version and update doc * Fix name * Apply suggestions from code review Co-authored-by: Adam Gleave <adam@gleave.me> * Update docs/index.rst Co-authored-by: Adam Gleave <adam@gleave.me> * Update wording for RL zoo Co-authored-by: Adam Gleave <adam@gleave.me> * Add gym-pybullet-drones project (#358) * Update projects.rst Added gym-pybullet-drones * Update projects.rst Longer title underline * Update changelog Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org> * Include SuperSuit in projects (#359) * include supersuit * longer title underline * Update changelog.rst * Fix default arguments + add bugbear (#363) * Fix potential bug + add bug bear * Remove unused variables * Minor: version bump * Add code of conduct + update doc (#373) * Add code of conduct * Fix DQN doc example * Update doc (channel-last/first) * Apply suggestions from code review Co-authored-by: Anssi <kaneran21@hotmail.com> * Apply suggestions from code review Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Anssi <kaneran21@hotmail.com> Co-authored-by: Adam Gleave <adam@gleave.me> * Make installation command compatible with ZSH (#376) * Add quotes * Add Zsh bracket info * Add clarify pip installation line * Make note bold * Add Zsh pip installation note * Add handle timeouts param * Fixes * Fixes (buffer size, extend test) * Fix `max_episode_length` redefinition * Fix potential issue * Add some docs on dict obs * Fix performance bug * Fix slowdown * Add package to install (#378) * Add package to install * Update docs packages installation command Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Fix backward compat + add test * Fix VecEnv detection * Update doc * Fix vec env check * Support for `VecMonitor` for gym3-style environments (#311) * add vectorized monitor * auto format of the code * add documentation and VecExtractDictObs * refactor and add test cases * add test cases and format * avoid circular import and fix doc * fix type * fix type * oops * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * add test cases * update changelog * fix mutable argument * quick fix * Apply suggestions from code review * fix terminal observation for gym3 envs * delete comment * Update doc and bump version * Add warning when already using `Monitor` wrapper * Update vecmonitor tests * Fixes Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Reformat * Fixed loading of ``ent_coef`` for ``SAC`` and ``TQC``, it was not optimized anymore (#392) * Fix ent coef loading bug * Add test * Add comment * Reuse save path * Add test for GAE + rename `RolloutBuffer.dones` for clarification (#375) * Fix return computation + add test for GAE * Rename `last_dones` to `episode_starts` for clarification * Revert advantage * Cleanup test * Rename variable * Clarify return computation * Clarify docs * Add multi-episode rollout test * Reformat Co-authored-by: Anssi "Miffyli" Kanervisto <kaneran21@hotmail.com> * Fixed saving of `A2C` and `PPO` policy when using gSDE (#401) * Improve doc and replay buffer loading * Add support for images * Fix doc * Update Procgen doc * Update changelog * Update docstrings Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Jacopo Panerati <jacopo.panerati@utoronto.ca> Co-authored-by: Justin Terry <justinkterry@gmail.com> Co-authored-by: Anssi <kaneran21@hotmail.com> Co-authored-by: Tom Dörr <tomdoerr96@gmail.com> Co-authored-by: Tom Dörr <tom.doerr@tum.de> Co-authored-by: Costa Huang <costa.huang@outlook.com> * Update doc and minor fixes * Update doc * Added note about MultiInputPolicy in error of NatureCNN * Merge branch 'master' into feat/dict_observations * Address comments * Naming clarifications * Actually saving the file would be nice * Fix edge case when doing online sampling with HER * Cleanup * Add sanity check Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Anssi "Miffyli" Kanervisto <kaneran21@hotmail.com> Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Jacopo Panerati <jacopo.panerati@utoronto.ca> Co-authored-by: Justin Terry <justinkterry@gmail.com> Co-authored-by: Tom Dörr <tomdoerr96@gmail.com> Co-authored-by: Tom Dörr <tom.doerr@tum.de> Co-authored-by: Costa Huang <costa.huang@outlook.com>
2021-05-11 10:29:30 +00:00
"_episode_storage",
"_logger",
"_custom_logger",
]
2019-09-05 15:29:41 +00:00
def _get_policy_from_name(self, policy_name: str) -> Type[BasePolicy]:
"""
Get a policy class from its name representation.
The goal here is to standardize policy naming, e.g.
all algorithms can call upon "MlpPolicy" or "CnnPolicy",
and they receive respective policies that work for them.
:param policy_name: Alias of the policy
:return: A policy class (type)
"""
if policy_name in self.policy_aliases:
return self.policy_aliases[policy_name]
else:
raise ValueError(f"Policy {policy_name} unknown")
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
2020-02-12 10:34:29 +00:00
"""
Get the name of the torch variables that will be saved with
PyTorch ``th.save``, ``th.load`` and ``state_dicts`` instead of the default
pickling strategy. This is to handle device placement correctly.
2020-02-12 10:34:29 +00:00
Names can point to specific variables under classes, e.g.
"policy.optimizer" would point to ``optimizer`` object of ``self.policy``
if this object.
2019-09-05 15:29:41 +00:00
:return:
List of Torch variables whose state dicts to save (e.g. th.nn.Modules),
and list of other Torch variables to store with ``th.save``.
2019-09-05 15:29:41 +00:00
"""
state_dicts = ["policy"]
2019-09-05 15:29:41 +00:00
return state_dicts, []
2019-09-05 15:29:41 +00:00
def _init_callback(
self,
callback: MaybeCallback,
progress_bar: bool = False,
) -> BaseCallback:
2020-01-27 14:53:27 +00:00
"""
:param callback: Callback(s) called at every step with state of the algorithm.
:param progress_bar: Display a progress bar using tqdm and rich.
:return: A hybrid callback calling `callback` and performing evaluation.
2020-01-27 14:53:27 +00:00
"""
# Convert a list of callbacks into a callback
if isinstance(callback, list):
callback = CallbackList(callback)
# Convert functional callback to object
if not isinstance(callback, BaseCallback):
callback = ConvertCallback(callback)
# Add progress bar callback
if progress_bar:
callback = CallbackList([callback, ProgressBarCallback()])
2020-01-27 14:53:27 +00:00
callback.init_callback(self)
return callback
def _setup_learn(
self,
total_timesteps: int,
callback: MaybeCallback = None,
reset_num_timesteps: bool = True,
tb_log_name: str = "run",
progress_bar: bool = False,
) -> Tuple[int, BaseCallback]:
2019-11-22 12:33:12 +00:00
"""
Initialize different variables needed for training.
:param total_timesteps: The total number of samples (env steps) to train on
:param callback: Callback(s) called at every step with state of the algorithm.
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
:param tb_log_name: the name of the run for tensorboard log
:param progress_bar: Display a progress bar using tqdm and rich.
:return: Total timesteps and callback(s)
2019-11-22 12:33:12 +00:00
"""
self.start_time = time.time_ns()
if self.ep_info_buffer is None or reset_num_timesteps:
# Initialize buffers if they don't exist, or reinitialize if resetting counters
self.ep_info_buffer = deque(maxlen=100)
self.ep_success_buffer = deque(maxlen=100)
2019-11-22 12:33:12 +00:00
2019-10-10 11:47:13 +00:00
if self.action_noise is not None:
self.action_noise.reset()
2019-11-22 12:33:12 +00:00
2020-01-31 12:16:28 +00:00
if reset_num_timesteps:
self.num_timesteps = 0
2020-04-17 10:36:27 +00:00
self._episode_num = 0
else:
# Make sure training timesteps are ahead of the internal counter
total_timesteps += self.num_timesteps
Implement DQN (#28) * Created DQN template according to the paper. Next steps: - Create Policy - Complete Training - Debug * Changed Base Class * refactor save, to be consistence with overriding the excluded_save_params function. Do not try to exclude the parameters twice. * Added simple DQN policy * Finished learn and train function - missing correct loss computation * changed collect_rollouts to work with discrete space * moved discrete space collect_rollouts to dqn * basic dqn working * deleted SDE related code * added gradient clipping and moved greedy policy to policy * changed policy to implement target network and added soft update(in fact standart tau is 1 so hard update) * fixed policy setup * rebase target_update_intervall on _n_updates * adapted all tests all tests passing * Move to stable-baseline3 * Fixes for DQN * Fix tests + add CNNPolicy * Allow any optimizer for DQN * added some util functions to create a arbitrary linear schedule, fixed pickle problem with old exploration schedule * more documentation * changed buffer dtype * refactor and document * Added Sphinx Documentation Updated changelog.rst * removed custom collect_rollouts as it is no longer necessary * Implemented suggestions to clean code and documentation. * extracted some functions on tests to reduce duplicated code * added support for exploration_fraction * Fixed exploration_fraction * Added documentation * Fixed get_linear_fn -> proper progress scaling * Merged master * Added nature reference * Changed default parameters to https://www.nature.com/articles/nature14236/tables/1 * Fixed n_updates to be incremented correctly * Correct train_freq * Doc update * added special parameter for DQN in tests * different fix for test_discrete * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Added RMSProp in optimizer_kwargs, as described in nature paper * Exploration fraction is inverse of 50.000.000 (total frames) / 1.000.000 (frames with linear schedule) according to nature paper * Changelog update for buffer dtype * standard exlude parameters should be always excluded to assure proper saving only if intentionally included by ``include`` parameter * slightly more iterations on test_discrete to pass the test * added param use_rms_prop instead of mutable default argument * forgot alpha * using huber loss, adam and learning rate 1e-4 * account for train_freq in update_target_network * Added memory check for both buffers * Doc updated for buffer allocation * Added psutil Requirement * Adapted test_identity.py * Fixes with new SB3 version * Fix for tensorboard name * Convert assert to warning and fix tests * Refactor off-policy algorithms * Fixes * test: remove next_obs in replay buffer * Update changelog * Fix tests and use tmp_path where possible * Fix sampling bug in buffer * Do not store next obs on episode termination * Fix replay buffer sampling * Update comment * moved epsilon from policy to model * Update predict method * Update atari wrappers to match SB2 * Minor edit in the buffers * Update changelog * Merge branch 'master' into dqn * Update DQN to new structure * Fix tests and remove hardcoded path * Fix for DQN * Disable memory efficient replay buffer by default * Fix docstring * Add tests for memory efficient buffer * Update changelog * Split collect rollout * Move target update outside `train()` for DQN * Update changelog * Update linear schedule doc * Cleanup DQN code * Minor edit * Update version and docker images Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-06-29 09:16:54 +00:00
self._total_timesteps = total_timesteps
self._num_timesteps_at_start = self.num_timesteps
2020-04-17 10:36:27 +00:00
2020-05-05 15:19:21 +00:00
# Avoid resetting the environment when calling ``.learn()`` consecutive times
2020-04-17 10:36:27 +00:00
if reset_num_timesteps or self._last_obs is None:
self._last_obs = self.env.reset() # pytype: disable=annotation-type-mismatch
self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool)
2020-04-17 10:36:27 +00:00
# Retrieve unnormalized observation for saving into the buffer
if self._vec_normalize_env is not None:
self._last_original_obs = self._vec_normalize_env.get_original_obs()
2020-01-31 12:16:28 +00:00
# Configure logger's outputs if no logger was passed
if not self._custom_logger:
self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps)
2020-01-27 14:53:27 +00:00
# Create eval callback if needed
callback = self._init_callback(callback, progress_bar)
2020-01-27 14:53:27 +00:00
return total_timesteps, callback
2019-10-10 11:47:13 +00:00
2020-02-04 12:24:09 +00:00
def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.ndarray] = None) -> None:
2019-10-17 11:44:48 +00:00
"""
TD3 Code review (#245) * Removed unneeded overrides of feature_extractor and normalize_images in the TD3 Actor. * Add learning rate schedule example (#248) * Add learning rate schedule example * Update docs/guide/examples.rst Co-authored-by: Adam Gleave <adam@gleave.me> * Address comments Co-authored-by: Adam Gleave <adam@gleave.me> * Add supported action spaces checks (#254) * Add supported action spaces checks * Address comment * Use `pass` in an abstractmethod instead of deleting the arguments. * Remove the "deterministic" keyword from the forward method of the TD3 Actor since it always is deterministic anyways. * Rename _get_data to _get_data_to_reconstruct_model. _get_data was too generic and could have meant anything. * Remove the n_episodes_rollout parameter and allow passing tuples as train_freq instead. * Fix docstring of `train_freq` parameter. * Black fixes. * Fix TD3 delayed update + rename `_get_data()` * Fix TD3 test * Normalize `train_freq` to a tuple in the constructor and turn the warning into an assert. * Make one step the default train frequency. * Black fixes. * Change np.bool to bool. * Use the tuple format to specify an amount of steps in terms of steps or episodes in the collect_collouts of the off policy algorithm. * Use the tuple format to specify an amount of steps in terms of steps or episodes in the collect_collouts of HER. * Use named tuple for train freq * Rename train_freq to train_every and TrainFreq to ExperienceDuration. Also add some type annotations and documentation. * Black fixes. * Revert to train_freq * Fix terminal observation issues * Typo * Fix action noise bug in HER * Add assert when loading HER models * Update version Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Adam Gleave <adam@gleave.me>
2021-02-27 16:33:50 +00:00
Retrieve reward, episode length, episode success and update the buffer
if using Monitor wrapper or a GoalEnv.
2019-11-22 12:33:12 +00:00
TD3 Code review (#245) * Removed unneeded overrides of feature_extractor and normalize_images in the TD3 Actor. * Add learning rate schedule example (#248) * Add learning rate schedule example * Update docs/guide/examples.rst Co-authored-by: Adam Gleave <adam@gleave.me> * Address comments Co-authored-by: Adam Gleave <adam@gleave.me> * Add supported action spaces checks (#254) * Add supported action spaces checks * Address comment * Use `pass` in an abstractmethod instead of deleting the arguments. * Remove the "deterministic" keyword from the forward method of the TD3 Actor since it always is deterministic anyways. * Rename _get_data to _get_data_to_reconstruct_model. _get_data was too generic and could have meant anything. * Remove the n_episodes_rollout parameter and allow passing tuples as train_freq instead. * Fix docstring of `train_freq` parameter. * Black fixes. * Fix TD3 delayed update + rename `_get_data()` * Fix TD3 test * Normalize `train_freq` to a tuple in the constructor and turn the warning into an assert. * Make one step the default train frequency. * Black fixes. * Change np.bool to bool. * Use the tuple format to specify an amount of steps in terms of steps or episodes in the collect_collouts of the off policy algorithm. * Use the tuple format to specify an amount of steps in terms of steps or episodes in the collect_collouts of HER. * Use named tuple for train freq * Rename train_freq to train_every and TrainFreq to ExperienceDuration. Also add some type annotations and documentation. * Black fixes. * Revert to train_freq * Fix terminal observation issues * Typo * Fix action noise bug in HER * Add assert when loading HER models * Update version Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Adam Gleave <adam@gleave.me>
2021-02-27 16:33:50 +00:00
:param infos: List of additional information about the transition.
:param dones: Termination signals
2019-10-17 11:44:48 +00:00
"""
2020-02-04 12:24:09 +00:00
if dones is None:
dones = np.array([False] * len(infos))
for idx, info in enumerate(infos):
maybe_ep_info = info.get("episode")
maybe_is_success = info.get("is_success")
2019-10-17 11:44:48 +00:00
if maybe_ep_info is not None:
self.ep_info_buffer.extend([maybe_ep_info])
2020-02-04 12:24:09 +00:00
if maybe_is_success is not None and dones[idx]:
self.ep_success_buffer.append(maybe_is_success)
2019-10-17 11:44:48 +00:00
def get_env(self) -> Optional[VecEnv]:
2020-02-03 17:18:41 +00:00
"""
Returns the current environment (can be None if not defined).
2020-02-03 17:18:41 +00:00
:return: The current environment
2020-02-03 17:18:41 +00:00
"""
return self.env
def get_vec_normalize_env(self) -> Optional[VecNormalize]:
"""
Return the ``VecNormalize`` wrapper of the training env
if it exists.
:return: The ``VecNormalize`` env.
"""
return self._vec_normalize_env
def set_env(self, env: GymEnv, force_reset: bool = True) -> None:
"""
Checks the validity of the environment, and if it is coherent, set it as the current environment.
Furthermore wrap any non vectorized env into a vectorized
checked parameters:
- observation_space
- action_space
:param env: The environment for learning a policy
:param force_reset: Force call to ``reset()`` before training
to avoid unexpected behavior.
See issue https://github.com/DLR-RM/stable-baselines3/issues/597
"""
# if it is not a VecEnv, make it a VecEnv
Implement HER (#120) * Added working her version, Online sampling is missing. * Updated test_her. * Added first version of online her sampling. Still problems with tensor dimensions. * Reformat * Fixed tests * Added some comments. * Updated changelog. * Add missing init file * Fixed some small bugs. * Reduced arguments for HER, small changes. * Added getattr. Fixed bug for online sampling. * Updated save/load funtions. Small changes. * Added her to init. * Updated save method. * Updated her ratio. * Move obs_wrapper * Added DQN test. * Fix potential bug * Offline and online her share same sample_goal function. * Changed lists into arrays. * Updated her test. * Fix online sampling * Fixed action bug. Updated time limit for episodes. * Updated convert_dict method to take keys as arguments. * Renamed obs dict wrapper. * Seed bit flipping env * Remove get_episode_dict * Add fast online sampling version * Added documentation. * Vectorized reward computation * Vectorized goal sampling * Update time limit for episodes in online her sampling. * Fix max episode length inference * Bug fix for Fetch envs * Fix for HER + gSDE * Reformat (new black version) * Added info dict to compute new reward. Check her_replay_buffer again. * Fix info buffer * Updated done flag. * Fixes for gSDE * Offline her version uses now HerReplayBuffer as episode storage. * Fix num_timesteps computation * Fix get torch params * Vectorized version for offline sampling. * Modified offline her sampling to use sample method of her_replay_buffer * Updated HER tests. * Updated documentation * Cleanup docstrings * Updated to review comments * Fix pytype * Update according to review comments. * Removed random goal strategy. Updated sample transitions. * Updated migration. Removed time signal removal. * Update doc * Fix potential load issue * Add VecNormalize support for dict obs * Updated saving/loading replay buffer for HER. * Fix test memory usage * Fixed save/load replay buffer. * Fixed save/load replay buffer * Fixed transition index after loading replay buffer in online sampling * Better error handling * Add tests for get_time_limit * More tests for VecNormalize with dict obs * Update doc * Improve HER description * Add test for sde support * Add comments * Add comments * Remove check that was always valid * Fix for terminal observation * Updated buffer size in offline version and reset of HER buffer * Reformat * Update doc * Remove np.empty + add doc * Fix loading * Updated loading replay buffer * Separate online and offline sampling + bug fixes * Update tensorboard log name * Version bump * Bug fix for special case Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de> Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-10-22 09:56:43 +00:00
# and do other transformations (dict obs, image transpose) if needed
env = self._wrap_env(env, self.verbose)
assert env.num_envs == self.n_envs, (
"The number of environments to be set is different from the number of environments in the model: "
f"({env.num_envs} != {self.n_envs}), whereas `set_env` requires them to be the same. To load a model with "
f"a different number of environments, you must use `{self.__class__.__name__}.load(path, env)` instead"
)
Implement HER (#120) * Added working her version, Online sampling is missing. * Updated test_her. * Added first version of online her sampling. Still problems with tensor dimensions. * Reformat * Fixed tests * Added some comments. * Updated changelog. * Add missing init file * Fixed some small bugs. * Reduced arguments for HER, small changes. * Added getattr. Fixed bug for online sampling. * Updated save/load funtions. Small changes. * Added her to init. * Updated save method. * Updated her ratio. * Move obs_wrapper * Added DQN test. * Fix potential bug * Offline and online her share same sample_goal function. * Changed lists into arrays. * Updated her test. * Fix online sampling * Fixed action bug. Updated time limit for episodes. * Updated convert_dict method to take keys as arguments. * Renamed obs dict wrapper. * Seed bit flipping env * Remove get_episode_dict * Add fast online sampling version * Added documentation. * Vectorized reward computation * Vectorized goal sampling * Update time limit for episodes in online her sampling. * Fix max episode length inference * Bug fix for Fetch envs * Fix for HER + gSDE * Reformat (new black version) * Added info dict to compute new reward. Check her_replay_buffer again. * Fix info buffer * Updated done flag. * Fixes for gSDE * Offline her version uses now HerReplayBuffer as episode storage. * Fix num_timesteps computation * Fix get torch params * Vectorized version for offline sampling. * Modified offline her sampling to use sample method of her_replay_buffer * Updated HER tests. * Updated documentation * Cleanup docstrings * Updated to review comments * Fix pytype * Update according to review comments. * Removed random goal strategy. Updated sample transitions. * Updated migration. Removed time signal removal. * Update doc * Fix potential load issue * Add VecNormalize support for dict obs * Updated saving/loading replay buffer for HER. * Fix test memory usage * Fixed save/load replay buffer. * Fixed save/load replay buffer * Fixed transition index after loading replay buffer in online sampling * Better error handling * Add tests for get_time_limit * More tests for VecNormalize with dict obs * Update doc * Improve HER description * Add test for sde support * Add comments * Add comments * Remove check that was always valid * Fix for terminal observation * Updated buffer size in offline version and reset of HER buffer * Reformat * Update doc * Remove np.empty + add doc * Fix loading * Updated loading replay buffer * Separate online and offline sampling + bug fixes * Update tensorboard log name * Version bump * Bug fix for special case Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de> Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-10-22 09:56:43 +00:00
# Check that the observation spaces match
check_for_correct_spaces(env, self.observation_space, self.action_space)
# Update VecNormalize object
# otherwise the wrong env may be used, see https://github.com/DLR-RM/stable-baselines3/issues/637
self._vec_normalize_env = unwrap_vec_normalize(env)
# Discard `_last_obs`, this will force the env to reset before training
# See issue https://github.com/DLR-RM/stable-baselines3/issues/597
if force_reset:
self._last_obs = None
self.n_envs = env.num_envs
self.env = env
@abstractmethod
def learn(
self: SelfBaseAlgorithm,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 100,
tb_log_name: str = "run",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfBaseAlgorithm:
"""
Return a trained model.
:param total_timesteps: The total number of samples (env steps) to train on
:param callback: callback(s) called at every step with state of the algorithm.
:param log_interval: The number of timesteps before logging.
:param tb_log_name: the name of the run for TensorBoard logging
:param reset_num_timesteps: whether or not to reset the current timestep number (used in logging)
:param progress_bar: Display a progress bar using tqdm and rich.
:return: the trained model
"""
def predict(
self,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""
Get the policy action from an observation (and optional hidden state).
Includes sugar-coating to handle different observations (e.g. normalizing images).
:param observation: the input observation
:param state: The last hidden states (can be None, used in recurrent policies)
:param episode_start: The last masks (can be None, used in recurrent policies)
this correspond to beginning of episodes,
where the hidden states of the RNN must be reset.
:param deterministic: Whether or not to return deterministic actions.
:return: the model's action and the next hidden state
(used in recurrent policies)
"""
return self.policy.predict(observation, state, episode_start, deterministic)
def set_random_seed(self, seed: Optional[int] = None) -> None:
"""
Set the seed of the pseudo-random generators
(python, numpy, pytorch, gym, action_space)
:param seed:
"""
if seed is None:
return
set_random_seed(seed, using_cuda=self.device.type == th.device("cuda").type)
self.action_space.seed(seed)
# self.env is always a VecEnv
if self.env is not None:
self.env.seed(seed)
def set_parameters(
self,
load_path_or_dict: Union[str, Dict[str, Dict]],
exact_match: bool = True,
device: Union[th.device, str] = "auto",
) -> None:
"""
Load parameters from a given zip-file or a nested dictionary containing parameters for
different modules (see ``get_parameters``).
:param load_path_or_iter: Location of the saved data (path or file-like, see ``save``), or a nested
dictionary containing nn.Module parameters used by the policy. The dictionary maps
object names to a state-dictionary returned by ``torch.nn.Module.state_dict()``.
:param exact_match: If True, the given parameters should include parameters for each
module and each of their parameters, otherwise raises an Exception. If set to False, this
can be used to update only specific parameters.
:param device: Device on which the code should run.
"""
params = None
if isinstance(load_path_or_dict, dict):
params = load_path_or_dict
else:
_, params, _ = load_from_zip_file(load_path_or_dict, device=device)
# Keep track which objects were updated.
# `_get_torch_save_params` returns [params, other_pytorch_variables].
# We are only interested in former here.
objects_needing_update = set(self._get_torch_save_params()[0])
updated_objects = set()
for name in params:
attr = None
try:
attr = recursive_getattr(self, name)
except Exception as e:
# What errors recursive_getattr could throw? KeyError, but
# possible something else too (e.g. if key is an int?).
# Catch anything for now.
raise ValueError(f"Key {name} is an invalid object name.") from e
if isinstance(attr, th.optim.Optimizer):
# Optimizers do not support "strict" keyword...
# Seems like they will just replace the whole
# optimizer state with the given one.
# On top of this, optimizer state-dict
# seems to change (e.g. first ``optim.step()``),
# which makes comparing state dictionary keys
# invalid (there is also a nesting of dictionaries
# with lists with dictionaries with ...), adding to the
# mess.
#
# TL;DR: We might not be able to reliably say
# if given state-dict is missing keys.
#
# Solution: Just load the state-dict as is, and trust
# the user has provided a sensible state dictionary.
attr.load_state_dict(params[name])
else:
# Assume attr is th.nn.Module
attr.load_state_dict(params[name], strict=exact_match)
updated_objects.add(name)
if exact_match and updated_objects != objects_needing_update:
raise ValueError(
"Names of parameters do not match agents' parameters: "
f"expected {objects_needing_update}, got {updated_objects}"
)
@classmethod
def load(
cls: Type[SelfBaseAlgorithm],
path: Union[str, pathlib.Path, io.BufferedIOBase],
env: Optional[GymEnv] = None,
device: Union[th.device, str] = "auto",
custom_objects: Optional[Dict[str, Any]] = None,
print_system_info: bool = False,
force_reset: bool = True,
**kwargs,
) -> SelfBaseAlgorithm:
"""
Load the model from a zip-file.
Warning: ``load`` re-creates the model from scratch, it does not update it in-place!
For an in-place load use ``set_parameters`` instead.
:param path: path to the file (or a file-like) where to
load the agent from
:param env: the new environment to run the loaded model on
(can be None if you only need prediction from a trained model) has priority over any saved environment
:param device: Device on which the code should run.
:param custom_objects: Dictionary of objects to replace
upon loading. If a variable is present in this dictionary as a
key, it will not be deserialized and the corresponding item
will be used instead. Similar to custom_objects in
``keras.models.load_model``. Useful when you have an object in
file that can not be deserialized.
:param print_system_info: Whether to print system info from the saved model
and the current system info (useful to debug loading issues)
:param force_reset: Force call to ``reset()`` before training
to avoid unexpected behavior.
See https://github.com/DLR-RM/stable-baselines3/issues/597
:param kwargs: extra arguments to change the model when loading
:return: new model instance with loaded parameters
"""
if print_system_info:
print("== CURRENT SYSTEM INFO ==")
get_system_info()
data, params, pytorch_variables = load_from_zip_file(
path,
device=device,
custom_objects=custom_objects,
print_system_info=print_system_info,
)
# Remove stored device information and replace with ours
if "policy_kwargs" in data:
if "device" in data["policy_kwargs"]:
del data["policy_kwargs"]["device"]
if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data["policy_kwargs"]:
raise ValueError(
f"The specified policy kwargs do not equal the stored policy kwargs."
f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}"
)
if "observation_space" not in data or "action_space" not in data:
raise KeyError("The observation_space and action_space were not given, can't verify new environments")
if env is not None:
Implement HER (#120) * Added working her version, Online sampling is missing. * Updated test_her. * Added first version of online her sampling. Still problems with tensor dimensions. * Reformat * Fixed tests * Added some comments. * Updated changelog. * Add missing init file * Fixed some small bugs. * Reduced arguments for HER, small changes. * Added getattr. Fixed bug for online sampling. * Updated save/load funtions. Small changes. * Added her to init. * Updated save method. * Updated her ratio. * Move obs_wrapper * Added DQN test. * Fix potential bug * Offline and online her share same sample_goal function. * Changed lists into arrays. * Updated her test. * Fix online sampling * Fixed action bug. Updated time limit for episodes. * Updated convert_dict method to take keys as arguments. * Renamed obs dict wrapper. * Seed bit flipping env * Remove get_episode_dict * Add fast online sampling version * Added documentation. * Vectorized reward computation * Vectorized goal sampling * Update time limit for episodes in online her sampling. * Fix max episode length inference * Bug fix for Fetch envs * Fix for HER + gSDE * Reformat (new black version) * Added info dict to compute new reward. Check her_replay_buffer again. * Fix info buffer * Updated done flag. * Fixes for gSDE * Offline her version uses now HerReplayBuffer as episode storage. * Fix num_timesteps computation * Fix get torch params * Vectorized version for offline sampling. * Modified offline her sampling to use sample method of her_replay_buffer * Updated HER tests. * Updated documentation * Cleanup docstrings * Updated to review comments * Fix pytype * Update according to review comments. * Removed random goal strategy. Updated sample transitions. * Updated migration. Removed time signal removal. * Update doc * Fix potential load issue * Add VecNormalize support for dict obs * Updated saving/loading replay buffer for HER. * Fix test memory usage * Fixed save/load replay buffer. * Fixed save/load replay buffer * Fixed transition index after loading replay buffer in online sampling * Better error handling * Add tests for get_time_limit * More tests for VecNormalize with dict obs * Update doc * Improve HER description * Add test for sde support * Add comments * Add comments * Remove check that was always valid * Fix for terminal observation * Updated buffer size in offline version and reset of HER buffer * Reformat * Update doc * Remove np.empty + add doc * Fix loading * Updated loading replay buffer * Separate online and offline sampling + bug fixes * Update tensorboard log name * Version bump * Bug fix for special case Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de> Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-10-22 09:56:43 +00:00
# Wrap first if needed
env = cls._wrap_env(env, data["verbose"])
# Check if given env is valid
check_for_correct_spaces(env, data["observation_space"], data["action_space"])
# Discard `_last_obs`, this will force the env to reset before training
# See issue https://github.com/DLR-RM/stable-baselines3/issues/597
if force_reset and data is not None:
data["_last_obs"] = None
# `n_envs` must be updated. See issue https://github.com/DLR-RM/stable-baselines3/issues/1018
if data is not None:
data["n_envs"] = env.num_envs
else:
# Use stored env, if one exists. If not, continue as is (can be used for predict)
if "env" in data:
env = data["env"]
# noinspection PyArgumentList
model = cls( # pytype: disable=not-instantiable,wrong-keyword-args
policy=data["policy_class"],
env=env,
device=device,
_init_setup_model=False, # pytype: disable=not-instantiable,wrong-keyword-args
)
# load parameters
model.__dict__.update(data)
model.__dict__.update(kwargs)
model._setup_model()
# put state_dicts back in place
model.set_parameters(params, exact_match=True, device=device)
# put other pytorch variables back in place
if pytorch_variables is not None:
for name in pytorch_variables:
# Skip if PyTorch variable was not defined (to ensure backward compatibility).
# This happens when using SAC/TQC.
# SAC has an entropy coefficient which can be fixed or optimized.
# If it is optimized, an additional PyTorch variable `log_ent_coef` is defined,
# otherwise it is initialized to `None`.
if pytorch_variables[name] is None:
continue
# Set the data attribute directly to avoid issue when using optimizers
# See https://github.com/DLR-RM/stable-baselines3/issues/391
recursive_setattr(model, name + ".data", pytorch_variables[name].data)
# Sample gSDE exploration matrix, so it uses the right device
# see issue #44
if model.use_sde:
model.policy.reset_noise() # pytype: disable=attribute-error
return model
def get_parameters(self) -> Dict[str, Dict]:
"""
Return the parameters of the agent. This includes parameters from different networks, e.g.
critics (value functions) and policies (pi functions).
:return: Mapping of from names of the objects to PyTorch state-dicts.
"""
state_dicts_names, _ = self._get_torch_save_params()
params = {}
for name in state_dicts_names:
attr = recursive_getattr(self, name)
# Retrieve state dict
params[name] = attr.state_dict()
return params
2020-02-03 17:18:41 +00:00
def save(
self,
path: Union[str, pathlib.Path, io.BufferedIOBase],
2020-07-03 01:49:59 +00:00
exclude: Optional[Iterable[str]] = None,
include: Optional[Iterable[str]] = None,
) -> None:
2020-02-03 17:18:41 +00:00
"""
Save all the attributes of the object and the model parameters in a zip-file.
:param path: path to the file where the rl agent should be saved
:param exclude: name of parameters that should be excluded in addition to the default ones
2020-02-03 17:18:41 +00:00
:param include: name of parameters that might be excluded but should be included anyway
"""
# Copy parameter list so we don't mutate the original dict
2020-02-03 17:18:41 +00:00
data = self.__dict__.copy()
2020-07-03 01:49:59 +00:00
# Exclude is union of specified parameters (if any) and standard exclusions
2020-02-03 17:18:41 +00:00
if exclude is None:
2020-07-03 01:49:59 +00:00
exclude = []
exclude = set(exclude).union(self._excluded_save_params())
2020-02-03 17:18:41 +00:00
2020-07-03 01:49:59 +00:00
# Do not exclude params if they are specifically included
2020-02-03 17:18:41 +00:00
if include is not None:
2020-07-03 01:49:59 +00:00
exclude = exclude.difference(include)
2020-02-03 17:18:41 +00:00
state_dicts_names, torch_variable_names = self._get_torch_save_params()
all_pytorch_variables = state_dicts_names + torch_variable_names
for torch_var in all_pytorch_variables:
# We need to get only the name of the top most module as we'll remove that
var_name = torch_var.split(".")[0]
# Any params that are in the save vars must not be saved by data
2020-07-03 01:49:59 +00:00
exclude.add(var_name)
2020-02-03 17:18:41 +00:00
# Remove parameter entries of parameters which are to be excluded
for param_name in exclude:
2020-07-03 01:49:59 +00:00
data.pop(param_name, None)
2020-02-03 17:18:41 +00:00
# Build dict of torch variables
pytorch_variables = None
if torch_variable_names is not None:
pytorch_variables = {}
for name in torch_variable_names:
2020-02-03 17:18:41 +00:00
attr = recursive_getattr(self, name)
pytorch_variables[name] = attr
2020-02-03 17:18:41 +00:00
# Build dict of state_dicts
params_to_save = self.get_parameters()
2020-02-03 17:18:41 +00:00
save_to_zip_file(path, data=data, params=params_to_save, pytorch_variables=pytorch_variables)