From 254bb10c42e8f892e43af9da25aefc7c604c317c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gr=C3=A9goire=20Passault?= Date: Fri, 8 Apr 2022 15:21:53 -0400 Subject: [PATCH 01/33] Replacing the policy registry with policy "aliases" (#842) * Replacing the policy registry with policy "aliases" * Fixing import order and SAC * Changing arg. order to be sure policy_aliases is a kwarg * Import orders * Removing pytype error check * Reformat * Fix alias import * Not using mutable {} as default for policy_aliases * Empty aliases initialization * Using static attributes for policy_aliases * Fixing isort * Fixing back bad merge * Running isort * Fixing aliases for A2C and PPO * Using f-string * Moving policy_aliases definition position * Adding change in the changelog * Update version Co-authored-by: Antonin Raffin --- docs/misc/changelog.rst | 4 +- stable_baselines3/a2c/a2c.py | 8 ++- stable_baselines3/a2c/policies.py | 11 +--- stable_baselines3/common/base_class.py | 29 +++++++-- .../common/off_policy_algorithm.py | 3 - .../common/on_policy_algorithm.py | 5 +- stable_baselines3/common/policies.py | 65 ------------------- stable_baselines3/dqn/dqn.py | 10 ++- stable_baselines3/dqn/policies.py | 7 +- stable_baselines3/ppo/policies.py | 11 +--- stable_baselines3/ppo/ppo.py | 8 ++- stable_baselines3/sac/policies.py | 7 +- stable_baselines3/sac/sac.py | 10 ++- stable_baselines3/td3/policies.py | 7 +- stable_baselines3/td3/td3.py | 10 ++- stable_baselines3/version.txt | 2 +- 16 files changed, 71 insertions(+), 126 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index e17c3df..b209f16 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,11 +4,13 @@ Changelog ========== -Release 1.5.1a0 (WIP) +Release 1.5.1a1 (WIP) --------------------------- Breaking Changes: ^^^^^^^^^^^^^^^^^ +- Changed the way policy "aliases" are handled ("MlpPolicy", "CnnPolicy", ...), removing the former + ``register_policy`` helper, ``policy_base`` parameter and using ``policy_aliases`` static attributes instead (@Gregwar) New Features: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/a2c/a2c.py b/stable_baselines3/a2c/a2c.py index 837ec42..eeeb670 100644 --- a/stable_baselines3/a2c/a2c.py +++ b/stable_baselines3/a2c/a2c.py @@ -5,7 +5,7 @@ from gym import spaces from torch.nn import functional as F from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm -from stable_baselines3.common.policies import ActorCriticPolicy +from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance @@ -51,6 +51,12 @@ class A2C(OnPolicyAlgorithm): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ + policy_aliases: Dict[str, Type[BasePolicy]] = { + "MlpPolicy": ActorCriticPolicy, + "CnnPolicy": ActorCriticCnnPolicy, + "MultiInputPolicy": MultiInputActorCriticPolicy, + } + def __init__( self, policy: Union[str, Type[ActorCriticPolicy]], diff --git a/stable_baselines3/a2c/policies.py b/stable_baselines3/a2c/policies.py index 79c85f8..7299b34 100644 --- a/stable_baselines3/a2c/policies.py +++ b/stable_baselines3/a2c/policies.py @@ -1,16 +1,7 @@ # This file is here just to define MlpPolicy/CnnPolicy # that work for A2C -from stable_baselines3.common.policies import ( - ActorCriticCnnPolicy, - ActorCriticPolicy, - MultiInputActorCriticPolicy, - register_policy, -) +from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy MlpPolicy = ActorCriticPolicy CnnPolicy = ActorCriticCnnPolicy MultiInputPolicy = MultiInputActorCriticPolicy - -register_policy("MlpPolicy", ActorCriticPolicy) -register_policy("CnnPolicy", ActorCriticCnnPolicy) -register_policy("MultiInputPolicy", MultiInputPolicy) diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 25c2638..14570be 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -17,7 +17,7 @@ from stable_baselines3.common.env_util import is_wrapped from stable_baselines3.common.logger import Logger from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.noise import ActionNoise -from stable_baselines3.common.policies import BasePolicy, get_policy_from_name +from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space, is_image_space_channels_first from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule @@ -60,7 +60,6 @@ class BaseAlgorithm(ABC): :param policy: Policy object :param env: The environment to learn from (if registered in Gym, can be str. Can be None for loading trained models) - :param policy_base: The base policy used by this method :param learning_rate: learning rate for the optimizer, it can be a function of the current progress remaining (from 1 to 0) :param policy_kwargs: Additional arguments to be passed to the policy on creation @@ -83,11 +82,13 @@ class BaseAlgorithm(ABC): :param supported_action_spaces: The action spaces supported by the algorithm. """ + # Policy aliases (see _get_policy_from_name()) + policy_aliases: Dict[str, Type[BasePolicy]] = {} + def __init__( self, policy: Type[BasePolicy], env: Union[GymEnv, str, None], - policy_base: Type[BasePolicy], learning_rate: Union[float, Schedule], policy_kwargs: Optional[Dict[str, Any]] = None, tensorboard_log: Optional[str] = None, @@ -101,9 +102,8 @@ class BaseAlgorithm(ABC): sde_sample_freq: int = -1, supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None, ): - - if isinstance(policy, str) and policy_base is not None: - self.policy_class = get_policy_from_name(policy_base, policy) + if isinstance(policy, str): + self.policy_class = self._get_policy_from_name(policy) else: self.policy_class = policy @@ -325,6 +325,23 @@ class BaseAlgorithm(ABC): "_custom_logger", ] + def _get_policy_from_name(self, policy_name: str) -> Type[BasePolicy]: + """ + Get a policy class from its name representation. + + The goal here is to standardize policy naming, e.g. + all algorithms can call upon "MlpPolicy" or "CnnPolicy", + and they receive respective policies that work for them. + + :param policy_name: Alias of the policy + :return: A policy class (type) + """ + + if policy_name in self.policy_aliases: + return self.policy_aliases[policy_name] + else: + raise ValueError(f"Policy {policy_name} unknown") + def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: """ Get the name of the torch variables that will be saved with diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index 27e8bdd..5905dee 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -28,7 +28,6 @@ class OffPolicyAlgorithm(BaseAlgorithm): :param policy: Policy object :param env: The environment to learn from (if registered in Gym, can be str. Can be None for loading trained models) - :param policy_base: The base policy used by this method :param learning_rate: learning rate for the optimizer, it can be a function of the current progress remaining (from 1 to 0) :param buffer_size: size of the replay buffer @@ -76,7 +75,6 @@ class OffPolicyAlgorithm(BaseAlgorithm): self, policy: Type[BasePolicy], env: Union[GymEnv, str], - policy_base: Type[BasePolicy], learning_rate: Union[float, Schedule], buffer_size: int = 1_000_000, # 1e6 learning_starts: int = 100, @@ -107,7 +105,6 @@ class OffPolicyAlgorithm(BaseAlgorithm): super(OffPolicyAlgorithm, self).__init__( policy=policy, env=env, - policy_base=policy_base, learning_rate=learning_rate, policy_kwargs=policy_kwargs, tensorboard_log=tensorboard_log, diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 48cb365..281758c 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -8,7 +8,7 @@ import torch as th from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer from stable_baselines3.common.callbacks import BaseCallback -from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy +from stable_baselines3.common.policies import ActorCriticPolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import obs_as_tensor, safe_mean from stable_baselines3.common.vec_env import VecEnv @@ -34,7 +34,6 @@ class OnPolicyAlgorithm(BaseAlgorithm): instead of action noise exploration (default: False) :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) - :param policy_base: The base policy used by this method :param tensorboard_log: the log location for tensorboard (if None, no logging) :param create_eval_env: Whether to create a second environment that will be used for evaluating the agent periodically. (Only available when passing string for the environment) @@ -62,7 +61,6 @@ class OnPolicyAlgorithm(BaseAlgorithm): max_grad_norm: float, use_sde: bool, sde_sample_freq: int, - policy_base: Type[BasePolicy] = ActorCriticPolicy, tensorboard_log: Optional[str] = None, create_eval_env: bool = False, monitor_wrapper: bool = True, @@ -77,7 +75,6 @@ class OnPolicyAlgorithm(BaseAlgorithm): super(OnPolicyAlgorithm, self).__init__( policy=policy, env=env, - policy_base=policy_base, learning_rate=learning_rate, policy_kwargs=policy_kwargs, verbose=verbose, diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 33918b7..c322dc6 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -894,68 +894,3 @@ class ContinuousCritic(BaseModel): with th.no_grad(): features = self.extract_features(obs) return self.q_networks[0](th.cat([features, actions], dim=1)) - - -_policy_registry = dict() # type: Dict[Type[BasePolicy], Dict[str, Type[BasePolicy]]] - - -def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[BasePolicy]: - """ - Returns the registered policy from the base type and name. - See `register_policy` for registering policies and explanation. - - :param base_policy_type: the base policy class - :param name: the policy name - :return: the policy - """ - if base_policy_type not in _policy_registry: - raise KeyError(f"Error: the policy type {base_policy_type} is not registered!") - if name not in _policy_registry[base_policy_type]: - raise KeyError( - f"Error: unknown policy type {name}," - f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!" - ) - return _policy_registry[base_policy_type][name] - - -def register_policy(name: str, policy: Type[BasePolicy]) -> None: - """ - Register a policy, so it can be called using its name. - e.g. SAC('MlpPolicy', ...) instead of SAC(MlpPolicy, ...). - - The goal here is to standardize policy naming, e.g. - all algorithms can call upon "MlpPolicy" or "CnnPolicy", - and they receive respective policies that work for them. - Consider following: - - OnlinePolicy - -- OnlineMlpPolicy ("MlpPolicy") - -- OnlineCnnPolicy ("CnnPolicy") - OfflinePolicy - -- OfflineMlpPolicy ("MlpPolicy") - -- OfflineCnnPolicy ("CnnPolicy") - - Two policies have name "MlpPolicy" and two have "CnnPolicy". - In `get_policy_from_name`, the parent class (e.g. OnlinePolicy) - is given and used to select and return the correct policy. - - :param name: the policy name - :param policy: the policy class - """ - sub_class = None - for cls in BasePolicy.__subclasses__(): - if issubclass(policy, cls): - sub_class = cls - break - if sub_class is None: - raise ValueError(f"Error: the policy {policy} is not of any known subclasses of BasePolicy!") - - if sub_class not in _policy_registry: - _policy_registry[sub_class] = {} - if name in _policy_registry[sub_class]: - # Check if the registered policy is same - # we try to register. If not so, - # do not override and complain. - if _policy_registry[sub_class][name] != policy: - raise ValueError(f"Error: the name {name} is already registered for a different policy, will not override.") - _policy_registry[sub_class][name] = policy diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index a7aec6b..ed6073b 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -8,10 +8,11 @@ from torch.nn import functional as F from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm +from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.preprocessing import maybe_transpose from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update -from stable_baselines3.dqn.policies import DQNPolicy +from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy class DQN(OffPolicyAlgorithm): @@ -59,6 +60,12 @@ class DQN(OffPolicyAlgorithm): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ + policy_aliases: Dict[str, Type[BasePolicy]] = { + "MlpPolicy": MlpPolicy, + "CnnPolicy": CnnPolicy, + "MultiInputPolicy": MultiInputPolicy, + } + def __init__( self, policy: Union[str, Type[DQNPolicy]], @@ -91,7 +98,6 @@ class DQN(OffPolicyAlgorithm): super(DQN, self).__init__( policy, env, - DQNPolicy, learning_rate, buffer_size, learning_starts, diff --git a/stable_baselines3/dqn/policies.py b/stable_baselines3/dqn/policies.py index 099a4e3..ea00b5c 100644 --- a/stable_baselines3/dqn/policies.py +++ b/stable_baselines3/dqn/policies.py @@ -4,7 +4,7 @@ import gym import torch as th from torch import nn -from stable_baselines3.common.policies import BasePolicy, register_policy +from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.torch_layers import ( BaseFeaturesExtractor, CombinedExtractor, @@ -296,8 +296,3 @@ class MultiInputPolicy(DQNPolicy): optimizer_class, optimizer_kwargs, ) - - -register_policy("MlpPolicy", MlpPolicy) -register_policy("CnnPolicy", CnnPolicy) -register_policy("MultiInputPolicy", MultiInputPolicy) diff --git a/stable_baselines3/ppo/policies.py b/stable_baselines3/ppo/policies.py index 7427cfc..fb7afae 100644 --- a/stable_baselines3/ppo/policies.py +++ b/stable_baselines3/ppo/policies.py @@ -1,16 +1,7 @@ # This file is here just to define MlpPolicy/CnnPolicy # that work for PPO -from stable_baselines3.common.policies import ( - ActorCriticCnnPolicy, - ActorCriticPolicy, - MultiInputActorCriticPolicy, - register_policy, -) +from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy MlpPolicy = ActorCriticPolicy CnnPolicy = ActorCriticCnnPolicy MultiInputPolicy = MultiInputActorCriticPolicy - -register_policy("MlpPolicy", ActorCriticPolicy) -register_policy("CnnPolicy", ActorCriticCnnPolicy) -register_policy("MultiInputPolicy", MultiInputPolicy) diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index 088bab3..0d05b4c 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -7,7 +7,7 @@ from gym import spaces from torch.nn import functional as F from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm -from stable_baselines3.common.policies import ActorCriticPolicy +from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance, get_schedule_fn @@ -65,6 +65,12 @@ class PPO(OnPolicyAlgorithm): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ + policy_aliases: Dict[str, Type[BasePolicy]] = { + "MlpPolicy": ActorCriticPolicy, + "CnnPolicy": ActorCriticCnnPolicy, + "MultiInputPolicy": MultiInputActorCriticPolicy, + } + def __init__( self, policy: Union[str, Type[ActorCriticPolicy]], diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py index 0bd1382..cb6a61c 100644 --- a/stable_baselines3/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -6,7 +6,7 @@ import torch as th from torch import nn from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution -from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, register_policy +from stable_baselines3.common.policies import BasePolicy, ContinuousCritic from stable_baselines3.common.preprocessing import get_action_dim from stable_baselines3.common.torch_layers import ( BaseFeaturesExtractor, @@ -514,8 +514,3 @@ class MultiInputPolicy(SACPolicy): n_critics, share_features_extractor, ) - - -register_policy("MlpPolicy", MlpPolicy) -register_policy("CnnPolicy", CnnPolicy) -register_policy("MultiInputPolicy", MultiInputPolicy) diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index 5f3a833..3703b73 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -8,9 +8,10 @@ from torch.nn import functional as F from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm +from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import polyak_update -from stable_baselines3.sac.policies import SACPolicy +from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy class SAC(OffPolicyAlgorithm): @@ -72,6 +73,12 @@ class SAC(OffPolicyAlgorithm): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ + policy_aliases: Dict[str, Type[BasePolicy]] = { + "MlpPolicy": MlpPolicy, + "CnnPolicy": CnnPolicy, + "MultiInputPolicy": MultiInputPolicy, + } + def __init__( self, policy: Union[str, Type[SACPolicy]], @@ -106,7 +113,6 @@ class SAC(OffPolicyAlgorithm): super(SAC, self).__init__( policy, env, - SACPolicy, learning_rate, buffer_size, learning_starts, diff --git a/stable_baselines3/td3/policies.py b/stable_baselines3/td3/policies.py index 264c760..ce91a0f 100644 --- a/stable_baselines3/td3/policies.py +++ b/stable_baselines3/td3/policies.py @@ -4,7 +4,7 @@ import gym import torch as th from torch import nn -from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, register_policy +from stable_baselines3.common.policies import BasePolicy, ContinuousCritic from stable_baselines3.common.preprocessing import get_action_dim from stable_baselines3.common.torch_layers import ( BaseFeaturesExtractor, @@ -351,8 +351,3 @@ class MultiInputPolicy(TD3Policy): n_critics, share_features_extractor, ) - - -register_policy("MlpPolicy", MlpPolicy) -register_policy("CnnPolicy", CnnPolicy) -register_policy("MultiInputPolicy", MultiInputPolicy) diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index eb257a6..d31720b 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -8,9 +8,10 @@ from torch.nn import functional as F from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm +from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import polyak_update -from stable_baselines3.td3.policies import TD3Policy +from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TD3Policy class TD3(OffPolicyAlgorithm): @@ -60,6 +61,12 @@ class TD3(OffPolicyAlgorithm): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ + policy_aliases: Dict[str, Type[BasePolicy]] = { + "MlpPolicy": MlpPolicy, + "CnnPolicy": CnnPolicy, + "MultiInputPolicy": MultiInputPolicy, + } + def __init__( self, policy: Union[str, Type[TD3Policy]], @@ -91,7 +98,6 @@ class TD3(OffPolicyAlgorithm): super(TD3, self).__init__( policy, env, - TD3Policy, learning_rate, buffer_size, learning_starts, diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 33271c4..1110517 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.5.1a0 +1.5.1a1 From 16703b13143eb2b55e216ac831758d523165ff1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Mon, 11 Apr 2022 17:50:02 +0200 Subject: [PATCH 02/33] Fix HER goal selection (#848) * Goal sampled from next_achieved_goal instead of achived_goal * No need to have special case for future anymore * Update changelog Co-authored-by: Antonin Raffin --- docs/misc/changelog.rst | 3 ++- stable_baselines3/her/her_replay_buffer.py | 12 ++---------- stable_baselines3/version.txt | 2 +- 3 files changed, 5 insertions(+), 12 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index b209f16..652b5a6 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.5.1a1 (WIP) +Release 1.5.1a2(WIP) --------------------------- Breaking Changes: @@ -21,6 +21,7 @@ SB3-Contrib Bug Fixes: ^^^^^^^^^^ - Fixed saving and loading large policies greater than 2GB (@jkterry1, @ycheng517) +- Fixed final goal selection strategy that did not sample the final achieved goal (@qgallouedec) Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/her/her_replay_buffer.py b/stable_baselines3/her/her_replay_buffer.py index 9a41477..f61a786 100644 --- a/stable_baselines3/her/her_replay_buffer.py +++ b/stable_baselines3/her/her_replay_buffer.py @@ -252,7 +252,7 @@ class HerReplayBuffer(DictReplayBuffer): elif self.goal_selection_strategy == GoalSelectionStrategy.FUTURE: # replay with random state which comes from the same episode and was observed after current transition transitions_indices = np.random.randint( - transitions_indices[her_indices] + 1, self.episode_lengths[her_episode_indices] + transitions_indices[her_indices], self.episode_lengths[her_episode_indices] ) elif self.goal_selection_strategy == GoalSelectionStrategy.EPISODE: @@ -262,7 +262,7 @@ class HerReplayBuffer(DictReplayBuffer): else: raise ValueError(f"Strategy {self.goal_selection_strategy} for sampling goals not supported!") - return self._buffer["achieved_goal"][her_episode_indices, transitions_indices] + return self._buffer["next_achieved_goal"][her_episode_indices, transitions_indices] def _sample_transitions( self, @@ -304,14 +304,6 @@ class HerReplayBuffer(DictReplayBuffer): ep_lengths = self.episode_lengths[episode_indices] - # Special case when using the "future" goal sampling strategy - # we cannot sample all transitions, we have to remove the last timestep - if self.goal_selection_strategy == GoalSelectionStrategy.FUTURE: - # restrict the sampling domain when ep_lengths > 1 - # otherwise filter out the indices - her_indices = her_indices[ep_lengths[her_indices] > 1] - ep_lengths[her_indices] -= 1 - if online_sampling: # Select which transitions to use transitions_indices = np.random.randint(ep_lengths) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 1110517..1a2eef7 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.5.1a1 +1.5.1a2 From 248f082cdc4c64974a602f5359799e16ca17f2cf Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 11 Apr 2022 18:34:15 +0200 Subject: [PATCH 03/33] Bump min PyTorch version (#855) --- .github/workflows/ci.yml | 2 +- README.md | 4 ++-- docs/guide/custom_env.rst | 6 +++--- docs/guide/install.rst | 2 +- docs/misc/changelog.rst | 4 +++- setup.py | 2 +- stable_baselines3/version.txt | 2 +- 7 files changed, 12 insertions(+), 10 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b45ae31..4bc23e2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -28,7 +28,7 @@ jobs: run: | python -m pip install --upgrade pip # cpu version of pytorch - pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html + pip install torch==1.11+cpu -f https://download.pytorch.org/whl/torch_stable.html pip install .[extra,tests,docs] # Use headless version pip install opencv-python-headless diff --git a/README.md b/README.md index 2a0701c..f727547 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ Documentation is available online: [https://stable-baselines3.readthedocs.io/](h ## Integrations -Stable-Baselines3 has some integration with other libraries/services like Weights & Biases for experiment tracking or Hugging Face for storing/sharing trained models. You can find out more in the [dedicated section](https://stable-baselines3.readthedocs.io/en/master/guide/integrations.html) of the documentation. +Stable-Baselines3 has some integration with other libraries/services like Weights & Biases for experiment tracking or Hugging Face for storing/sharing trained models. You can find out more in the [dedicated section](https://stable-baselines3.readthedocs.io/en/master/guide/integrations.html) of the documentation. ## RL Baselines3 Zoo: A Training Framework for Stable Baselines3 Reinforcement Learning Agents @@ -84,7 +84,7 @@ Documentation is available online: [https://sb3-contrib.readthedocs.io/](https:/ ## Installation -**Note:** Stable-Baselines3 supports PyTorch >= 1.8.1. +**Note:** Stable-Baselines3 supports PyTorch >= 1.11 ### Prerequisites Stable Baselines3 requires Python 3.7+. diff --git a/docs/guide/custom_env.rst b/docs/guide/custom_env.rst index 355ecb2..2e2d1f7 100644 --- a/docs/guide/custom_env.rst +++ b/docs/guide/custom_env.rst @@ -61,7 +61,7 @@ Then you can define and train a RL agent with: model = A2C('CnnPolicy', env).learn(total_timesteps=1000) -To check that your environment follows the gym interface, please use: +To check that your environment follows the Gym interface that SB3 supports, please use: .. code-block:: python @@ -71,11 +71,11 @@ To check that your environment follows the gym interface, please use: # It will check your custom environment and output additional warnings if needed check_env(env) - +Gym also have its own `env checker `_ but it checks a superset of what SB3 supports (SB3 does not support all Gym features). We have created a `colab notebook `_ for a concrete example on creating a custom environment along with an example of using it with Stable-Baselines3 interface. -Alternatively, you may look at OpenAI Gym `built-in environments `_. However, the readers are cautioned as per OpenAI Gym `official wiki `_, its advised not to customize their built-in environments. It is better to copy and create new ones if you need to modify them. +Alternatively, you may look at OpenAI Gym `built-in environments `_. However, the readers are cautioned as per OpenAI Gym `official wiki `_, its advised not to customize their built-in environments. It is better to copy and create new ones if you need to modify them. Optionally, you can also register the environment with gym, that will allow you to create the RL agent in one line (and use ``gym.make()`` to instantiate the env): diff --git a/docs/guide/install.rst b/docs/guide/install.rst index 7beabb7..3b26927 100644 --- a/docs/guide/install.rst +++ b/docs/guide/install.rst @@ -6,7 +6,7 @@ Installation Prerequisites ------------- -Stable-Baselines3 requires python 3.7+ and PyTorch >= 1.8.1. +Stable-Baselines3 requires python 3.7+ and PyTorch >= 1.11 Windows 10 ~~~~~~~~~~ diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 652b5a6..3179605 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,13 +4,14 @@ Changelog ========== -Release 1.5.1a2(WIP) +Release 1.5.1a3 (WIP) --------------------------- Breaking Changes: ^^^^^^^^^^^^^^^^^ - Changed the way policy "aliases" are handled ("MlpPolicy", "CnnPolicy", ...), removing the former ``register_policy`` helper, ``policy_base`` parameter and using ``policy_aliases`` static attributes instead (@Gregwar) +- SB3 now requires PyTorch >= 1.11 New Features: ^^^^^^^^^^^^^ @@ -31,6 +32,7 @@ Others: Documentation: ^^^^^^^^^^^^^^ +- Added link to gym doc and gym env checker Release 1.5.0 (2022-03-25) diff --git a/setup.py b/setup.py index de615a7..3664bbc 100644 --- a/setup.py +++ b/setup.py @@ -75,7 +75,7 @@ setup( install_requires=[ "gym==0.21", # Fixed version due to breaking changes in 0.22 "numpy", - "torch>=1.8.1", + "torch>=1.11", # For saving models "cloudpickle", # For reading logs diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 1a2eef7..8d61b2f 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.5.1a2 +1.5.1a3 From 39a4f9379a8068110c895c4bb18cb0e4e20cd69c Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 11 Apr 2022 21:49:18 +0200 Subject: [PATCH 04/33] Escape tensorboard log name (#857) * escape tensorboard log name Otherwise utils does not recognize the log. * Added fix to changelog * Modifications made by: make commit-checks . * Revert "Modifications made by: make commit-checks ." This reverts commit 529a275d9475f85ef031038a8f3565f7301e5371. * Update changelog and add test Co-authored-by: James Hirschorn --- docs/misc/changelog.rst | 5 +++-- stable_baselines3/common/utils.py | 7 +++++-- stable_baselines3/version.txt | 2 +- tests/test_tensorboard.py | 11 +++++++++++ 4 files changed, 20 insertions(+), 5 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 3179605..3895380 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.5.1a3 (WIP) +Release 1.5.1a4 (WIP) --------------------------- Breaking Changes: @@ -23,6 +23,7 @@ Bug Fixes: ^^^^^^^^^^ - Fixed saving and loading large policies greater than 2GB (@jkterry1, @ycheng517) - Fixed final goal selection strategy that did not sample the final achieved goal (@qgallouedec) +- Fixed a bug with special characters in the tensorboard log name (@quantitative-technologies) Deprecations: ^^^^^^^^^^^^^ @@ -962,4 +963,4 @@ And all the contributors: @wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum @eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP @simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485 -@Gregwar @ycheng517 +@Gregwar @ycheng517 @quantitative-technologies diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 8504c8d..94cd658 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -154,15 +154,18 @@ def get_device(device: Union[th.device, str] = "auto") -> th.device: return device -def get_latest_run_id(log_path: Optional[str] = None, log_name: str = "") -> int: +def get_latest_run_id(log_path: str = "", log_name: str = "") -> int: """ Returns the latest run number for the given log name and log path, by finding the greatest number in the directories. + :param log_path: Path to the log folder containing several runs. + :param log_name: Name of the experiment. Each run is stored + in a folder named ``log_name_1``, ``log_name_2``, ... :return: latest run number """ max_run_id = 0 - for path in glob.glob(f"{log_path}/{log_name}_[0-9]*"): + for path in glob.glob(os.path.join(log_path, f"{glob.escape(log_name)}_[0-9]*")): file_name = path.split(os.sep)[-1] ext = file_name.split("_")[-1] if log_name == "_".join(file_name.split("_")[:-1]) and ext.isdigit() and int(ext) > max_run_id: diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 8d61b2f..d6a9f8c 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.5.1a3 +1.5.1a4 diff --git a/tests/test_tensorboard.py b/tests/test_tensorboard.py index 20f58b9..6dccf41 100644 --- a/tests/test_tensorboard.py +++ b/tests/test_tensorboard.py @@ -3,6 +3,7 @@ import os import pytest from stable_baselines3 import A2C, PPO, SAC, TD3 +from stable_baselines3.common.utils import get_latest_run_id MODEL_DICT = { "a2c": (A2C, "CartPole-v1"), @@ -35,3 +36,13 @@ def test_tensorboard(tmp_path, model_name): assert os.path.isdir(tmp_path / str(logname + "_1")) # Check that the log dir name increments correctly assert os.path.isdir(tmp_path / str(logname + "_2")) + + +def test_escape_log_name(tmp_path): + # Log name that must be escaped + log_name = "filename[16, 16]" + # Create folder + os.makedirs(str(tmp_path) + f"/{log_name}_1", exist_ok=True) + os.makedirs(str(tmp_path) + f"/{log_name}_2", exist_ok=True) + last_run_id = get_latest_run_id(tmp_path, log_name) + assert last_run_id == 2 From ed308a71be24036744b5ad4af61b083e4fbdf83c Mon Sep 17 00:00:00 2001 From: Paul Scheikl Date: Tue, 12 Apr 2022 16:05:40 +0200 Subject: [PATCH 05/33] Fixed unchecked None value in SubprocVecEnv (#808) * Fixed unchecked None value in SubprocVecEnv * Fixed unchecked None value in DummyVecEnv * Fix formatting * Update test and changelog * Improve test Co-authored-by: Antonin RAFFIN --- docs/misc/changelog.rst | 1 + .../common/vec_env/dummy_vec_env.py | 4 +- .../common/vec_env/subproc_vec_env.py | 2 + tests/test_vec_envs.py | 37 ++++++++++++++++++- 4 files changed, 41 insertions(+), 3 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 3895380..8f37265 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -24,6 +24,7 @@ Bug Fixes: - Fixed saving and loading large policies greater than 2GB (@jkterry1, @ycheng517) - Fixed final goal selection strategy that did not sample the final achieved goal (@qgallouedec) - Fixed a bug with special characters in the tensorboard log name (@quantitative-technologies) +- Fixed a bug in ``DummyVecEnv``'s and ``SubprocVecEnv``'s seeding function. None value was unchecked (@ScheiklP) Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index 5eb87cd..c0efc8c 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -51,7 +51,9 @@ class DummyVecEnv(VecEnv): return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos)) def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: - seeds = list() + if seed is None: + seed = np.random.randint(0, 2**32 - 1) + seeds = [] for idx, env in enumerate(self.envs): seeds.append(env.seed(seed + idx)) return seeds diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index 1050f3e..04f5d0c 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -123,6 +123,8 @@ class SubprocVecEnv(VecEnv): return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: + if seed is None: + seed = np.random.randint(0, 2**32 - 1) for idx, remote in enumerate(self.remotes): remote.send(("seed", seed + idx)) return [remote.recv() for remote in self.remotes] diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index 9a4c118..93ea348 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -31,7 +31,7 @@ class CustomGymEnv(gym.Env): return self.state def step(self, action): - reward = 1 + reward = float(np.random.rand()) self._choose_next_state() self.current_step += 1 done = self.current_step >= self.ep_length @@ -45,7 +45,9 @@ class CustomGymEnv(gym.Env): return np.zeros((4, 4, 3)) def seed(self, seed=None): - pass + if seed is not None: + np.random.seed(seed) + self.observation_space.seed(seed) @staticmethod def custom_method(dim_0=1, dim_1=1): @@ -440,3 +442,34 @@ def test_vec_env_is_wrapped(): vec_env = VecFrameStack(vec_env, n_stack=2) assert vec_env.env_is_wrapped(Monitor) == [False, True] + + +@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES) +def test_vec_seeding(vec_env_class): + def make_env(): + return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2))) + + # For SubprocVecEnv check for all starting methods + start_methods = [None] + if vec_env_class != DummyVecEnv: + all_methods = {"forkserver", "spawn", "fork"} + available_methods = multiprocessing.get_all_start_methods() + start_methods = list(all_methods.intersection(available_methods)) + + for start_method in start_methods: + if start_method is not None: + vec_env_class = functools.partial(SubprocVecEnv, start_method=start_method) + + n_envs = 3 + vec_env = vec_env_class([make_env] * n_envs) + # Seed with no argument + vec_env.seed() + obs = vec_env.reset() + _, rewards, _, _ = vec_env.step(np.array([vec_env.action_space.sample() for _ in range(n_envs)])) + # Seed should be different per process + assert not np.allclose(obs[0], obs[1]) + assert not np.allclose(rewards[0], rewards[1]) + assert not np.allclose(obs[1], obs[2]) + assert not np.allclose(rewards[1], rewards[2]) + + vec_env.close() From 3c468ff5582f1c05374f674331fe8ff8e7cfc70d Mon Sep 17 00:00:00 2001 From: Bryan Collazo Date: Tue, 19 Apr 2022 08:15:51 -0400 Subject: [PATCH 06/33] Update ppo documentation (remove redundant and) (#874) * Update ppo documentation (remove redundant and) PTAL, thanks! * Update changelog * Pin ale-py version Co-authored-by: Antonin Raffin --- docs/misc/changelog.rst | 3 ++- setup.py | 2 +- stable_baselines3/ppo/ppo.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 8f37265..da33766 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -35,6 +35,7 @@ Others: Documentation: ^^^^^^^^^^^^^^ - Added link to gym doc and gym env checker +- Fix typo in PPO doc (@bcollazo) Release 1.5.0 (2022-03-25) @@ -964,4 +965,4 @@ And all the contributors: @wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum @eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP @simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485 -@Gregwar @ycheng517 @quantitative-technologies +@Gregwar @ycheng517 @quantitative-technologies @bcollazo diff --git a/setup.py b/setup.py index 3664bbc..bb53f06 100644 --- a/setup.py +++ b/setup.py @@ -116,7 +116,7 @@ setup( # For render "opencv-python", # For atari games, - "ale-py~=0.7.4", + "ale-py==0.7.4", "autorom[accept-rom-license]~=0.4.2", "pillow", # Tensorboard support diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index 0d05b4c..346cc02 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -19,7 +19,7 @@ class PPO(OnPolicyAlgorithm): Paper: https://arxiv.org/abs/1707.06347 Code: This implementation borrows code from OpenAI Spinning Up (https://github.com/openai/spinningup/) https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and - and Stable Baselines (PPO2 from https://github.com/hill-a/stable-baselines) + Stable Baselines (PPO2 from https://github.com/hill-a/stable-baselines) Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html From 061841a314c361d342442d693af85f6abd838003 Mon Sep 17 00:00:00 2001 From: code-review-doctor <72647856+code-review-doctor@users.noreply.github.com> Date: Sun, 24 Apr 2022 08:19:06 +0100 Subject: [PATCH 07/33] Missing f prefix on f-strings (#882) --- stable_baselines3/common/vec_env/base_vec_env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index d3e624a..9870605 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -305,7 +305,7 @@ class VecEnvWrapper(VecEnv): own_class = f"{type(self).__module__}.{type(self).__name__}" error_str = ( f"Error: Recursive attribute lookup for {name} from {own_class} is " - "ambiguous and hides attribute from {blocked_class}" + f"ambiguous and hides attribute from {blocked_class}" ) raise AttributeError(error_str) From a6f5049a99a4c21a6f0bcce458ca3306cef310e0 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 25 Apr 2022 12:01:38 +0200 Subject: [PATCH 08/33] Upgrade code to Python 3.7+ syntax using `pyupgrade` (#887) * Upgrade code to Python 3.7+ syntax * Update changelog --- docs/conf.py | 3 +-- docs/misc/changelog.rst | 3 ++- setup.py | 2 +- stable_baselines3/__init__.py | 2 +- stable_baselines3/a2c/a2c.py | 4 ++-- stable_baselines3/common/atari_wrappers.py | 2 +- stable_baselines3/common/buffers.py | 8 +++---- stable_baselines3/common/callbacks.py | 22 +++++++++---------- stable_baselines3/common/distributions.py | 22 +++++++++---------- .../common/envs/bit_flipping_env.py | 6 ++--- .../common/envs/multi_input_envs.py | 2 +- stable_baselines3/common/logger.py | 16 +++++++------- stable_baselines3/common/monitor.py | 6 ++--- stable_baselines3/common/noise.py | 6 ++--- .../common/off_policy_algorithm.py | 2 +- .../common/on_policy_algorithm.py | 2 +- stable_baselines3/common/policies.py | 10 ++++----- stable_baselines3/common/running_mean_std.py | 2 +- .../common/sb2_compat/rmsprop_tf_like.py | 14 ++++++------ stable_baselines3/common/torch_layers.py | 10 ++++----- .../common/vec_env/stacked_observations.py | 2 +- .../common/vec_env/subproc_vec_env.py | 2 +- stable_baselines3/common/vec_env/util.py | 2 +- .../common/vec_env/vec_transpose.py | 4 ++-- stable_baselines3/ddpg/ddpg.py | 4 ++-- stable_baselines3/dqn/dqn.py | 8 +++---- stable_baselines3/dqn/policies.py | 8 +++---- stable_baselines3/her/her_replay_buffer.py | 2 +- stable_baselines3/ppo/ppo.py | 6 ++--- stable_baselines3/sac/policies.py | 8 +++---- stable_baselines3/sac/sac.py | 10 ++++----- stable_baselines3/td3/policies.py | 8 +++---- stable_baselines3/td3/td3.py | 10 ++++----- stable_baselines3/version.txt | 2 +- tests/test_gae.py | 6 ++--- tests/test_her.py | 2 +- tests/test_monitor.py | 8 +++---- tests/test_save_load.py | 6 ++--- tests/test_spaces.py | 4 ++-- tests/test_train_eval_mode.py | 2 +- tests/test_utils.py | 2 +- tests/test_vec_check_nan.py | 2 +- tests/test_vec_monitor.py | 4 ++-- tests/test_vec_normalize.py | 2 +- 44 files changed, 129 insertions(+), 129 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 088f8a0..712908e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Configuration file for the Sphinx documentation builder. # @@ -46,7 +45,7 @@ sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES) # Read version from file version_file = os.path.join(os.path.dirname(__file__), "../stable_baselines3", "version.txt") -with open(version_file, "r") as file_handler: +with open(version_file) as file_handler: __version__ = file_handler.read().strip() # -- Project information ----------------------------------------------------- diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index da33766..409b672 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.5.1a4 (WIP) +Release 1.5.1a5 (WIP) --------------------------- Breaking Changes: @@ -31,6 +31,7 @@ Deprecations: Others: ^^^^^^^ +- Upgraded to Python 3.7+ syntax using ``pyupgrade`` Documentation: ^^^^^^^^^^^^^^ diff --git a/setup.py b/setup.py index bb53f06..cd8f209 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ import os from setuptools import find_packages, setup -with open(os.path.join("stable_baselines3", "version.txt"), "r") as file_handler: +with open(os.path.join("stable_baselines3", "version.txt")) as file_handler: __version__ = file_handler.read().strip() diff --git a/stable_baselines3/__init__.py b/stable_baselines3/__init__.py index 4e31c5b..d73f5f0 100644 --- a/stable_baselines3/__init__.py +++ b/stable_baselines3/__init__.py @@ -11,7 +11,7 @@ from stable_baselines3.td3 import TD3 # Read version from file version_file = os.path.join(os.path.dirname(__file__), "version.txt") -with open(version_file, "r") as file_handler: +with open(version_file) as file_handler: __version__ = file_handler.read().strip() diff --git a/stable_baselines3/a2c/a2c.py b/stable_baselines3/a2c/a2c.py index eeeb670..13adf68 100644 --- a/stable_baselines3/a2c/a2c.py +++ b/stable_baselines3/a2c/a2c.py @@ -82,7 +82,7 @@ class A2C(OnPolicyAlgorithm): _init_setup_model: bool = True, ): - super(A2C, self).__init__( + super().__init__( policy, env, learning_rate=learning_rate, @@ -194,7 +194,7 @@ class A2C(OnPolicyAlgorithm): reset_num_timesteps: bool = True, ) -> "A2C": - return super(A2C, self).learn( + return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index 832ad9f..a9b2eca 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -245,4 +245,4 @@ class AtariWrapper(gym.Wrapper): if clip_reward: env = ClipRewardEnv(env) - super(AtariWrapper, self).__init__(env) + super().__init__(env) diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index bba2272..d7728cb 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -42,7 +42,7 @@ class BaseBuffer(ABC): device: Union[th.device, str] = "cpu", n_envs: int = 1, ): - super(BaseBuffer, self).__init__() + super().__init__() self.buffer_size = buffer_size self.observation_space = observation_space self.action_space = action_space @@ -179,7 +179,7 @@ class ReplayBuffer(BaseBuffer): optimize_memory_usage: bool = False, handle_timeout_termination: bool = True, ): - super(ReplayBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) + super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) # Adjust buffer size self.buffer_size = max(buffer_size // n_envs, 1) @@ -339,7 +339,7 @@ class RolloutBuffer(BaseBuffer): n_envs: int = 1, ): - super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) + super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) self.gae_lambda = gae_lambda self.gamma = gamma self.observations, self.actions, self.rewards, self.advantages = None, None, None, None @@ -358,7 +358,7 @@ class RolloutBuffer(BaseBuffer): self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.generator_ready = False - super(RolloutBuffer, self).reset() + super().reset() def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None: """ diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py index 27ce5e6..c5f297c 100644 --- a/stable_baselines3/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -19,7 +19,7 @@ class BaseCallback(ABC): """ def __init__(self, verbose: int = 0): - super(BaseCallback, self).__init__() + super().__init__() # The RL model self.model = None # type: Optional[base_class.BaseAlgorithm] # An alias for self.model.get_env(), the environment used for training @@ -127,14 +127,14 @@ class EventCallback(BaseCallback): """ def __init__(self, callback: Optional[BaseCallback] = None, verbose: int = 0): - super(EventCallback, self).__init__(verbose=verbose) + super().__init__(verbose=verbose) self.callback = callback # Give access to the parent if callback is not None: self.callback.parent = self def init_callback(self, model: "base_class.BaseAlgorithm") -> None: - super(EventCallback, self).init_callback(model) + super().init_callback(model) if self.callback is not None: self.callback.init_callback(self.model) @@ -169,7 +169,7 @@ class CallbackList(BaseCallback): """ def __init__(self, callbacks: List[BaseCallback]): - super(CallbackList, self).__init__() + super().__init__() assert isinstance(callbacks, list) self.callbacks = callbacks @@ -228,7 +228,7 @@ class CheckpointCallback(BaseCallback): """ def __init__(self, save_freq: int, save_path: str, name_prefix: str = "rl_model", verbose: int = 0): - super(CheckpointCallback, self).__init__(verbose) + super().__init__(verbose) self.save_freq = save_freq self.save_path = save_path self.name_prefix = name_prefix @@ -256,7 +256,7 @@ class ConvertCallback(BaseCallback): """ def __init__(self, callback: Callable[[Dict[str, Any], Dict[str, Any]], bool], verbose: int = 0): - super(ConvertCallback, self).__init__(verbose) + super().__init__(verbose) self.callback = callback def _on_step(self) -> bool: @@ -307,7 +307,7 @@ class EvalCallback(EventCallback): verbose: int = 1, warn: bool = True, ): - super(EvalCallback, self).__init__(callback_after_eval, verbose=verbose) + super().__init__(callback_after_eval, verbose=verbose) self.callback_on_new_best = callback_on_new_best if self.callback_on_new_best is not None: @@ -480,7 +480,7 @@ class StopTrainingOnRewardThreshold(BaseCallback): """ def __init__(self, reward_threshold: float, verbose: int = 0): - super(StopTrainingOnRewardThreshold, self).__init__(verbose=verbose) + super().__init__(verbose=verbose) self.reward_threshold = reward_threshold def _on_step(self) -> bool: @@ -505,7 +505,7 @@ class EveryNTimesteps(EventCallback): """ def __init__(self, n_steps: int, callback: BaseCallback): - super(EveryNTimesteps, self).__init__(callback) + super().__init__(callback) self.n_steps = n_steps self.last_time_trigger = 0 @@ -528,7 +528,7 @@ class StopTrainingOnMaxEpisodes(BaseCallback): """ def __init__(self, max_episodes: int, verbose: int = 0): - super(StopTrainingOnMaxEpisodes, self).__init__(verbose=verbose) + super().__init__(verbose=verbose) self.max_episodes = max_episodes self._total_max_episodes = max_episodes self.n_episodes = 0 @@ -573,7 +573,7 @@ class StopTrainingOnNoModelImprovement(BaseCallback): """ def __init__(self, max_no_improvement_evals: int, min_evals: int = 0, verbose: int = 0): - super(StopTrainingOnNoModelImprovement, self).__init__(verbose=verbose) + super().__init__(verbose=verbose) self.max_no_improvement_evals = max_no_improvement_evals self.min_evals = min_evals self.last_best_mean_reward = -np.inf diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index 1c0e54a..3d1ff5a 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -16,7 +16,7 @@ class Distribution(ABC): """Abstract base class for distributions.""" def __init__(self): - super(Distribution, self).__init__() + super().__init__() self.distribution = None @abstractmethod @@ -120,7 +120,7 @@ class DiagGaussianDistribution(Distribution): """ def __init__(self, action_dim: int): - super(DiagGaussianDistribution, self).__init__() + super().__init__() self.action_dim = action_dim self.mean_actions = None self.log_std = None @@ -201,13 +201,13 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution): """ def __init__(self, action_dim: int, epsilon: float = 1e-6): - super(SquashedDiagGaussianDistribution, self).__init__(action_dim) + super().__init__(action_dim) # Avoid NaN (prevents division by zero or log of zero) self.epsilon = epsilon self.gaussian_actions = None def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "SquashedDiagGaussianDistribution": - super(SquashedDiagGaussianDistribution, self).proba_distribution(mean_actions, log_std) + super().proba_distribution(mean_actions, log_std) return self def log_prob(self, actions: th.Tensor, gaussian_actions: Optional[th.Tensor] = None) -> th.Tensor: @@ -219,7 +219,7 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution): gaussian_actions = TanhBijector.inverse(actions) # Log likelihood for a Gaussian distribution - log_prob = super(SquashedDiagGaussianDistribution, self).log_prob(gaussian_actions) + log_prob = super().log_prob(gaussian_actions) # Squash correction (from original SAC implementation) # this comes from the fact that tanh is bijective and differentiable log_prob -= th.sum(th.log(1 - actions**2 + self.epsilon), dim=1) @@ -254,7 +254,7 @@ class CategoricalDistribution(Distribution): """ def __init__(self, action_dim: int): - super(CategoricalDistribution, self).__init__() + super().__init__() self.action_dim = action_dim def proba_distribution_net(self, latent_dim: int) -> nn.Module: @@ -305,7 +305,7 @@ class MultiCategoricalDistribution(Distribution): """ def __init__(self, action_dims: List[int]): - super(MultiCategoricalDistribution, self).__init__() + super().__init__() self.action_dims = action_dims def proba_distribution_net(self, latent_dim: int) -> nn.Module: @@ -360,7 +360,7 @@ class BernoulliDistribution(Distribution): """ def __init__(self, action_dims: int): - super(BernoulliDistribution, self).__init__() + super().__init__() self.action_dims = action_dims def proba_distribution_net(self, latent_dim: int) -> nn.Module: @@ -433,7 +433,7 @@ class StateDependentNoiseDistribution(Distribution): learn_features: bool = False, epsilon: float = 1e-6, ): - super(StateDependentNoiseDistribution, self).__init__() + super().__init__() self.action_dim = action_dim self.latent_sde_dim = None self.mean_actions = None @@ -597,7 +597,7 @@ class StateDependentNoiseDistribution(Distribution): return actions, log_prob -class TanhBijector(object): +class TanhBijector: """ Bijective transformation of a probability distribution using a squashing function (tanh) @@ -607,7 +607,7 @@ class TanhBijector(object): """ def __init__(self, epsilon: float = 1e-6): - super(TanhBijector, self).__init__() + super().__init__() self.epsilon = epsilon @staticmethod diff --git a/stable_baselines3/common/envs/bit_flipping_env.py b/stable_baselines3/common/envs/bit_flipping_env.py index c5d713a..a881b32 100644 --- a/stable_baselines3/common/envs/bit_flipping_env.py +++ b/stable_baselines3/common/envs/bit_flipping_env.py @@ -36,7 +36,7 @@ class BitFlippingEnv(GoalEnv): image_obs_space: bool = False, channel_first: bool = True, ): - super(BitFlippingEnv, self).__init__() + super().__init__() # Shape of the observation when using image space self.image_shape = (1, 36, 36) if channel_first else (36, 36, 1) # The achieved goal is determined by the current state @@ -115,7 +115,7 @@ class BitFlippingEnv(GoalEnv): if self.discrete_obs_space: # The internal state is the binary representation of the # observed one - return int(sum([state[i] * 2**i for i in range(len(state))])) + return int(sum(state[i] * 2**i for i in range(len(state)))) if self.image_obs_space: size = np.prod(self.image_shape) @@ -135,7 +135,7 @@ class BitFlippingEnv(GoalEnv): if isinstance(state, int): state = np.array(state).reshape(batch_size, -1) # Convert to binary representation - state = (((state[:, :] & (1 << np.arange(len(self.state))))) > 0).astype(int) + state = ((state[:, :] & (1 << np.arange(len(self.state)))) > 0).astype(int) elif self.image_obs_space: state = state.reshape(batch_size, -1)[:, : len(self.state)] / 255 else: diff --git a/stable_baselines3/common/envs/multi_input_envs.py b/stable_baselines3/common/envs/multi_input_envs.py index 177a641..2e5f13f 100644 --- a/stable_baselines3/common/envs/multi_input_envs.py +++ b/stable_baselines3/common/envs/multi_input_envs.py @@ -42,7 +42,7 @@ class SimpleMultiObsEnv(gym.Env): discrete_actions: bool = True, channel_last: bool = True, ): - super(SimpleMultiObsEnv, self).__init__() + super().__init__() self.vector_size = 5 if channel_last: diff --git a/stable_baselines3/common/logger.py b/stable_baselines3/common/logger.py index 6493a3e..7cc3d0a 100644 --- a/stable_baselines3/common/logger.py +++ b/stable_baselines3/common/logger.py @@ -24,7 +24,7 @@ ERROR = 40 DISABLED = 50 -class Video(object): +class Video: """ Video data class storing the video frames and the frame per seconds @@ -37,7 +37,7 @@ class Video(object): self.fps = fps -class Figure(object): +class Figure: """ Figure data class storing a matplotlib figure and whether to close the figure after logging it @@ -50,7 +50,7 @@ class Figure(object): self.close = close -class Image(object): +class Image: """ Image data class storing an image and data format @@ -80,13 +80,13 @@ class FormatUnsupportedError(NotImplementedError): format_str = f"formats {', '.join(unsupported_formats)} are" else: format_str = f"format {unsupported_formats[0]} is" - super(FormatUnsupportedError, self).__init__( + super().__init__( f"The {format_str} not supported for the {value_description} value logged.\n" f"You can exclude formats via the `exclude` parameter of the logger's `record` function." ) -class KVWriter(object): +class KVWriter: """ Key Value writer """ @@ -108,7 +108,7 @@ class KVWriter(object): raise NotImplementedError -class SeqWriter(object): +class SeqWriter: """ sequence writer """ @@ -427,7 +427,7 @@ def make_output_format(_format: str, log_dir: str, log_suffix: str = "") -> KVWr # ================================================================ -class Logger(object): +class Logger: """ The logger class. @@ -623,7 +623,7 @@ def read_json(filename: str) -> pandas.DataFrame: :return: the data in the json """ data = [] - with open(filename, "rt") as file_handler: + with open(filename) as file_handler: for line in file_handler: data.append(json.loads(line)) return pandas.DataFrame(data) diff --git a/stable_baselines3/common/monitor.py b/stable_baselines3/common/monitor.py index 04cda22..a482b72 100644 --- a/stable_baselines3/common/monitor.py +++ b/stable_baselines3/common/monitor.py @@ -36,7 +36,7 @@ class Monitor(gym.Wrapper): reset_keywords: Tuple[str, ...] = (), info_keywords: Tuple[str, ...] = (), ): - super(Monitor, self).__init__(env=env) + super().__init__(env=env) self.t_start = time.time() if filename is not None: self.results_writer = ResultsWriter( @@ -110,7 +110,7 @@ class Monitor(gym.Wrapper): """ Closes the environment """ - super(Monitor, self).close() + super().close() if self.results_writer is not None: self.results_writer.close() @@ -224,7 +224,7 @@ def load_results(path: str) -> pandas.DataFrame: raise LoadMonitorResultsError(f"No monitor files of the form *{Monitor.EXT} found in {path}") data_frames, headers = [], [] for file_name in monitor_files: - with open(file_name, "rt") as file_handler: + with open(file_name) as file_handler: first_line = file_handler.readline() assert first_line[0] == "#" header = json.loads(first_line[1:]) diff --git a/stable_baselines3/common/noise.py b/stable_baselines3/common/noise.py index b1db6f4..119ed36 100644 --- a/stable_baselines3/common/noise.py +++ b/stable_baselines3/common/noise.py @@ -11,7 +11,7 @@ class ActionNoise(ABC): """ def __init__(self): - super(ActionNoise, self).__init__() + super().__init__() def reset(self) -> None: """ @@ -35,7 +35,7 @@ class NormalActionNoise(ActionNoise): def __init__(self, mean: np.ndarray, sigma: np.ndarray): self._mu = mean self._sigma = sigma - super(NormalActionNoise, self).__init__() + super().__init__() def __call__(self) -> np.ndarray: return np.random.normal(self._mu, self._sigma) @@ -72,7 +72,7 @@ class OrnsteinUhlenbeckActionNoise(ActionNoise): self.initial_noise = initial_noise self.noise_prev = np.zeros_like(self._mu) self.reset() - super(OrnsteinUhlenbeckActionNoise, self).__init__() + super().__init__() def __call__(self) -> np.ndarray: noise = ( diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index 5905dee..ca57166 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -102,7 +102,7 @@ class OffPolicyAlgorithm(BaseAlgorithm): supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None, ): - super(OffPolicyAlgorithm, self).__init__( + super().__init__( policy=policy, env=env, learning_rate=learning_rate, diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 281758c..763c108 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -72,7 +72,7 @@ class OnPolicyAlgorithm(BaseAlgorithm): supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None, ): - super(OnPolicyAlgorithm, self).__init__( + super().__init__( policy=policy, env=env, learning_rate=learning_rate, diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index c322dc6..51a3d37 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -67,7 +67,7 @@ class BaseModel(nn.Module, ABC): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, ): - super(BaseModel, self).__init__() + super().__init__() if optimizer_kwargs is None: optimizer_kwargs = {} @@ -267,7 +267,7 @@ class BasePolicy(BaseModel): """ def __init__(self, *args, squash_output: bool = False, **kwargs): - super(BasePolicy, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self._squash_output = squash_output @staticmethod @@ -437,7 +437,7 @@ class ActorCriticPolicy(BasePolicy): if optimizer_class == th.optim.Adam: optimizer_kwargs["eps"] = 1e-5 - super(ActorCriticPolicy, self).__init__( + super().__init__( observation_space, action_space, features_extractor_class, @@ -724,7 +724,7 @@ class ActorCriticCnnPolicy(ActorCriticPolicy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, ): - super(ActorCriticCnnPolicy, self).__init__( + super().__init__( observation_space, action_space, lr_schedule, @@ -799,7 +799,7 @@ class MultiInputActorCriticPolicy(ActorCriticPolicy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, ): - super(MultiInputActorCriticPolicy, self).__init__( + super().__init__( observation_space, action_space, lr_schedule, diff --git a/stable_baselines3/common/running_mean_std.py b/stable_baselines3/common/running_mean_std.py index fb3ae8b..b48f922 100644 --- a/stable_baselines3/common/running_mean_std.py +++ b/stable_baselines3/common/running_mean_std.py @@ -3,7 +3,7 @@ from typing import Tuple, Union import numpy as np -class RunningMeanStd(object): +class RunningMeanStd: def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()): """ Calulates the running mean and std of a data stream diff --git a/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py b/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py index ba70a5f..377b7f6 100644 --- a/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py +++ b/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py @@ -54,21 +54,21 @@ class RMSpropTFLike(Optimizer): centered: bool = False, ): if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) + raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) + raise ValueError(f"Invalid epsilon value: {eps}") if not 0.0 <= momentum: - raise ValueError("Invalid momentum value: {}".format(momentum)) + raise ValueError(f"Invalid momentum value: {momentum}") if not 0.0 <= weight_decay: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError(f"Invalid weight_decay value: {weight_decay}") if not 0.0 <= alpha: - raise ValueError("Invalid alpha value: {}".format(alpha)) + raise ValueError(f"Invalid alpha value: {alpha}") defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay) - super(RMSpropTFLike, self).__init__(params, defaults) + super().__init__(params, defaults) def __setstate__(self, state: Dict[str, Any]) -> None: - super(RMSpropTFLike, self).__setstate__(state) + super().__setstate__(state) for group in self.param_groups: group.setdefault("momentum", 0) group.setdefault("centered", False) diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index 589d12e..8fd2237 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -19,7 +19,7 @@ class BaseFeaturesExtractor(nn.Module): """ def __init__(self, observation_space: gym.Space, features_dim: int = 0): - super(BaseFeaturesExtractor, self).__init__() + super().__init__() assert features_dim > 0 self._observation_space = observation_space self._features_dim = features_dim @@ -41,7 +41,7 @@ class FlattenExtractor(BaseFeaturesExtractor): """ def __init__(self, observation_space: gym.Space): - super(FlattenExtractor, self).__init__(observation_space, get_flattened_obs_dim(observation_space)) + super().__init__(observation_space, get_flattened_obs_dim(observation_space)) self.flatten = nn.Flatten() def forward(self, observations: th.Tensor) -> th.Tensor: @@ -61,7 +61,7 @@ class NatureCNN(BaseFeaturesExtractor): """ def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512): - super(NatureCNN, self).__init__(observation_space, features_dim) + super().__init__(observation_space, features_dim) # We assume CxHxW images (channels first) # Re-ordering will be done by pre-preprocessing or wrapper assert is_image_space(observation_space, check_channels=False), ( @@ -169,7 +169,7 @@ class MlpExtractor(nn.Module): activation_fn: Type[nn.Module], device: Union[th.device, str] = "auto", ): - super(MlpExtractor, self).__init__() + super().__init__() device = get_device(device) shared_net, policy_net, value_net = [], [], [] policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network @@ -250,7 +250,7 @@ class CombinedExtractor(BaseFeaturesExtractor): def __init__(self, observation_space: gym.spaces.Dict, cnn_output_dim: int = 256): # TODO we do not know features-dim here before going over all the items, so put something there. This is dirty! - super(CombinedExtractor, self).__init__(observation_space, features_dim=1) + super().__init__(observation_space, features_dim=1) extractors = {} diff --git a/stable_baselines3/common/vec_env/stacked_observations.py b/stable_baselines3/common/vec_env/stacked_observations.py index affd775..733b728 100644 --- a/stable_baselines3/common/vec_env/stacked_observations.py +++ b/stable_baselines3/common/vec_env/stacked_observations.py @@ -7,7 +7,7 @@ from gym import spaces from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first -class StackedObservations(object): +class StackedObservations: """ Frame stacking wrapper for data. diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index 04f5d0c..f723c71 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -217,6 +217,6 @@ def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: gym.space elif isinstance(space, gym.spaces.Tuple): assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space" obs_len = len(space.spaces) - return tuple((np.stack([o[i] for o in obs]) for i in range(obs_len))) + return tuple(np.stack([o[i] for o in obs]) for i in range(obs_len)) else: return np.stack(obs) diff --git a/stable_baselines3/common/vec_env/util.py b/stable_baselines3/common/vec_env/util.py index 859f1ec..ca590cb 100644 --- a/stable_baselines3/common/vec_env/util.py +++ b/stable_baselines3/common/vec_env/util.py @@ -37,7 +37,7 @@ def dict_to_obs(obs_space: gym.spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> return obs_dict elif isinstance(obs_space, gym.spaces.Tuple): assert len(obs_dict) == len(obs_space.spaces), "size of observation does not match size of observation space" - return tuple((obs_dict[i] for i in range(len(obs_space.spaces)))) + return tuple(obs_dict[i] for i in range(len(obs_space.spaces))) else: assert set(obs_dict.keys()) == {None}, "multiple observation keys for unstructured observation space" return obs_dict[None] diff --git a/stable_baselines3/common/vec_env/vec_transpose.py b/stable_baselines3/common/vec_env/vec_transpose.py index e6f728b..b6b0ad8 100644 --- a/stable_baselines3/common/vec_env/vec_transpose.py +++ b/stable_baselines3/common/vec_env/vec_transpose.py @@ -26,7 +26,7 @@ class VecTransposeImage(VecEnvWrapper): self.skip = skip # Do nothing if skip: - super(VecTransposeImage, self).__init__(venv) + super().__init__(venv) return if isinstance(venv.observation_space, spaces.dict.Dict): @@ -39,7 +39,7 @@ class VecTransposeImage(VecEnvWrapper): observation_space.spaces[key] = self.transpose_space(space, key) else: observation_space = self.transpose_space(venv.observation_space) - super(VecTransposeImage, self).__init__(venv, observation_space=observation_space) + super().__init__(venv, observation_space=observation_space) @staticmethod def transpose_space(observation_space: spaces.Box, key: str = "") -> spaces.Box: diff --git a/stable_baselines3/ddpg/ddpg.py b/stable_baselines3/ddpg/ddpg.py index 14293ca..53d3fb6 100644 --- a/stable_baselines3/ddpg/ddpg.py +++ b/stable_baselines3/ddpg/ddpg.py @@ -78,7 +78,7 @@ class DDPG(TD3): _init_setup_model: bool = True, ): - super(DDPG, self).__init__( + super().__init__( policy=policy, env=env, learning_rate=learning_rate, @@ -127,7 +127,7 @@ class DDPG(TD3): reset_num_timesteps: bool = True, ) -> OffPolicyAlgorithm: - return super(DDPG, self).learn( + return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index ed6073b..fe8f398 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -95,7 +95,7 @@ class DQN(OffPolicyAlgorithm): _init_setup_model: bool = True, ): - super(DQN, self).__init__( + super().__init__( policy, env, learning_rate, @@ -138,7 +138,7 @@ class DQN(OffPolicyAlgorithm): self._setup_model() def _setup_model(self) -> None: - super(DQN, self)._setup_model() + super()._setup_model() self._create_aliases() self.exploration_schedule = get_linear_fn( self.exploration_initial_eps, @@ -261,7 +261,7 @@ class DQN(OffPolicyAlgorithm): reset_num_timesteps: bool = True, ) -> OffPolicyAlgorithm: - return super(DQN, self).learn( + return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, @@ -274,7 +274,7 @@ class DQN(OffPolicyAlgorithm): ) def _excluded_save_params(self) -> List[str]: - return super(DQN, self)._excluded_save_params() + ["q_net", "q_net_target"] + return super()._excluded_save_params() + ["q_net", "q_net_target"] def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: state_dicts = ["policy", "policy.optimizer"] diff --git a/stable_baselines3/dqn/policies.py b/stable_baselines3/dqn/policies.py index ea00b5c..ed3497c 100644 --- a/stable_baselines3/dqn/policies.py +++ b/stable_baselines3/dqn/policies.py @@ -37,7 +37,7 @@ class QNetwork(BasePolicy): activation_fn: Type[nn.Module] = nn.ReLU, normalize_images: bool = True, ): - super(QNetwork, self).__init__( + super().__init__( observation_space, action_space, features_extractor=features_extractor, @@ -118,7 +118,7 @@ class DQNPolicy(BasePolicy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, ): - super(DQNPolicy, self).__init__( + super().__init__( observation_space, action_space, features_extractor_class, @@ -239,7 +239,7 @@ class CnnPolicy(DQNPolicy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, ): - super(CnnPolicy, self).__init__( + super().__init__( observation_space, action_space, lr_schedule, @@ -284,7 +284,7 @@ class MultiInputPolicy(DQNPolicy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, ): - super(MultiInputPolicy, self).__init__( + super().__init__( observation_space, action_space, lr_schedule, diff --git a/stable_baselines3/her/her_replay_buffer.py b/stable_baselines3/her/her_replay_buffer.py index f61a786..c461d19 100644 --- a/stable_baselines3/her/her_replay_buffer.py +++ b/stable_baselines3/her/her_replay_buffer.py @@ -82,7 +82,7 @@ class HerReplayBuffer(DictReplayBuffer): handle_timeout_termination: bool = True, ): - super(HerReplayBuffer, self).__init__(buffer_size, env.observation_space, env.action_space, device, env.num_envs) + super().__init__(buffer_size, env.observation_space, env.action_space, device, env.num_envs) # convert goal_selection_strategy into GoalSelectionStrategy if string if isinstance(goal_selection_strategy, str): diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index 346cc02..5b8d9e2 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -99,7 +99,7 @@ class PPO(OnPolicyAlgorithm): _init_setup_model: bool = True, ): - super(PPO, self).__init__( + super().__init__( policy, env, learning_rate=learning_rate, @@ -162,7 +162,7 @@ class PPO(OnPolicyAlgorithm): self._setup_model() def _setup_model(self) -> None: - super(PPO, self)._setup_model() + super()._setup_model() # Initialize schedules for policy/value clipping self.clip_range = get_schedule_fn(self.clip_range) @@ -307,7 +307,7 @@ class PPO(OnPolicyAlgorithm): reset_num_timesteps: bool = True, ) -> "PPO": - return super(PPO, self).learn( + return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py index cb6a61c..6fcbea1 100644 --- a/stable_baselines3/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -65,7 +65,7 @@ class Actor(BasePolicy): clip_mean: float = 2.0, normalize_images: bool = True, ): - super(Actor, self).__init__( + super().__init__( observation_space, action_space, features_extractor=features_extractor, @@ -237,7 +237,7 @@ class SACPolicy(BasePolicy): n_critics: int = 2, share_features_extractor: bool = True, ): - super(SACPolicy, self).__init__( + super().__init__( observation_space, action_space, features_extractor_class, @@ -424,7 +424,7 @@ class CnnPolicy(SACPolicy): n_critics: int = 2, share_features_extractor: bool = True, ): - super(CnnPolicy, self).__init__( + super().__init__( observation_space, action_space, lr_schedule, @@ -495,7 +495,7 @@ class MultiInputPolicy(SACPolicy): n_critics: int = 2, share_features_extractor: bool = True, ): - super(MultiInputPolicy, self).__init__( + super().__init__( observation_space, action_space, lr_schedule, diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index 3703b73..07f88d9 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -110,7 +110,7 @@ class SAC(OffPolicyAlgorithm): _init_setup_model: bool = True, ): - super(SAC, self).__init__( + super().__init__( policy, env, learning_rate, @@ -150,7 +150,7 @@ class SAC(OffPolicyAlgorithm): self._setup_model() def _setup_model(self) -> None: - super(SAC, self)._setup_model() + super()._setup_model() self._create_aliases() # Target entropy is used when learning the entropy coefficient if self.target_entropy == "auto": @@ -248,7 +248,7 @@ class SAC(OffPolicyAlgorithm): current_q_values = self.critic(replay_data.observations, replay_data.actions) # Compute critic loss - critic_loss = 0.5 * sum([F.mse_loss(current_q, target_q_values) for current_q in current_q_values]) + critic_loss = 0.5 * sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values) critic_losses.append(critic_loss.item()) # Optimize the critic @@ -295,7 +295,7 @@ class SAC(OffPolicyAlgorithm): reset_num_timesteps: bool = True, ) -> OffPolicyAlgorithm: - return super(SAC, self).learn( + return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, @@ -308,7 +308,7 @@ class SAC(OffPolicyAlgorithm): ) def _excluded_save_params(self) -> List[str]: - return super(SAC, self)._excluded_save_params() + ["actor", "critic", "critic_target"] + return super()._excluded_save_params() + ["actor", "critic", "critic_target"] def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: state_dicts = ["policy", "actor.optimizer", "critic.optimizer"] diff --git a/stable_baselines3/td3/policies.py b/stable_baselines3/td3/policies.py index ce91a0f..f3ed530 100644 --- a/stable_baselines3/td3/policies.py +++ b/stable_baselines3/td3/policies.py @@ -42,7 +42,7 @@ class Actor(BasePolicy): activation_fn: Type[nn.Module] = nn.ReLU, normalize_images: bool = True, ): - super(Actor, self).__init__( + super().__init__( observation_space, action_space, features_extractor=features_extractor, @@ -121,7 +121,7 @@ class TD3Policy(BasePolicy): n_critics: int = 2, share_features_extractor: bool = True, ): - super(TD3Policy, self).__init__( + super().__init__( observation_space, action_space, features_extractor_class, @@ -283,7 +283,7 @@ class CnnPolicy(TD3Policy): n_critics: int = 2, share_features_extractor: bool = True, ): - super(CnnPolicy, self).__init__( + super().__init__( observation_space, action_space, lr_schedule, @@ -337,7 +337,7 @@ class MultiInputPolicy(TD3Policy): n_critics: int = 2, share_features_extractor: bool = True, ): - super(MultiInputPolicy, self).__init__( + super().__init__( observation_space, action_space, lr_schedule, diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index d31720b..34a783d 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -95,7 +95,7 @@ class TD3(OffPolicyAlgorithm): _init_setup_model: bool = True, ): - super(TD3, self).__init__( + super().__init__( policy, env, learning_rate, @@ -129,7 +129,7 @@ class TD3(OffPolicyAlgorithm): self._setup_model() def _setup_model(self) -> None: - super(TD3, self)._setup_model() + super()._setup_model() self._create_aliases() def _create_aliases(self) -> None: @@ -168,7 +168,7 @@ class TD3(OffPolicyAlgorithm): current_q_values = self.critic(replay_data.observations, replay_data.actions) # Compute critic loss - critic_loss = sum([F.mse_loss(current_q, target_q_values) for current_q in current_q_values]) + critic_loss = sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values) critic_losses.append(critic_loss.item()) # Optimize the critics @@ -208,7 +208,7 @@ class TD3(OffPolicyAlgorithm): reset_num_timesteps: bool = True, ) -> OffPolicyAlgorithm: - return super(TD3, self).learn( + return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, @@ -221,7 +221,7 @@ class TD3(OffPolicyAlgorithm): ) def _excluded_save_params(self) -> List[str]: - return super(TD3, self)._excluded_save_params() + ["actor", "critic", "actor_target", "critic_target"] + return super()._excluded_save_params() + ["actor", "critic", "actor_target", "critic_target"] def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: state_dicts = ["policy", "actor.optimizer", "critic.optimizer"] diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index d6a9f8c..bccb8c6 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.5.1a4 +1.5.1a5 diff --git a/tests/test_gae.py b/tests/test_gae.py index 54e03b8..8e461ed 100644 --- a/tests/test_gae.py +++ b/tests/test_gae.py @@ -10,7 +10,7 @@ from stable_baselines3.common.policies import ActorCriticPolicy class CustomEnv(gym.Env): def __init__(self, max_steps=8): - super(CustomEnv, self).__init__() + super().__init__() self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) self.max_steps = max_steps @@ -54,7 +54,7 @@ class InfiniteHorizonEnv(gym.Env): class CheckGAECallback(BaseCallback): def __init__(self): - super(CheckGAECallback, self).__init__(verbose=0) + super().__init__(verbose=0) def _on_rollout_end(self): buffer = self.model.rollout_buffer @@ -99,7 +99,7 @@ class CustomPolicy(ActorCriticPolicy): """Custom Policy with a constant value function""" def __init__(self, *args, **kwargs): - super(CustomPolicy, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.constant_value = 0.0 def forward(self, obs, deterministic=False): diff --git a/tests/test_her.py b/tests/test_her.py index 0f6d75f..888d36a 100644 --- a/tests/test_her.py +++ b/tests/test_her.py @@ -156,7 +156,7 @@ def test_save_load(tmp_path, model_class, use_sde, online_sampling): params = deepcopy(model.policy.state_dict()) # Modify all parameters to be random values - random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items()) + random_params = {param_name: th.rand_like(param) for param_name, param in params.items()} # Update model parameters with the new random values model.policy.load_state_dict(random_params) diff --git a/tests/test_monitor.py b/tests/test_monitor.py index d3d041b..4c1d3cf 100644 --- a/tests/test_monitor.py +++ b/tests/test_monitor.py @@ -14,7 +14,7 @@ def test_monitor(tmp_path): """ env = gym.make("CartPole-v1") env.seed(0) - monitor_file = os.path.join(str(tmp_path), "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4())) + monitor_file = os.path.join(str(tmp_path), f"stable_baselines-test-{uuid.uuid4()}.monitor.csv") monitor_env = Monitor(env, monitor_file) monitor_env.reset() total_steps = 1000 @@ -37,7 +37,7 @@ def test_monitor(tmp_path): assert sum(monitor_env.get_episode_rewards()) == sum(ep_rewards) _ = monitor_env.get_episode_times() - with open(monitor_file, "rt") as file_handler: + with open(monitor_file) as file_handler: first_line = file_handler.readline() assert first_line.startswith("#") metadata = json.loads(first_line[1:]) @@ -56,7 +56,7 @@ def test_monitor_load_results(tmp_path): tmp_path = str(tmp_path) env1 = gym.make("CartPole-v1") env1.seed(0) - monitor_file1 = os.path.join(tmp_path, "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4())) + monitor_file1 = os.path.join(tmp_path, f"stable_baselines-test-{uuid.uuid4()}.monitor.csv") monitor_env1 = Monitor(env1, monitor_file1) monitor_files = get_monitor_files(tmp_path) @@ -76,7 +76,7 @@ def test_monitor_load_results(tmp_path): env2 = gym.make("CartPole-v1") env2.seed(0) - monitor_file2 = os.path.join(tmp_path, "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4())) + monitor_file2 = os.path.join(tmp_path, f"stable_baselines-test-{uuid.uuid4()}.monitor.csv") monitor_env2 = Monitor(env2, monitor_file2) monitor_files = get_monitor_files(tmp_path) assert len(monitor_files) == 2 diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 452e6fb..2fdebbe 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -64,7 +64,7 @@ def test_save_load(tmp_path, model_class): model.set_parameters(invalid_object_params, exact_match=False) # Test that exact_match catches when something was missed. - missing_object_params = dict((k, v) for k, v in list(original_params.items())[:-1]) + missing_object_params = {k: v for k, v in list(original_params.items())[:-1]} with pytest.raises(ValueError): model.set_parameters(missing_object_params, exact_match=True) @@ -446,7 +446,7 @@ def test_save_load_policy(tmp_path, model_class, policy_str, use_sde): params = deepcopy(policy.state_dict()) # Modify all parameters to be random values - random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items()) + random_params = {param_name: th.rand_like(param) for param_name, param in params.items()} # Update model parameters with the new random values policy.load_state_dict(random_params) @@ -537,7 +537,7 @@ def test_save_load_q_net(tmp_path, model_class, policy_str): params = deepcopy(q_net.state_dict()) # Modify all parameters to be random values - random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items()) + random_params = {param_name: th.rand_like(param) for param_name, param in params.items()} # Update model parameters with the new random values q_net.load_state_dict(random_params) diff --git a/tests/test_spaces.py b/tests/test_spaces.py index 54994b2..b754042 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -9,7 +9,7 @@ from stable_baselines3.common.evaluation import evaluate_policy class DummyMultiDiscreteSpace(gym.Env): def __init__(self, nvec): - super(DummyMultiDiscreteSpace, self).__init__() + super().__init__() self.observation_space = gym.spaces.MultiDiscrete(nvec) self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) @@ -22,7 +22,7 @@ class DummyMultiDiscreteSpace(gym.Env): class DummyMultiBinary(gym.Env): def __init__(self, n): - super(DummyMultiBinary, self).__init__() + super().__init__() self.observation_space = gym.spaces.MultiBinary(n) self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) diff --git a/tests/test_train_eval_mode.py b/tests/test_train_eval_mode.py index 1ea2efe..4f023e9 100644 --- a/tests/test_train_eval_mode.py +++ b/tests/test_train_eval_mode.py @@ -28,7 +28,7 @@ class FlattenBatchNormDropoutExtractor(BaseFeaturesExtractor): """ def __init__(self, observation_space: gym.Space): - super(FlattenBatchNormDropoutExtractor, self).__init__( + super().__init__( observation_space, get_flattened_obs_dim(observation_space), ) diff --git a/tests/test_utils.py b/tests/test_utils.py index b07bbe9..67f2ad1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -180,7 +180,7 @@ class AlwaysDoneWrapper(gym.Wrapper): # Pretends that environment only has single step for each # episode. def __init__(self, env): - super(AlwaysDoneWrapper, self).__init__(env) + super().__init__(env) self.last_obs = None self.needs_reset = True diff --git a/tests/test_vec_check_nan.py b/tests/test_vec_check_nan.py index 265da2e..9623557 100644 --- a/tests/test_vec_check_nan.py +++ b/tests/test_vec_check_nan.py @@ -12,7 +12,7 @@ class NanAndInfEnv(gym.Env): metadata = {"render.modes": ["human"]} def __init__(self): - super(NanAndInfEnv, self).__init__() + super().__init__() self.action_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64) self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64) diff --git a/tests/test_vec_monitor.py b/tests/test_vec_monitor.py index 974202b..5ccc33e 100644 --- a/tests/test_vec_monitor.py +++ b/tests/test_vec_monitor.py @@ -36,7 +36,7 @@ def test_vec_monitor(tmp_path): monitor_env.close() - with open(monitor_file, "rt") as file_handler: + with open(monitor_file) as file_handler: first_line = file_handler.readline() assert first_line.startswith("#") metadata = json.loads(first_line[1:]) @@ -66,7 +66,7 @@ def test_vec_monitor_info_keywords(tmp_path): monitor_env.close() - with open(monitor_file, "rt") as f: + with open(monitor_file) as f: reader = csv.reader(f) for i, line in enumerate(reader): if i == 0 or i == 1: diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index 07ad77f..86a0d84 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -47,7 +47,7 @@ class DummyDictEnv(gym.GoalEnv): """ def __init__(self): - super(DummyDictEnv, self).__init__() + super().__init__() self.observation_space = spaces.Dict( { "observation": spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32), From c5f0aa5de0a1a8a8b226665cfe45ccb09df353bc Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sun, 1 May 2022 16:26:34 +0200 Subject: [PATCH 09/33] Update doc: PPO blog post and remark on timeouts (#896) --- docs/guide/rl_tips.rst | 19 ++++++++++++++++--- docs/misc/changelog.rst | 2 ++ docs/modules/ppo.rst | 1 + 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/docs/guide/rl_tips.rst b/docs/guide/rl_tips.rst index 031f947..2f093f6 100644 --- a/docs/guide/rl_tips.rst +++ b/docs/guide/rl_tips.rst @@ -183,6 +183,16 @@ Some basic advice: - start with shaped reward (i.e. informative reward) and simplified version of your problem - debug with random actions to check that your environment works and follows the gym interface: +Two important things to keep in mind when creating a custom environment is to avoid breaking Markov assumption +and properly handle termination due to a timeout (maximum number of steps in an episode). +For instance, if there is some time delay between action and observation (e.g. due to wifi communication), you should give an history of observations +as input. + +Termination due to timeout (max number of steps per episode) needs to be handled separately. You should fill the key in the info dict: ``info["TimeLimit.truncated"] = True``. +If you are using the gym ``TimeLimit`` wrapper, this will be done automatically. +You can read `Time Limit in RL `_ or take a look at the `RL Tips and Tricks video `_ +for more details. + We provide a helper to check that your environment runs without error: @@ -241,12 +251,15 @@ We *recommend following those steps to have a working RL algorithm*: 1. Read the original paper several times 2. Read existing implementations (if available) 3. Try to have some "sign of life" on toy problems -4. Validate the implementation by making it run on harder and harder envs (you can compare results against the RL zoo) - You usually need to run hyperparameter optimization for that step. +4. Validate the implementation by making it run on harder and harder envs (you can compare results against the RL zoo). + You usually need to run hyperparameter optimization for that step. -You need to be particularly careful on the shape of the different objects you are manipulating (a broadcast mistake will fail silently cf `issue #75 `_) +You need to be particularly careful on the shape of the different objects you are manipulating (a broadcast mistake will fail silently cf. `issue #75 `_) and when to stop the gradient propagation. +Don't forget to handle termination due to timeout separately (see remark in the custom environment section above), +you can also take a look at `Issue #284 `_ and `Issue #633 `_. + A personal pick (by @araffin) for environments with gradual difficulty in RL with continuous actions: 1. Pendulum (easy to solve) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 409b672..42a1d5a 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -37,6 +37,8 @@ Documentation: ^^^^^^^^^^^^^^ - Added link to gym doc and gym env checker - Fix typo in PPO doc (@bcollazo) +- Added link to PPO ICLR blog post +- Added remark about breaking Markov assumption and timeout handling Release 1.5.0 (2022-03-25) diff --git a/docs/modules/ppo.rst b/docs/modules/ppo.rst index 3aab653..d32986c 100644 --- a/docs/modules/ppo.rst +++ b/docs/modules/ppo.rst @@ -25,6 +25,7 @@ Notes - Clear explanation of PPO on Arxiv Insights channel: https://www.youtube.com/watch?v=5P7I-xPq8u8 - OpenAI blog post: https://blog.openai.com/openai-baselines-ppo/ - Spinning Up guide: https://spinningup.openai.com/en/latest/algorithms/ppo.html +- 37 implementation details blog: https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/ Can I use? From db57cb67e3d404a32d11a8e68822f19541b69621 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 3 May 2022 16:27:27 +0200 Subject: [PATCH 10/33] Fix gitlab coverage report (#901) --- .gitlab-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 45ca8f5..20953d2 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -9,6 +9,7 @@ pytest: - python --version # MKL_THREADING_LAYER=GNU to avoid MKL_THREADING_LAYER=INTEL incompatibility error - MKL_THREADING_LAYER=GNU make pytest + coverage: '/^TOTAL.+?(\d+\%)$/' doc-build: script: From e98ae129de53a5699af2b1119d7151c5434e8df1 Mon Sep 17 00:00:00 2001 From: Marsel Khisamutdinov Date: Tue, 3 May 2022 19:27:48 +0500 Subject: [PATCH 11/33] Fix a grammatical mistake (#899) Co-authored-by: Antonin RAFFIN --- docs/guide/algos.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/guide/algos.rst b/docs/guide/algos.rst index 474a047..5d1bf08 100644 --- a/docs/guide/algos.rst +++ b/docs/guide/algos.rst @@ -26,8 +26,8 @@ Maskable PPO [#f1]_ ❌ ✔️ ✔️ ✔ .. [#f1] Implemented in `SB3 Contrib `_ .. note:: - ``Tuple`` observation spaces are not supported by any environment - however single-level ``Dict`` spaces are (cf. :ref:`Examples `). + ``Tuple`` observation spaces are not supported by any environment, + however, single-level ``Dict`` spaces are (cf. :ref:`Examples `). Actions ``gym.spaces``: From c2518dc160ef83e2d71cc81a35e0a894c8ab9a5e Mon Sep 17 00:00:00 2001 From: Thomas Rudolf <62146721+git-thor@users.noreply.github.com> Date: Sun, 8 May 2022 15:28:31 +0200 Subject: [PATCH 12/33] Add doc to use mlflow logger (#889) * ADD feature for mlflow logger via MLflowOutputFormat. * Move MLFlow integration to doc Co-authored-by: Antonin Raffin --- docs/guide/integrations.rst | 53 ++++++++++++++++++++++++++++++ docs/misc/changelog.rst | 3 +- stable_baselines3/common/logger.py | 35 +++++++++++--------- 3 files changed, 74 insertions(+), 17 deletions(-) diff --git a/docs/guide/integrations.rst b/docs/guide/integrations.rst index 9007ade..7f21bd3 100644 --- a/docs/guide/integrations.rst +++ b/docs/guide/integrations.rst @@ -137,3 +137,56 @@ Then, in this example, we train a PPO agent to play CartPole-v1 and push it to a filename="ppo-CartPole-v1", commit_message="Added Cartpole-v1 model trained with PPO", ) + +MLFLow +====== + +If you want to use `MLFLow `_ to track your SB3 experiments, +you can adapt the following code which defines a custom logger output: + +.. code-block:: python + + import sys + from typing import Any, Dict, Tuple, Union + + import mlflow + import numpy as np + + from stable_baselines3 import SAC + from stable_baselines3.common.logger import HumanOutputFormat, KVWriter, Logger + + + class MLflowOutputFormat(KVWriter): + """ + Dumps key/value pairs into MLflow's numeric format. + """ + + def write( + self, + key_values: Dict[str, Any], + key_excluded: Dict[str, Union[str, Tuple[str, ...]]], + step: int = 0, + ) -> None: + + for (key, value), (_, excluded) in zip( + sorted(key_values.items()), sorted(key_excluded.items()) + ): + + if excluded is not None and "mlflow" in excluded: + continue + + if isinstance(value, np.ScalarType): + if not isinstance(value, str): + mlflow.log_metric(key, value, step) + + + loggers = Logger( + folder=None, + output_formats=[HumanOutputFormat(sys.stdout), MLflowOutputFormat()], + ) + + with mlflow.start_run(): + model = SAC("MlpPolicy", "Pendulum-v1", verbose=2) + # Set custom logger + model.set_logger(loggers) + model.learn(total_timesteps=10000, log_interval=1) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 42a1d5a..cb3a17f 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -39,6 +39,7 @@ Documentation: - Fix typo in PPO doc (@bcollazo) - Added link to PPO ICLR blog post - Added remark about breaking Markov assumption and timeout handling +- Added doc about MLFlow integration via custom logger (@git-thor) Release 1.5.0 (2022-03-25) @@ -968,4 +969,4 @@ And all the contributors: @wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum @eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP @simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485 -@Gregwar @ycheng517 @quantitative-technologies @bcollazo +@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor diff --git a/stable_baselines3/common/logger.py b/stable_baselines3/common/logger.py index 7cc3d0a..1295e5b 100644 --- a/stable_baselines3/common/logger.py +++ b/stable_baselines3/common/logger.py @@ -17,6 +17,7 @@ try: except ImportError: SummaryWriter = None + DEBUG = 10 INFO = 20 WARN = 30 @@ -246,12 +247,13 @@ def filter_excluded_keys( class JSONOutputFormat(KVWriter): - def __init__(self, filename: str): - """ - log to a file, in the JSON format + """ + Log to a file, in the JSON format - :param filename: the file to write the log to - """ + :param filename: the file to write the log to + """ + + def __init__(self, filename: str): self.file = open(filename, "wt") def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None: @@ -287,13 +289,13 @@ class JSONOutputFormat(KVWriter): class CSVOutputFormat(KVWriter): + """ + Log to a file, in a CSV format + + :param filename: the file to write the log to + """ + def __init__(self, filename: str): - """ - log to a file, in a CSV format - - :param filename: the file to write the log to - """ - self.file = open(filename, "w+t") self.keys = [] self.separator = "," @@ -351,12 +353,13 @@ class CSVOutputFormat(KVWriter): class TensorBoardOutputFormat(KVWriter): - def __init__(self, folder: str): - """ - Dumps key/value pairs into TensorBoard's numeric format. + """ + Dumps key/value pairs into TensorBoard's numeric format. - :param folder: the folder to write the log to - """ + :param folder: the folder to write the log to + """ + + def __init__(self, folder: str): assert SummaryWriter is not None, "tensorboard is not installed, you can use " "pip install tensorboard to do so" self.writer = SummaryWriter(log_dir=folder) From 0fadc94df3d658dbfac1b727d5b7febf336e4a36 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sun, 8 May 2022 20:54:34 +0200 Subject: [PATCH 13/33] Fix synchronization bug with EvalCallback (#907) --- docs/misc/changelog.rst | 3 ++- stable_baselines3/common/vec_env/__init__.py | 4 +++- stable_baselines3/version.txt | 2 +- tests/test_vec_normalize.py | 17 ++++++++++++++--- 4 files changed, 20 insertions(+), 6 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index cb3a17f..d8b6d6d 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.5.1a5 (WIP) +Release 1.5.1a6 (WIP) --------------------------- Breaking Changes: @@ -25,6 +25,7 @@ Bug Fixes: - Fixed final goal selection strategy that did not sample the final achieved goal (@qgallouedec) - Fixed a bug with special characters in the tensorboard log name (@quantitative-technologies) - Fixed a bug in ``DummyVecEnv``'s and ``SubprocVecEnv``'s seeding function. None value was unchecked (@ScheiklP) +- Fixed a bug where ``EvalCallback`` would crash when trying to synchronize ``VecNormalize`` stats when observation normalization was disabled Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/vec_env/__init__.py b/stable_baselines3/common/vec_env/__init__.py index 37ebc36..3880fbd 100644 --- a/stable_baselines3/common/vec_env/__init__.py +++ b/stable_baselines3/common/vec_env/__init__.py @@ -66,7 +66,9 @@ def sync_envs_normalization(env: "GymEnv", eval_env: "GymEnv") -> None: env_tmp, eval_env_tmp = env, eval_env while isinstance(env_tmp, VecEnvWrapper): if isinstance(env_tmp, VecNormalize): - eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms) + # Only synchronize if observation normalization exists + if hasattr(env_tmp, "obs_rms"): + eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms) eval_env_tmp.ret_rms = deepcopy(env_tmp.ret_rms) env_tmp = env_tmp.venv eval_env_tmp = eval_env_tmp.venv diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index bccb8c6..1e5deca 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.5.1a5 +1.5.1a6 diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index 86a0d84..a363e40 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -388,11 +388,11 @@ def test_offpolicy_normalization(model_class, online_sampling): @pytest.mark.parametrize("make_env", [make_env, make_dict_env]) def test_sync_vec_normalize(make_env): - env = DummyVecEnv([make_env]) + original_env = DummyVecEnv([make_env]) - assert unwrap_vec_normalize(env) is None + assert unwrap_vec_normalize(original_env) is None - env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=100.0, clip_reward=100.0) + env = VecNormalize(original_env, norm_obs=True, norm_reward=True, clip_obs=100.0, clip_reward=100.0) assert isinstance(unwrap_vec_normalize(env), VecNormalize) @@ -433,6 +433,17 @@ def test_sync_vec_normalize(make_env): assert allclose(obs, eval_env.normalize_obs(original_obs)) assert allclose(env.normalize_reward(dummy_rewards), eval_env.normalize_reward(dummy_rewards)) + # Check synchronization when only reward is normalized + env = VecNormalize(original_env, norm_obs=False, norm_reward=True, clip_reward=100.0) + eval_env = DummyVecEnv([make_env]) + eval_env = VecNormalize(eval_env, training=False, norm_obs=False, norm_reward=False) + env.reset() + env.step([env.action_space.sample()]) + assert not np.allclose(env.ret_rms.mean, eval_env.ret_rms.mean) + sync_envs_normalization(env, eval_env) + assert np.allclose(env.ret_rms.mean, eval_env.ret_rms.mean) + assert np.allclose(env.ret_rms.var, eval_env.ret_rms.var) + def test_discrete_obs(): with pytest.raises(ValueError, match=".*only supports.*"): From 2fcf8f91c1fa07bbdd920a27b3874e67997ad754 Mon Sep 17 00:00:00 2001 From: TibiGG <28602860+TibiGG@users.noreply.github.com> Date: Mon, 9 May 2022 12:36:15 +0100 Subject: [PATCH 14/33] Removed redundant double-check of nested Dict (#908) * Removed redundant double-check of nested Dict observation space from BaseAlgorithm * Update changelog Co-authored-by: tibigg --- docs/misc/changelog.rst | 3 ++- stable_baselines3/common/base_class.py | 5 ----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index d8b6d6d..4bcecc4 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -33,6 +33,7 @@ Deprecations: Others: ^^^^^^^ - Upgraded to Python 3.7+ syntax using ``pyupgrade`` +- Removed redundant double-check for nested observations from ``BaseAlgorithm._wrap_env`` (@TibiGG) Documentation: ^^^^^^^^^^^^^^ @@ -970,4 +971,4 @@ And all the contributors: @wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum @eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP @simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485 -@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor +@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 14570be..bdba909 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -209,11 +209,6 @@ class BaseAlgorithm(ABC): # Make sure that dict-spaces are not nested (not supported) check_for_nested_spaces(env.observation_space) - if isinstance(env.observation_space, gym.spaces.Dict): - for space in env.observation_space.spaces.values(): - if isinstance(space, gym.spaces.Dict): - raise ValueError("Nested observation spaces are not supported (Dict spaces inside Dict space).") - if not is_vecenv_wrapped(env, VecTransposeImage): wrap_with_vectranspose = False if isinstance(env.observation_space, gym.spaces.Dict): From 49813d8c68c3642164081984227823ffc33e66d1 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Wed, 25 May 2022 10:24:21 -0400 Subject: [PATCH 15/33] Update doc and add check for unbounded action space (#918) --- docs/guide/integrations.rst | 3 +++ docs/misc/changelog.rst | 2 ++ stable_baselines3/common/base_class.py | 5 +++++ stable_baselines3/common/env_checker.py | 5 +++++ tests/test_envs.py | 12 ++++++++++-- 5 files changed, 25 insertions(+), 2 deletions(-) diff --git a/docs/guide/integrations.rst b/docs/guide/integrations.rst index 7f21bd3..98cf84a 100644 --- a/docs/guide/integrations.rst +++ b/docs/guide/integrations.rst @@ -48,11 +48,14 @@ Hugging Face 🤗 The Hugging Face Hub 🤗 is a central place where anyone can share and explore models. It allows you to host your saved models 💾. You can see the list of stable-baselines3 saved models here: https://huggingface.co/models?other=stable-baselines3 +Most of them are available via the RL Zoo. Official pre-trained models are saved in the SB3 organization on the hub: https://huggingface.co/sb3 We wrote a tutorial on how to use 🤗 Hub and Stable-Baselines3 here: https://colab.research.google.com/drive/1GI0WpThwRHbl-Fu2RHfczq6dci5GBDVE#scrollTo=q4cz-w9MdO7T +For up to date instructions (for instance for using ``package_to_hub()``), please take a look at the Huggingface SB3 package README: https://github.com/huggingface/huggingface_sb3 + Installation ------------- diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 4bcecc4..c473fc0 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -26,6 +26,7 @@ Bug Fixes: - Fixed a bug with special characters in the tensorboard log name (@quantitative-technologies) - Fixed a bug in ``DummyVecEnv``'s and ``SubprocVecEnv``'s seeding function. None value was unchecked (@ScheiklP) - Fixed a bug where ``EvalCallback`` would crash when trying to synchronize ``VecNormalize`` stats when observation normalization was disabled +- Added a check for unbounded actions Deprecations: ^^^^^^^^^^^^^ @@ -42,6 +43,7 @@ Documentation: - Added link to PPO ICLR blog post - Added remark about breaking Markov assumption and timeout handling - Added doc about MLFlow integration via custom logger (@git-thor) +- Updated Huggingface integration doc Release 1.5.0 (2022-03-25) diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index bdba909..e16814c 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -185,6 +185,11 @@ class BaseAlgorithm(ABC): if self.use_sde and not isinstance(self.action_space, gym.spaces.Box): raise ValueError("generalized State-Dependent Exploration (gSDE) can only be used with continuous actions.") + if isinstance(self.action_space, gym.spaces.Box): + assert np.all( + np.isfinite(np.array([self.action_space.low, self.action_space.high])) + ), "Continuous action space must have a finite lower and upper bound" + @staticmethod def _wrap_env(env: GymEnv, verbose: int = 0, monitor_wrapper: bool = True) -> VecEnv: """ " diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index c4e5669..ed07e7e 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -274,6 +274,11 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) - "cf https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html" ) + if isinstance(action_space, spaces.Box): + assert np.all( + np.isfinite(np.array([action_space.low, action_space.high])) + ), "Continuous action space must have a finite lower and upper bound" + if isinstance(action_space, spaces.Box) and action_space.dtype != np.dtype(np.float32): warnings.warn( f"Your action space has dtype {action_space.dtype}, we recommend using np.float32 to avoid cast errors." diff --git a/tests/test_envs.py b/tests/test_envs.py index 671e2a5..1cec17a 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -141,6 +141,8 @@ def test_non_default_spaces(new_obs_space): spaces.Box(low=1, high=-1, shape=(2,), dtype=np.float32), # Same boundaries spaces.Box(low=1, high=1, shape=(2,), dtype=np.float32), + # Unbounded action space + spaces.Box(low=-np.inf, high=1, shape=(2,), dtype=np.float32), # Almost good, except for one dim spaces.Box(low=np.array([-1, -1, -1]), high=np.array([1, 1, 0.99]), dtype=np.float32), ], @@ -156,8 +158,14 @@ def test_non_default_action_spaces(new_action_space): # Change the action space env.action_space = new_action_space - with pytest.warns(UserWarning): - check_env(env) + # Unbounded action space throws an error, + # the rest only warning + if not np.all(np.isfinite(env.action_space.low)): + with pytest.raises(AssertionError), pytest.warns(UserWarning): + check_env(env) + else: + with pytest.warns(UserWarning): + check_env(env) def check_reset_assert_error(env, new_reset_return): From 4b89fbf283c58486ff945b21451c987a83e84591 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sun, 29 May 2022 15:09:50 -0400 Subject: [PATCH 16/33] Fix issues due to newer version of protobuf and sphinx (#924) --- docs/conf.py | 2 +- docs/misc/changelog.rst | 3 ++- setup.py | 13 ++++++++----- stable_baselines3/version.txt | 2 +- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 712908e..18898d5 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -100,7 +100,7 @@ master_doc = "index" # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. -language = None +language = "en" # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index c473fc0..abd7c0c 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.5.1a6 (WIP) +Release 1.5.1a7 (WIP) --------------------------- Breaking Changes: @@ -27,6 +27,7 @@ Bug Fixes: - Fixed a bug in ``DummyVecEnv``'s and ``SubprocVecEnv``'s seeding function. None value was unchecked (@ScheiklP) - Fixed a bug where ``EvalCallback`` would crash when trying to synchronize ``VecNormalize`` stats when observation normalization was disabled - Added a check for unbounded actions +- Fixed issues due to newer version of protobuf (tensorboard) and sphinx Deprecations: ^^^^^^^^^^^^^ diff --git a/setup.py b/setup.py index cd8f209..05745e9 100644 --- a/setup.py +++ b/setup.py @@ -43,10 +43,10 @@ import gym from stable_baselines3 import PPO -env = gym.make('CartPole-v1') +env = gym.make("CartPole-v1") -model = PPO('MlpPolicy', env, verbose=1) -model.learn(total_timesteps=10000) +model = PPO("MlpPolicy", env, verbose=1) +model.learn(total_timesteps=10_000) obs = env.reset() for i in range(1000): @@ -57,12 +57,12 @@ for i in range(1000): obs = env.reset() ``` -Or just train a model with a one liner if [the environment is registered in Gym](https://github.com/openai/gym/wiki/Environments) and if [the policy is registered](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html): +Or just train a model with a one liner if [the environment is registered in Gym](https://www.gymlibrary.ml/content/environment_creation/) and if [the policy is registered](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html): ```python from stable_baselines3 import PPO -model = PPO('MlpPolicy', 'CartPole-v1').learn(10000) +model = PPO("MlpPolicy", "CartPole-v1").learn(10_000) ``` """ # noqa:E501 @@ -121,6 +121,9 @@ setup( "pillow", # Tensorboard support "tensorboard>=2.2.0", + # Protobuf >= 4 has breaking changes + # which does play well with tensorboard + "protobuf~=3.19.0", # Checking memory taken by replay buffer "psutil", ], diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 1e5deca..e39732b 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.5.1a6 +1.5.1a7 From d68f0a2411766beb6da58ee0e989d1a6a72869bc Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 31 May 2022 12:11:16 -0400 Subject: [PATCH 17/33] Update doc: SB3 Contrib RecurrentPPO (#927) * Update doc: contrib update * Update docs/misc/changelog.rst Co-authored-by: Anssi * Address Anssi comments Co-authored-by: Anssi --- README.md | 7 ++++--- docs/guide/algos.rst | 1 + docs/guide/sb3_contrib.rst | 6 ++++-- docs/misc/changelog.rst | 4 +++- docs/modules/ppo.rst | 12 +++++++++++- stable_baselines3/version.txt | 2 +- 6 files changed, 24 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index f727547..fd6c8db 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,7 @@ Documentation: https://stable-baselines3.readthedocs.io/en/master/guide/rl_zoo.h We implement experimental features in a separate contrib repository: [SB3-Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) -This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Truncated Quantile Critics (TQC), Quantile Regression DQN (QR-DQN) or PPO with invalid action masking (Maskable PPO). +This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Recurrent PPO (PPO LSTM), Truncated Quantile Critics (TQC), Quantile Regression DQN (QR-DQN) or PPO with invalid action masking (Maskable PPO). Documentation is available online: [https://sb3-contrib.readthedocs.io/](https://sb3-contrib.readthedocs.io/) @@ -122,7 +122,7 @@ from stable_baselines3 import PPO env = gym.make("CartPole-v1") model = PPO("MlpPolicy", env, verbose=1) -model.learn(total_timesteps=10000) +model.learn(total_timesteps=10_000) obs = env.reset() for i in range(1000): @@ -140,7 +140,7 @@ Or just train a model with a one liner if [the environment is registered in Gym] ```python from stable_baselines3 import PPO -model = PPO('MlpPolicy', 'CartPole-v1').learn(10000) +model = PPO("MlpPolicy", "CartPole-v1").learn(10_000) ``` Please read the [documentation](https://stable-baselines3.readthedocs.io/) for more examples. @@ -172,6 +172,7 @@ All the following examples can be executed online using Google colab notebooks: | HER | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :x: | | PPO | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | QR-DQN[1](#f1) | :x: | :x: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | +| RecurrentPPO[1](#f1) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | SAC | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: | | TD3 | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: | | TQC[1](#f1) | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: | diff --git a/docs/guide/algos.rst b/docs/guide/algos.rst index 5d1bf08..55aba35 100644 --- a/docs/guide/algos.rst +++ b/docs/guide/algos.rst @@ -15,6 +15,7 @@ DQN ❌ ✔️ ❌ ❌ HER ✔️ ✔️ ❌ ❌ ❌ PPO ✔️ ✔️ ✔️ ✔️ ✔️ QR-DQN [#f1]_ ❌ ️ ✔️ ❌ ❌ ✔️ +RecurrentPPO [#f1]_ ✔️ ✔️ ✔️ ✔️ ✔️ SAC ✔️ ❌ ❌ ❌ ✔️ TD3 ✔️ ❌ ❌ ❌ ✔️ TQC [#f1]_ ✔️ ❌ ❌ ❌ ✔️ diff --git a/docs/guide/sb3_contrib.rst b/docs/guide/sb3_contrib.rst index 1dfa912..445832c 100644 --- a/docs/guide/sb3_contrib.rst +++ b/docs/guide/sb3_contrib.rst @@ -8,7 +8,7 @@ We implement experimental features in a separate contrib repository: `SB3-Contrib`_ This allows Stable-Baselines3 (SB3) to maintain a stable and compact core, while still -providing the latest features, like Truncated Quantile Critics (TQC), Augmented Random Search (ARS), Trust Region Policy Optimization (TRPO) or +providing the latest features, like RecurrentPPO (PPO LSTM), Truncated Quantile Critics (TQC), Augmented Random Search (ARS), Trust Region Policy Optimization (TRPO) or Quantile Regression DQN (QR-DQN). Why create this repository? @@ -38,9 +38,11 @@ See documentation for the full list of included features. - `Augmented Random Search (ARS) `_ - `Quantile Regression DQN (QR-DQN)`_ +- `PPO with invalid action masking (Maskable PPO) `_ +- `PPO with recurrent policy (RecurrentPPO aka PPO LSTM) `_ - `Truncated Quantile Critics (TQC)`_ - `Trust Region Policy Optimization (TRPO) `_ -- `PPO with invalid action masking (Maskable PPO) `_ + **Gym Wrappers**: diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index abd7c0c..1c88691 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.5.1a7 (WIP) +Release 1.5.1a8 (WIP) --------------------------- Breaking Changes: @@ -18,6 +18,8 @@ New Features: SB3-Contrib ^^^^^^^^^^^ +- Added Recurrent PPO (PPO LSTM). See https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/53 + Bug Fixes: ^^^^^^^^^^ diff --git a/docs/modules/ppo.rst b/docs/modules/ppo.rst index d32986c..a562aa2 100644 --- a/docs/modules/ppo.rst +++ b/docs/modules/ppo.rst @@ -15,7 +15,7 @@ For that, ppo uses clipping to avoid too large update. .. note:: PPO contains several modifications from the original algorithm not documented - by OpenAI: advantages are normalized and value function can be also clipped . + by OpenAI: advantages are normalized and value function can be also clipped. Notes @@ -31,6 +31,16 @@ Notes Can I use? ---------- +.. note:: + + A recurrent version of PPO is available in our contrib repo: https://sb3-contrib.readthedocs.io/en/master/modules/ppo_recurrent.html + + However we advise users to start with simple frame-stacking as a simpler, faster + and usually competitive alternative, more info in our report: https://wandb.ai/sb3/no-vel-envs/reports/PPO-vs-RecurrentPPO-aka-PPO-LSTM-on-environments-with-masked-velocity--VmlldzoxOTI4NjE4 + See also `Procgen paper appendix Fig 11. `_. + In practice, you can stack multiple observations using ``VecFrameStack``. + + - Recurrent policies: ❌ - Multi processing: ✔️ - Gym spaces: diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index e39732b..511e75b 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.5.1a7 +1.5.1a8 From 7ce7b6a8c25d5862ec4850480191ac66ea02d3c8 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sat, 18 Jun 2022 10:52:52 +0200 Subject: [PATCH 18/33] Update defaults for offpolicy algos with features extractor (#935) --- docs/misc/changelog.rst | 4 +++- stable_baselines3/sac/policies.py | 11 ++++------- stable_baselines3/td3/policies.py | 8 ++++---- stable_baselines3/version.txt | 2 +- 4 files changed, 12 insertions(+), 13 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 1c88691..480b08e 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.5.1a8 (WIP) +Release 1.5.1a9 (WIP) --------------------------- Breaking Changes: @@ -12,6 +12,8 @@ Breaking Changes: - Changed the way policy "aliases" are handled ("MlpPolicy", "CnnPolicy", ...), removing the former ``register_policy`` helper, ``policy_base`` parameter and using ``policy_aliases`` static attributes instead (@Gregwar) - SB3 now requires PyTorch >= 1.11 +- Changed the default network architecture when using ``CnnPolicy`` or ``MultiInputPolicy`` with SAC or DDPG/TD3, + ``share_features_extractor`` is now set to False by default and the ``net_arch=[256, 256]`` (instead of ``net_arch=[]`` that was before) New Features: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py index 6fcbea1..255bd75 100644 --- a/stable_baselines3/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -235,7 +235,7 @@ class SACPolicy(BasePolicy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, n_critics: int = 2, - share_features_extractor: bool = True, + share_features_extractor: bool = False, ): super().__init__( observation_space, @@ -248,10 +248,7 @@ class SACPolicy(BasePolicy): ) if net_arch is None: - if features_extractor_class == NatureCNN: - net_arch = [] - else: - net_arch = [256, 256] + net_arch = [256, 256] actor_arch, critic_arch = get_actor_critic_arch(net_arch) @@ -422,7 +419,7 @@ class CnnPolicy(SACPolicy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, n_critics: int = 2, - share_features_extractor: bool = True, + share_features_extractor: bool = False, ): super().__init__( observation_space, @@ -493,7 +490,7 @@ class MultiInputPolicy(SACPolicy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, n_critics: int = 2, - share_features_extractor: bool = True, + share_features_extractor: bool = False, ): super().__init__( observation_space, diff --git a/stable_baselines3/td3/policies.py b/stable_baselines3/td3/policies.py index f3ed530..8781b32 100644 --- a/stable_baselines3/td3/policies.py +++ b/stable_baselines3/td3/policies.py @@ -119,7 +119,7 @@ class TD3Policy(BasePolicy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, n_critics: int = 2, - share_features_extractor: bool = True, + share_features_extractor: bool = False, ): super().__init__( observation_space, @@ -134,7 +134,7 @@ class TD3Policy(BasePolicy): # Default network architecture, from the original paper if net_arch is None: if features_extractor_class == NatureCNN: - net_arch = [] + net_arch = [256, 256] else: net_arch = [400, 300] @@ -281,7 +281,7 @@ class CnnPolicy(TD3Policy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, n_critics: int = 2, - share_features_extractor: bool = True, + share_features_extractor: bool = False, ): super().__init__( observation_space, @@ -335,7 +335,7 @@ class MultiInputPolicy(TD3Policy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, n_critics: int = 2, - share_features_extractor: bool = True, + share_features_extractor: bool = False, ): super().__init__( observation_space, diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 511e75b..125ec27 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.5.1a8 +1.5.1a9 From d64bcb401ad7d45799af1feee5c1058943be23f0 Mon Sep 17 00:00:00 2001 From: Ram Rachum Date: Tue, 21 Jun 2022 22:58:02 +0300 Subject: [PATCH 19/33] Fix exception cause in base_class.py (#940) --- docs/misc/changelog.rst | 3 ++- stable_baselines3/common/base_class.py | 4 ++-- stable_baselines3/common/callbacks.py | 4 ++-- stable_baselines3/common/env_checker.py | 4 ++-- stable_baselines3/common/noise.py | 4 ++-- stable_baselines3/common/off_policy_algorithm.py | 6 ++++-- stable_baselines3/common/save_util.py | 8 ++++---- stable_baselines3/her/her_replay_buffer.py | 4 ++-- 8 files changed, 20 insertions(+), 17 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 480b08e..52bf3e4 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -32,6 +32,7 @@ Bug Fixes: - Fixed a bug where ``EvalCallback`` would crash when trying to synchronize ``VecNormalize`` stats when observation normalization was disabled - Added a check for unbounded actions - Fixed issues due to newer version of protobuf (tensorboard) and sphinx +- Fix exception causes all over the codebase (@cool-RR) Deprecations: ^^^^^^^^^^^^^ @@ -978,4 +979,4 @@ And all the contributors: @wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum @eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP @simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485 -@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG +@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index e16814c..36a73b8 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -628,11 +628,11 @@ class BaseAlgorithm(ABC): attr = None try: attr = recursive_getattr(self, name) - except Exception: + except Exception as e: # What errors recursive_getattr could throw? KeyError, but # possible something else too (e.g. if key is an int?). # Catch anything for now. - raise ValueError(f"Key {name} is an invalid object name.") + raise ValueError(f"Key {name} is an invalid object name.") from e if isinstance(attr, th.optim.Optimizer): # Optimizers do not support "strict" keyword... diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py index c5f297c..e9f46fe 100644 --- a/stable_baselines3/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -380,12 +380,12 @@ class EvalCallback(EventCallback): if self.model.get_vec_normalize_env() is not None: try: sync_envs_normalization(self.training_env, self.eval_env) - except AttributeError: + except AttributeError as e: raise AssertionError( "Training and eval env are not wrapped the same way, " "see https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html#evalcallback " "and warning above." - ) + ) from e # Reset success rate buffer self._is_success_buffer = [] diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index ed07e7e..3b2c502 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -147,7 +147,7 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action try: _check_obs(obs[key], observation_space.spaces[key], "reset") except AssertionError as e: - raise AssertionError(f"Error while checking key={key}: " + str(e)) + raise AssertionError(f"Error while checking key={key}: " + str(e)) from e else: _check_obs(obs, observation_space, "reset") @@ -166,7 +166,7 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action try: _check_obs(obs[key], observation_space.spaces[key], "step") except AssertionError as e: - raise AssertionError(f"Error while checking key={key}: " + str(e)) + raise AssertionError(f"Error while checking key={key}: " + str(e)) from e else: _check_obs(obs, observation_space, "step") diff --git a/stable_baselines3/common/noise.py b/stable_baselines3/common/noise.py index 119ed36..baa72e9 100644 --- a/stable_baselines3/common/noise.py +++ b/stable_baselines3/common/noise.py @@ -105,8 +105,8 @@ class VectorizedActionNoise(ActionNoise): try: self.n_envs = int(n_envs) assert self.n_envs > 0 - except (TypeError, AssertionError): - raise ValueError(f"Expected n_envs={n_envs} to be positive integer greater than 0") + except (TypeError, AssertionError) as e: + raise ValueError(f"Expected n_envs={n_envs} to be positive integer greater than 0") from e self.base_noise = base_noise self.noises = [copy.deepcopy(self.base_noise) for _ in range(n_envs)] diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index ca57166..99a02ff 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -157,8 +157,10 @@ class OffPolicyAlgorithm(BaseAlgorithm): try: train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1])) - except ValueError: - raise ValueError(f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!") + except ValueError as e: + raise ValueError( + f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!" + ) from e if not isinstance(train_freq[0], int): raise ValueError(f"The frequency of `train_freq` must be an integer and not {train_freq[0]}") diff --git a/stable_baselines3/common/save_util.py b/stable_baselines3/common/save_util.py index e0b104f..1569001 100644 --- a/stable_baselines3/common/save_util.py +++ b/stable_baselines3/common/save_util.py @@ -206,8 +206,8 @@ def open_path(path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verb mode = mode.lower() try: mode = {"write": "w", "read": "r", "w": "w", "r": "r"}[mode] - except KeyError: - raise ValueError("Expected mode to be either 'w' or 'r'.") + except KeyError as e: + raise ValueError("Expected mode to be either 'w' or 'r'.") from e if ("w" == mode) and not path.writable() or ("r" == mode) and not path.readable(): e1 = "writable" if "w" == mode else "readable" raise ValueError(f"Expected a {e1} file.") @@ -441,7 +441,7 @@ def load_from_zip_file( # State dicts. Store into params dictionary # with same name as in .zip file (without .pth) params[os.path.splitext(file_path)[0]] = th_object - except zipfile.BadZipFile: + except zipfile.BadZipFile as e: # load_path wasn't a zip file - raise ValueError(f"Error: the file {load_path} wasn't a zip-file") + raise ValueError(f"Error: the file {load_path} wasn't a zip-file") from e return data, params, pytorch_variables diff --git a/stable_baselines3/her/her_replay_buffer.py b/stable_baselines3/her/her_replay_buffer.py index c461d19..3c19aac 100644 --- a/stable_baselines3/her/her_replay_buffer.py +++ b/stable_baselines3/her/her_replay_buffer.py @@ -28,13 +28,13 @@ def get_time_limit(env: VecEnv, current_max_episode_length: Optional[int]) -> in if current_max_episode_length is None: raise AttributeError # if not available check if a valid value was passed as an argument - except AttributeError: + except AttributeError as e: raise ValueError( "The max episode length could not be inferred.\n" "You must specify a `max_episode_steps` when registering the environment,\n" "use a `gym.wrappers.TimeLimit` wrapper " "or pass `max_episode_length` to the model constructor" - ) + ) from e return current_max_episode_length From ef10189d80dbb2efb3b5391cba4eb3c2ab5c7aae Mon Sep 17 00:00:00 2001 From: Max Weltevrede <31962715+MWeltevrede@users.noreply.github.com> Date: Mon, 4 Jul 2022 15:08:54 +0200 Subject: [PATCH 20/33] Prohibit simultaneous use of optimize_memory_usage and handle_timeout_termination (#948) * Prohibit simultaneous use of optimize_memory_buffer and handle_timeout_termination * Modify test to avoid unsupported buffer configuration * Change from assertion to raising of ValueError * Update changelog * Update style for consistency * Use handle_timeout_termination when possible Co-authored-by: Anssi Co-authored-by: Antonin Raffin --- docs/misc/changelog.rst | 3 ++- stable_baselines3/common/buffers.py | 7 +++++++ tests/test_save_load.py | 3 +++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 52bf3e4..5e893d9 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -33,6 +33,7 @@ Bug Fixes: - Added a check for unbounded actions - Fixed issues due to newer version of protobuf (tensorboard) and sphinx - Fix exception causes all over the codebase (@cool-RR) +- Prohibit simultaneous use of optimize_memory_usage and handle_timeout_termination due to a bug (@MWeltevrede) Deprecations: ^^^^^^^^^^^^^ @@ -979,4 +980,4 @@ And all the contributors: @wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum @eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP @simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485 -@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR +@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index d7728cb..d36c5f5 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -164,6 +164,7 @@ class ReplayBuffer(BaseBuffer): at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 and https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274 + Cannot be used in combination with handle_timeout_termination. :param handle_timeout_termination: Handle timeout termination (due to timelimit) separately and treat the task as infinite horizon task. https://github.com/DLR-RM/stable-baselines3/issues/284 @@ -188,6 +189,12 @@ class ReplayBuffer(BaseBuffer): if psutil is not None: mem_available = psutil.virtual_memory().available + # there is a bug if both optimize_memory_usage and handle_timeout_termination are true + # see https://github.com/DLR-RM/stable-baselines3/issues/934 + if optimize_memory_usage and handle_timeout_termination: + raise ValueError( + "ReplayBuffer does not support optimize_memory_usage = True and handle_timeout_termination = True simultaneously." + ) self.optimize_memory_usage = optimize_memory_usage self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=observation_space.dtype) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 2fdebbe..d7a74c5 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -375,6 +375,9 @@ def test_warn_buffer(recwarn, model_class, optimize_memory_usage): select_env(model_class), buffer_size=100, optimize_memory_usage=optimize_memory_usage, + # we cannot use optimize_memory_usage and handle_timeout_termination + # at the same time + replay_buffer_kwargs={"handle_timeout_termination": not optimize_memory_usage}, policy_kwargs=dict(net_arch=[64]), learning_starts=10, ) From c1f1c3d3d796054f68b3fa741c1a0be4ce2187b5 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 12 Jul 2022 22:50:23 +0200 Subject: [PATCH 21/33] Release v1.6.0 (#958) * Release v1.6.0 + update doc + add copy button * Update read the doc conda env * Update year * Fix bug in kl divergence check * Rephrase requirement for envpool and isaac gym --- docs/conda_env.yml | 6 +++--- docs/conf.py | 13 ++++++++++++- docs/guide/examples.rst | 10 ++++++++++ docs/guide/install.rst | 11 +++++++++++ docs/misc/changelog.rst | 7 ++++++- setup.py | 2 ++ stable_baselines3/common/buffers.py | 3 ++- stable_baselines3/common/distributions.py | 3 ++- stable_baselines3/version.txt | 2 +- tests/test_distributions.py | 4 +++- 10 files changed, 52 insertions(+), 9 deletions(-) diff --git a/docs/conda_env.yml b/docs/conda_env.yml index a01d37b..98a5508 100644 --- a/docs/conda_env.yml +++ b/docs/conda_env.yml @@ -6,9 +6,9 @@ dependencies: - cpuonly=1.0=0 - pip=21.1 - python=3.7 - - pytorch=1.8.1=py3.7_cpu_0 + - pytorch=1.11=py3.7_cpu_0 - pip: - - gym>=0.17.2 + - gym==0.21 - cloudpickle - opencv-python-headless - pandas @@ -16,5 +16,5 @@ dependencies: - matplotlib - sphinx_autodoc_typehints - sphinx>=4.2 - # See https://github.com/readthedocs/sphinx_rtd_theme/issues/1115 - sphinx_rtd_theme>=1.0 + - sphinx_copybutton diff --git a/docs/conf.py b/docs/conf.py index 18898d5..b44be6f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -24,6 +24,14 @@ try: except ImportError: enable_spell_check = False +# Try to enable copy button +try: + import sphinx_copybutton # noqa: F401 + + enable_copy_button = True +except ImportError: + enable_copy_button = False + # source code directory, relative to this file, for sphinx-autobuild sys.path.insert(0, os.path.abspath("..")) @@ -51,7 +59,7 @@ with open(version_file) as file_handler: # -- Project information ----------------------------------------------------- project = "Stable Baselines3" -copyright = "2020, Stable Baselines3" +copyright = "2022, Stable Baselines3" author = "Stable Baselines3 Contributors" # The short X.Y version @@ -83,6 +91,9 @@ extensions = [ if enable_spell_check: extensions.append("sphinxcontrib.spelling") +if enable_copy_button: + extensions.append("sphinx_copybutton") + # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index a5b56b2..0d7e7c0 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -729,6 +729,16 @@ to keep track of the agent progress. model.learn(10_000) +SB3 with EnvPool or Isaac Gym +----------------------------- + +Just like Procgen (see above), `EnvPool `_ and `Isaac Gym `_ accelerate the environment by +already providing a vectorized implementation. + +To use SB3 with those tools, you must wrap the env with tool's specific ``VecEnvWrapper`` that will pre-process the data for SB3, +you can find links to those wrappers in `issue #772 `_. + + Record a Video -------------- diff --git a/docs/guide/install.rst b/docs/guide/install.rst index 3b26927..a9bb761 100644 --- a/docs/guide/install.rst +++ b/docs/guide/install.rst @@ -54,6 +54,17 @@ Bleeding-edge version pip install git+https://github.com/DLR-RM/stable-baselines3 +.. note:: + + If you want to use latest gym version (0.24+), you have to use + + .. code-block:: bash + + pip install git+https://github.com/carlosluis/stable-baselines3/tree/fix_tests + + See `PR #780 `_ for more information. + + Development version ------------------- diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 5e893d9..62f2ddb 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,9 +4,11 @@ Changelog ========== -Release 1.5.1a9 (WIP) +Release 1.6.0 (2022-07-11) --------------------------- +**Recurrent PPO (PPO LSTM), better defaults for learning from pixels with SAC/TD3** + Breaking Changes: ^^^^^^^^^^^^^^^^^ - Changed the way policy "aliases" are handled ("MlpPolicy", "CnnPolicy", ...), removing the former @@ -34,6 +36,7 @@ Bug Fixes: - Fixed issues due to newer version of protobuf (tensorboard) and sphinx - Fix exception causes all over the codebase (@cool-RR) - Prohibit simultaneous use of optimize_memory_usage and handle_timeout_termination due to a bug (@MWeltevrede) +- Fixed a bug in ``kl_divergence`` check that would fail when using numpy arrays with MultiCategorical distribution Deprecations: ^^^^^^^^^^^^^ @@ -51,6 +54,8 @@ Documentation: - Added remark about breaking Markov assumption and timeout handling - Added doc about MLFlow integration via custom logger (@git-thor) - Updated Huggingface integration doc +- Added copy button for code snippets +- Added doc about EnvPool and Isaac Gym support Release 1.5.0 (2022-03-25) diff --git a/setup.py b/setup.py index 05745e9..2816316 100644 --- a/setup.py +++ b/setup.py @@ -111,6 +111,8 @@ setup( "sphinxcontrib.spelling", # Type hints support "sphinx-autodoc-typehints", + # Copy button for code snippets + "sphinx_copybutton", ], "extra": [ # For render diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index d36c5f5..5ed9b4c 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -193,7 +193,8 @@ class ReplayBuffer(BaseBuffer): # see https://github.com/DLR-RM/stable-baselines3/issues/934 if optimize_memory_usage and handle_timeout_termination: raise ValueError( - "ReplayBuffer does not support optimize_memory_usage = True and handle_timeout_termination = True simultaneously." + "ReplayBuffer does not support optimize_memory_usage = True " + "and handle_timeout_termination = True simultaneously." ) self.optimize_memory_usage = optimize_memory_usage diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index 3d1ff5a..7096d01 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Tuple, Union import gym +import numpy as np import torch as th from gym import spaces from torch import nn @@ -688,7 +689,7 @@ def kl_divergence(dist_true: Distribution, dist_pred: Distribution) -> th.Tensor # MultiCategoricalDistribution is not a PyTorch Distribution subclass # so we need to implement it ourselves! if isinstance(dist_pred, MultiCategoricalDistribution): - assert dist_pred.action_dims == dist_true.action_dims, "Error: distributions must have the same input space" + assert np.allclose(dist_pred.action_dims, dist_true.action_dims), "Error: distributions must have the same input space" return th.stack( [th.distributions.kl_divergence(p, q) for p, q in zip(dist_true.distribution, dist_pred.distribution)], dim=1, diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 125ec27..dc1e644 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.5.1a9 +1.6.0 diff --git a/tests/test_distributions.py b/tests/test_distributions.py index 3652b18..07920db 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -163,7 +163,9 @@ def test_categorical(dist, CAT_ACTIONS): BernoulliDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS)), CategoricalDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS)), DiagGaussianDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS), th.rand(N_ACTIONS)), - MultiCategoricalDistribution([N_ACTIONS, N_ACTIONS]).proba_distribution(th.rand(1, sum([N_ACTIONS, N_ACTIONS]))), + MultiCategoricalDistribution(np.array([N_ACTIONS, N_ACTIONS])).proba_distribution( + th.rand(1, sum([N_ACTIONS, N_ACTIONS])) + ), SquashedDiagGaussianDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS), th.rand(N_ACTIONS)), StateDependentNoiseDistribution(N_ACTIONS).proba_distribution( th.rand(N_ACTIONS), th.rand([N_ACTIONS, N_ACTIONS]), th.rand([N_ACTIONS, N_ACTIONS]) From 38706f12f34b94236f3bf45b9aaef724569ea997 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 12 Jul 2022 23:47:53 +0200 Subject: [PATCH 22/33] Use ICRL url for PPO blog post --- docs/modules/ppo.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/modules/ppo.rst b/docs/modules/ppo.rst index a562aa2..829aa33 100644 --- a/docs/modules/ppo.rst +++ b/docs/modules/ppo.rst @@ -25,7 +25,7 @@ Notes - Clear explanation of PPO on Arxiv Insights channel: https://www.youtube.com/watch?v=5P7I-xPq8u8 - OpenAI blog post: https://blog.openai.com/openai-baselines-ppo/ - Spinning Up guide: https://spinningup.openai.com/en/latest/algorithms/ppo.html -- 37 implementation details blog: https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/ +- 37 implementation details blog: https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/ Can I use? From a18b91e01a3fa0b49512c3a5701c517643ce1e64 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 15 Jul 2022 22:48:27 +0200 Subject: [PATCH 23/33] Replace "nature" with "Nature" (magazine) to reduce confusion (#965) * Replace "nature" with "Nature" (magazine) to reduce confusion * Replace "nature" with "Nature" (magazine) to reduce confusion * Update changelog Co-authored-by: mel --- docs/guide/migration.rst | 2 +- docs/misc/changelog.rst | 26 ++++++++++++++++++++++++ stable_baselines3/common/torch_layers.py | 2 +- stable_baselines3/dqn/dqn.py | 2 +- stable_baselines3/version.txt | 2 +- 5 files changed, 30 insertions(+), 4 deletions(-) diff --git a/docs/guide/migration.rst b/docs/guide/migration.rst index 879a5fb..ef26870 100644 --- a/docs/guide/migration.rst +++ b/docs/guide/migration.rst @@ -141,7 +141,7 @@ DQN ^^^ Only the vanilla DQN is implemented right now but extensions will follow. -Default hyperparameters are taken from the nature paper, except for the optimizer and learning rate that were taken from Stable Baselines defaults. +Default hyperparameters are taken from the Nature paper, except for the optimizer and learning rate that were taken from Stable Baselines defaults. DDPG ^^^^ diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 62f2ddb..81ab0ff 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,6 +3,31 @@ Changelog ========== +Release 1.6.1a0 (WIP) +--------------------------- + +Breaking Changes: +^^^^^^^^^^^^^^^^^ + +New Features: +^^^^^^^^^^^^^ + +SB3-Contrib +^^^^^^^^^^^ + +Bug Fixes: +^^^^^^^^^^ + +Deprecations: +^^^^^^^^^^^^^ + +Others: +^^^^^^^ + +Documentation: +^^^^^^^^^^^^^^ +- Fix typo in docstring "nature" -> "Nature" (@Melanol) + Release 1.6.0 (2022-07-11) --------------------------- @@ -986,3 +1011,4 @@ And all the contributors: @eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP @simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485 @Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede +@Melanol diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index 8fd2237..f87337c 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -50,7 +50,7 @@ class FlattenExtractor(BaseFeaturesExtractor): class NatureCNN(BaseFeaturesExtractor): """ - CNN from DQN nature paper: + CNN from DQN Nature paper: Mnih, Volodymyr, et al. "Human-level control through deep reinforcement learning." Nature 518.7540 (2015): 529-533. diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index fe8f398..0cd6dfb 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -20,7 +20,7 @@ class DQN(OffPolicyAlgorithm): Deep Q-Network (DQN) Paper: https://arxiv.org/abs/1312.5602, https://www.nature.com/articles/nature14236 - Default hyperparameters are taken from the nature paper, + Default hyperparameters are taken from the Nature paper, except for the optimizer and learning rate that were taken from Stable Baselines defaults. :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index dc1e644..035e3b6 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.6.0 +1.6.1a0 From fda3d4d748439a6755260896e4350e0383a9c6f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Mon, 18 Jul 2022 11:22:19 +0200 Subject: [PATCH 24/33] Fix returned type in predict (#964) * `arr[0]` to `arr.squeeze(0)` * `squeeze(axis=0)` to `squeeze(0)` * Type testing * Add type test for unvectorized observation * `squeeze(0)` to `squeeze(axis=0)` * Treatment of the laziness symptoms * Update changelog * Udate changelog Co-authored-by: Antonin RAFFIN --- docs/misc/changelog.rst | 3 ++- stable_baselines3/common/distributions.py | 4 ++-- stable_baselines3/common/policies.py | 2 +- tests/test_predict.py | 2 ++ 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 81ab0ff..b258547 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -17,6 +17,7 @@ SB3-Contrib Bug Fixes: ^^^^^^^^^^ +- Fixed the issue that ``predict`` does not always return action as ``np.ndarray`` (@qgallouedec) Deprecations: ^^^^^^^^^^^^^ @@ -1011,4 +1012,4 @@ And all the contributors: @eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP @simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485 @Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede -@Melanol +@Melanol @qgallouedec diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index 7096d01..5247751 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -578,10 +578,10 @@ class StateDependentNoiseDistribution(Distribution): return th.mm(latent_sde, self.exploration_mat) # Use batch matrix multiplication for efficient computation # (batch_size, n_features) -> (batch_size, 1, n_features) - latent_sde = latent_sde.unsqueeze(1) + latent_sde = latent_sde.unsqueeze(dim=1) # (batch_size, 1, n_actions) noise = th.bmm(latent_sde, self.exploration_matrices) - return noise.squeeze(1) + return noise.squeeze(dim=1) def actions_from_params( self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor, deterministic: bool = False diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 51a3d37..a88fad6 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -350,7 +350,7 @@ class BasePolicy(BaseModel): # Remove batch dimension if needed if not vectorized_env: - actions = actions[0] + actions = actions.squeeze(axis=0) return actions, state diff --git a/tests/test_predict.py b/tests/test_predict.py index 853f4d1..89cdb09 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -73,11 +73,13 @@ def test_predict(model_class, env_id, device): obs = env.reset() action, _ = model.predict(obs) + assert isinstance(action, np.ndarray) assert action.shape == env.action_space.shape assert env.action_space.contains(action) vec_env_obs = vec_env.reset() action, _ = model.predict(vec_env_obs) + assert isinstance(action, np.ndarray) assert action.shape[0] == vec_env_obs.shape[0] # Special case for DQN to check the epsilon greedy exploration From b1cc15970a40c86b26a247fab5783e025bfe3da1 Mon Sep 17 00:00:00 2001 From: Adam Gleave Date: Mon, 25 Jul 2022 14:02:53 -0700 Subject: [PATCH 25/33] Use higher resolution time_ns() and avoid division by zero (#979) * Use higher resolution time and round up to eps * Update changelog * Add test case * Fix formatting, time()->time_ns * Bugfix: ns is integer not float * Move test to better place * Divide by 1e9 earlier --- docs/misc/changelog.rst | 1 + stable_baselines3/common/base_class.py | 2 +- stable_baselines3/common/off_policy_algorithm.py | 5 +++-- stable_baselines3/common/on_policy_algorithm.py | 6 ++++-- tests/test_logger.py | 14 ++++++++++++++ tests/test_run.py | 5 ++++- 6 files changed, 27 insertions(+), 6 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index b258547..b1ed3b1 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -18,6 +18,7 @@ SB3-Contrib Bug Fixes: ^^^^^^^^^^ - Fixed the issue that ``predict`` does not always return action as ``np.ndarray`` (@qgallouedec) +- Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers. Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 36a73b8..e8032e7 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -422,7 +422,7 @@ class BaseAlgorithm(ABC): :param tb_log_name: the name of the run for tensorboard log :return: """ - self.start_time = time.time() + self.start_time = time.time_ns() if self.ep_info_buffer is None or reset_num_timesteps: # Initialize buffers if they don't exist, or reinitialize if resetting counters diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index 99a02ff..b841eb0 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -1,5 +1,6 @@ import io import pathlib +import sys import time import warnings from copy import deepcopy @@ -427,8 +428,8 @@ class OffPolicyAlgorithm(BaseAlgorithm): """ Write log. """ - time_elapsed = time.time() - self.start_time - fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time_elapsed + 1e-8)) + time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon) + fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed) self.logger.record("time/episodes", self._episode_num, exclude="tensorboard") if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 763c108..84c89d9 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -1,3 +1,4 @@ +import sys import time from typing import Any, Dict, List, Optional, Tuple, Type, Union @@ -254,13 +255,14 @@ class OnPolicyAlgorithm(BaseAlgorithm): # Display training infos if log_interval is not None and iteration % log_interval == 0: - fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time.time() - self.start_time)) + time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon) + fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed) self.logger.record("time/iterations", iteration, exclude="tensorboard") if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) self.logger.record("time/fps", fps) - self.logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard") + self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard") self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") self.logger.dump(step=self.num_timesteps) diff --git a/tests/test_logger.py b/tests/test_logger.py index 6fe536f..a55f88a 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -1,6 +1,7 @@ import os import time from typing import Sequence +from unittest import mock import gym import numpy as np @@ -381,3 +382,16 @@ def test_fps_logger(tmp_path, algo): # third time, FPS should be the same model.learn(100, log_interval=1, reset_num_timesteps=False) assert max_fps / 10 <= logger.name_to_value["time/fps"] <= max_fps + + +@pytest.mark.parametrize("algo", [A2C, DQN]) +def test_fps_no_div_zero(algo): + """Set time to constant and train algorithm to check no division by zero error. + + Time can appear to be constant during short runs on platforms with low-precision + timers. We should avoid division by zero errors e.g. when computing FPS in + this situation.""" + with mock.patch("time.time", lambda: 42.0): + with mock.patch("time.time_ns", lambda: 42.0): + model = algo("MlpPolicy", "CartPole-v1") + model.learn(total_timesteps=100) diff --git a/tests/test_run.py b/tests/test_run.py index e4e8a2e..b0a9a11 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -10,7 +10,10 @@ normal_action_noise = NormalActionNoise(np.zeros(1), 0.1 * np.ones(1)) @pytest.mark.parametrize("model_class", [TD3, DDPG]) -@pytest.mark.parametrize("action_noise", [normal_action_noise, OrnsteinUhlenbeckActionNoise(np.zeros(1), 0.1 * np.ones(1))]) +@pytest.mark.parametrize( + "action_noise", + [normal_action_noise, OrnsteinUhlenbeckActionNoise(np.zeros(1), 0.1 * np.ones(1))], +) def test_deterministic_pg(model_class, action_noise): """ Test for DDPG and variants (TD3). From d532362e94064f7059166a8e53aa8efb1bf253c0 Mon Sep 17 00:00:00 2001 From: Marsel Khisamutdinov Date: Sat, 30 Jul 2022 15:44:25 +0500 Subject: [PATCH 26/33] Adds info on split tensorboard graphs (#989) * Add info on split tensorboard graphs. * Change wording to make it look better. * Update changelog.rst * Rephrase and add link to issue Co-authored-by: Antonin Raffin --- docs/_static/img/split_graph.png | Bin 0 -> 21464 bytes docs/guide/tensorboard.rst | 10 ++++++++++ docs/misc/changelog.rst | 1 + 3 files changed, 11 insertions(+) create mode 100644 docs/_static/img/split_graph.png diff --git a/docs/_static/img/split_graph.png b/docs/_static/img/split_graph.png new file mode 100644 index 0000000000000000000000000000000000000000..c966c5565ac2c8fe2427952aeb0bb970520448af GIT binary patch literal 21464 zcmeFXV|1lWw>G+xj%{>0M#px?wv&!++wR!5ZQHi(j*X7h>Hb#wdHL-<&KY~0v;V!x zSoc_~YSvXXXU&?bJ5)|a6dvXy3;+Ot7Z(#!002NFfZyk!Ab?M5i_mNU0LQ~!N!3w7 z&xOF&-p1I>(uly(&DMy($kogk0B~I^OgC|+YOgE!ZGgrN=F0SGrEv?+{pk%nY_4d6 z=P_$}E{;%52*iwt7$s-*_oGkLGp}djN?B;>!iIg5=O9M<)lbEX+r=vzlf5^uZhfD( zqb?qsB%jXFTcvFqSG=X4Kb`QWySIK^FPrt;vC3suV-782s(VK2bi?1Z$4bp!<<{E?aXejTY2DaCr+0pIb@mpG{lxHsK8E}15OWnxyc9uoZTQ0*5|g^yt7IC? z-@5}k(+8&8%T)7NVE;E3O1GA2b*gU6_w&2wL8n90vg_sWF4u$GZkw%BI+N6^t*~EC z;!no=Kc63$k0K?zJr2@-aeJSh@&3MlX1adNoMk&ZJIYX=)?z9@W-vJYslG6ba^<5r zOd#J}6;x`Z@gxo>f9bNyneU?V#<8_4n?1=_0H*)~+AXkMqH= zhVX-Iy#0H;U7OdW4K)IT)u_#oS@+ovc8^VyG7eWNt+)1?r?}rZ1_!&}SbsmjACaSH z#M8X)=^Y+)MocpAbG#iga_vrymz;cj+IxLX{`v8TZd_s_5qy7?Nl5GlJ_*$FtRF0@ zPqw3-Q71%?&;}<1Hs%k2wgX}7x#e{idoUpt!l~3|@G&HwEyeR7i5DE+h{S@~XVwbJW)||&3Q7k(WrW^A7DApGePvZ=; zLp9?}tA~THs@ub3hM4Hj&sb3sCKHzwsV3k!NnI?(%Y|7vY zg4}cF!F}BMvWFi%=I~p`n~Bu>I{o)|_o)%5-*XPUZx89qNnQu$Kdg7SX^)i0Njwa$~{Q$+3B!;}x#DrQVPLmhic`B9r$FU8E^Z8&Ft`0;G7 zQVTd~6~~NpvbtE1k#ILZPumz4Wv3VlGq8&XuiIqbfy1h08>9AkMZW#HGR!_61cs`}u?kq^bA zK{jNbleb@qhYVT>QJqaLvmt&)Ka}eRM_Z#_4mFFX;zuwTHe5pX?vB#u=^ivVPZfAL z>XkR4%4!`ae2b>Px7JA!`5Y@#E1YARKt4i!P){Z0U7*tYrp zRF4+JfsB6Fbcknp!j4j+GgB4Ux;`>m7&FbYEk~q+Sfe$c6lrrtl0u*2_m&q$JgGB82z{ zmcr2D^FEm8CT+qEmPO$rq8Hw~soa@ScTs5?$VhRK(>b(63Y@=9>N%m-hr6>-8aee6 zvNL4qAZxc@;l>%`ia6g+Qp1*EYgOrQM^c)t@fB((R+#0Si?H|)_;cV=dP#(vsXWM) zH`tXn@jwaR9?31{Gdh?LR%uMF(?7t!!W_ufB$K<{uxvSO%v8vIR{QE9@0&IaHCey5 zrT{)x3weyMc}dT%fZcKpwlYxYEpXT&0%-0pLHC(Zd}OQ{P2}I#1G&3o4JcFuyB170 zLas;kRWTdTq{uZOXGJSPD3E^3EFJE9GM33%g0+@cH_C`CA?-y4y(EMR1(1K}<$<{d zu!R>YJf%S*m9DOvOXCH=L8?H&gH#A>P?*|b*z(Hs`TZ1RQu)HPjlaCBbsJD2;X;A& z6z_um^P{!UPH?F&v=S}*=UYF6>I&FLf&>H^1MMt6d@h|(evsP&733P~GbynFKU0*{ z4`g`XdOT~n2yj|*LWu5pia{GY;^onou5L{@+03pSebvYuE9-B(6bU87TbSY7>|huE)V#2)-&No`;*gS|k#G3P3w@pP+acR;-N}bD6;FRal0Sc%jG=MU zv%_fU9~`aispNI|)u7vAfvN3Nq0K6F+N%pjOoV#vK-l=pog>gogoM zwdB%A@VG5}(bq0cP|%~Vpv~iFJ(c!X>QQi)MR3fMqcKoIkjzo(AyNd^eR)uh+k+k~ zEkUSH4V`{IC}8mWS`>zuY5pXv;G)hsC{YUAx*BtkU{Yp_Zs#R)JY~kHdE%oYtRN=c zD|5ON!_Cp9hUBw$TX=DuBq`+lp}D-e+9cH@aT@|(Jl6UV|Rml91ySyR%TCYwWqAx*(ju!QCAf;AYNJWOk{Jse#2QAod zd5(NRN`oNdtmm|NognDZpZ2R)(Wf>LX}o*|Xt}AV?KyZzIZg{>3@~L9Gm0G?=;T5) zlWzoh3L6Lrzx)MEI7mc=N)wUG9B|M(Rl?okbfl>=#xA|#!0K#-m>d&4=%EX|!^{#MuVzdQdoX%*Dc~s__ttrmTGfh-;=ebr@By z@8bY6-$#f-5PJO=raN?nt7dM5-Z`5PEPOM3@>!SKlqy0FQIr8Qz2xBUFTFxmtEz{3s2!hU4%@b3awnfhgac(vlqU_j$o8N>PDKP`6Onp>hfnHq@ip&xQ))798{s$X8!|?lGVI$-OsFkQix8TI~r2 zSS|q(a_bZ)9;CHs)multkX4&EfO&lj>mi}h^blB83Qf`{+_Y44CB7;qa8@=(7J1ms zgV1qdCtAX=?PC3Q$%Su|2m^oyG8S|r#IkA2$TiX$4Qf0_uCTt#T!(}Y)nq!?P*gDR z_8-*SwpJ(?uWppBL`xZbF-Cu3j1=)~)$Zho9+s0-W9p+M2dGm+=~))Ii&KA(;-dlA zC0&wUl`WGD|V0O==ZaDz1#w4rfm6jc?R{eh$LdPdwT z%v%0=*$sJg-JVLQwEj@9+~z9$x}!~A;c$~DlNSGX-G+KCoM}#*h#&W$`L@4Eo;Jg6 z*VVy3E+u)H@D1fUuheq{*F-QbDg+=$1UkntA*VQ~M0wIktFoE0=}4E;E(gh&P{@{@ z!AnGgMJ`GOPfII{gGfQa3u=7bMT2~hFpmTr%qRCW5%dCBGlFxxgNeilkGa`xT!n z%6GG{+%RVVvB0Kvv!fPG_sin()&`fLY@RI=SfZZIo*p-@aySOb5TK52kEh8gS_FVH z87w9@J^So~esKD;=_QPY;kTEEuBwP>qYuLUw>a;h#r5+E6DjD;B@llhcRsuo`)}%W zJL+C?w z7G4xljIy6~R;i|to`PPWb1EL3sXe0jC9O_#`4wf*eLA4=6FIcq;TnGuipGG)C=P5zf8j*k>S>teW&pSC}{!0)Z+WA8yp2!pdoT9VLMl-Vv5K|I~#Jz_+X zDB;A6$snw(3&$IG1kAn;>l3)KXh=`koOKejBSxZxY-vaY%ib&MLAn3-YHmbhl#mcQ z#(|A?(G|kPiUYOI`>a0VnoV(EB*9h!=MBdaU~41!%{&T>dQ+B;E8{{6ZGSs0qm^B| z3H+#YNkSM3;afUv;e4z`!Uh49F^rt?oNT6eHMQ>-?C^+*%V?TfIjk6t6>NXHAp2JW zp{prUv}jZ8H5$_cXK{aD%rNj^xT#|8GyQ75z?H{qi?cGnh48?8P(toX8k==u`j4?DDJ6Pr+TmO>mpL)Zm@nx0zH zP`P!{g(dda@LM=-U;#};{oSK%x3PLUx6)4!j7e`P_=UlCa)AfUJuE!k!d30Duq-aG zk`fEwjgz@Zsj|iyAmt>@@ycNMT)SwhIF9f`0}WnSdSOoc5s@R)``*BZ&=H=~tGfP&j8c@Ax*7%Q!zZvHv;e`Ass}ki+PJ4t%SqPARIcuI8I48DPci$qP}RV!CsRQ=%dX^@6v+uNM2@`*>Q zP9uQZe?4Z2Gz<(>4Ga^SH8WrhVRhN0il89h`&Uds*c+Pa-H8xaMX_r5(NJzuP~JkA zT6fE>v~(aA*$1}<7R+oSdSSO6!QJI-4+DIIG>y2ONY2@N{HwR*DKlZVRM8&d@$i}T zEqI<4k{vkuzQWRl`Bhq1G}6yQ{76q?0-XvOeuFX~(c@yRMT%dAKUTTe{&zM_>C-&`r62vQpVi+XKk*{O_ZZVu$rj3Xq>D9z7Ak8_O`^JjF z-cSMVUm8(r{BYD`0tiJj{X_Anreg{&q!PJM0 zz7m5h5;ggAvMm}lgn$VaNTRRuQLs%k1BigM^o8J^d6Gy_dV$E%c)~7ni0xMn+ zVr?LE)uz80HHu8S7^bSv_=6O`7G@V-FQ*F0|2od}X)Lb&eqr3*$kg2hjttbpg>qtUhz!h4 z9hrZcev`V|2&b-BQnDZy@^#^hFfDSCAIX?)!aw0M$QJn z+6A&RI&sCh%7a-1XL4^bVc)L9!YDP`5dVj7Bm*#}@;{Xz%`D2JXXGvfadi6%6gAi2 zgD8^n6b%UYuk%(<rk z>gYX@H|MLq&f`54{UdHiIO{kW2bZQ~vc+(cgSi5_*k_y>)mO%{XC|{NmZeSf)?k3x zV?p5Xt#{0@ce=Q1LORh<`7KBM4FXA}9@(aGil0|SXkz%@hOTX43rTiB#8JvB! z%6>m~Uv$&Qpta_j1+ZNGj8pjAASk@^jcUD;s{J$^Q|1{eMH{FdqN-ShDf5e7&zzLf z=o;=%dpGc(K1}p{*%n74%lwYri^DAWd4$5P6QxQoST2|MpyP$_M{pe%6f0V~bR3dK z#cnA25nwEh31oiq!uX*$eTx^rKtOdR#`>#3H{A%=1T-l+k9IqapE5_%-*C0K?aQ+Xk$l83^R~8ta4M|Nb7JV24q+Z=BVDXpG)CiX(?DD9pu+;xdS;Pw)! zpLkW)B+Hh1@O)9@H(u+NpJb+65wD}>6CNc(QBaNpEg)@5P71eAju17t{0fnFVRhrm zxWYSnnFH)?$mVcQ8TKA5ZHvT9)lx_(L=mcPHX|@j-J>CTHkO5wpL)0h?o5AsEz$3OU5m*@u{g@vz%yPN zF3iaY*Py{Ynp_j5co{(oj9k-el-7Uz#jMI2A1Fvb9i%ab4T&E-E9;Pu4NrWqiYjDM zEFs?>eAL*C30d1J-*U=EW!7Hig(OJESK{hMz(qWmZGkR0U_1{Q47#ltTD$&f?1Kp3 z$p#T$091rkBTN5BYvv1SsmzI$bOCmlfR4|7b#7zyjYv-UWQKv2j>ai|NCMsp&KBBR zmJ3fg9V(Qa2pgiAx*N*FCkslrBxwk%AST2jm9ChXxBSB9uV`L}_G>XV>;);%N?)?l z8lMY_(s@2JJuZFb-0l$L4ue&{z~e@2204F}XD-cJ)(c~IBR@vDvXvV* zAg=3-<~kdP3&ah z0aWl_DbykIdweOe>c#IW5t&`ra(rD4rJS>nLgizDT2MmT+D&9bVSMy!^^~R@o8%r_ z?0!ZMU)ts<=;&glGp;>|WXom9k5I%$$l+@k&+#xfUHD1i%o7)v1E~j8E1it>(HwsH z;GEH}-@tY~E)J|OP~zq+#!MEFpHnOBRlLq}rc3-Vi)#7GQGZw?Z;lEW9gydl^^R?? zvX?3?-vp;ih+-)uh_n@(s}XRtWbg-^pvf5$`YK$&?m%kZPBFqvw2OFw8uM?6FLbCP zB%pldFq?aeb!)-jG8&f~?mbbA-jK(ieii%TFCIxkrFY`kLa;IIC*!dS&FL=pM&=A_ z44(Hj!EgYz=)gywBHYXKHNwwMXA3by>p0!&{?8C0k7xY%z)Q$ zs?t)N1~yi-dWJUoMzpS0w!mvT0Dz0v)mG2I!pM<8-^j$wnw#jVy_<-@%#fStD~mL} zw5_0#shOC&y^*53jFN%7g#m{l5ibu6mn$a_z{&Yo6PY?X+H%s- zxwyE{x-ik&*qhKXaBy(Y(KFI9GSUDkXdK+E9ravktR0BoA^yS;GIB7mH?ws#v#}<4 z$JEoeadPA)A_Ddk{C#~^w$jr7fVXz|n-qZZpmWu;rDLF_r?ax6`_~Z;j>67BkiRYH zza8PA1ia0pQ!sL{ak4is5_UGSb|n5+2t$K^jJI{NxBSD7p#hzdrI8ho)B!jv!@o@_ zDlRSej}h-8Ffp^T{WA(E?0=(lG&BCEu>Ngr?>&Fm`PYR2hyMfj->Co2`yXN;m9#Xc zkd1-U`|`wvxQX89=QOl2Ff-)*>noccyOF*T8x1|H5ep5o0Xsd7J_9oc4Lc(PD=UY- zp^=e3{l7qoTRS-FSsNI=Ljl2Q&44&;#(GRl`g(>mEJkdGG|YyE^fW+B1{zi-eHMCV zMtXV{4x@j8khM1hYNej#zpm;X$`A;pr)SJy%mUO4;N0{y%na-XG#tikj5LPq9PG?2 z`o;{5hI)UX3=KF%Z0xP{fZJ(irDtM9XKQWpr{i7Vocwa)+(eADe=Yn^i=3sNqcLy* zH<6T?wUg^VNtMj3j1(R9-lfUF%ECd<&c?yO!NNq(&cyIfA{8Th2cRatV=~awGP3{a zdEXXJATvP0>b>hJ5a3TckQYutdm}wZ8+#=i8%u7Y_XQBVkNl(F1YCb@ikO)LkizX< zdfam*f3ibDP zv;StZ4D}5d*?~Jo!=c9t+;2U0V;T+yLsl9+HV$KbMkacCV`j#GM|ZF>c68CRH{v$| zath=Is6c51t6_gU2E zFkoR|<6xyR*4Ja8VP?`ZqS4pqFrqPFGBPyKV`5`7V*N{r{vp8sPb_jMF)(v7vv4wS z{5y+Wbnk}i@0-d+_g}XCFM@xubU@Mk)dsYnK!Z&84~zUanZ0Yu|Ki`@)a`$91|al* zCix%9_dnwLkGTFv68Ikx|3|z2Bd-6E1pY_F|Ix1h-^2y;&z#Q48W`xg05djUx1M$2 z3z~suf`W45f`Wg49}@r|kM)Y>7VGExG^jH#p-hS3j3hrImq{HG#%_)tCP%mcQ^}d4 zxu$M97gNx}4pRjV$%|2rKs(jO*#hIg4H=RjADYj1F zVg@Vcq$yQ~Y{1f57_OTrg3ysLDrSP(KqXG$2d>T|)Q@}GsrsMOZyfFm9fQ8UVqE=j z^JLnev(_5mr(Or-n+VGXjt26{v|^rO&%%anGQw$LxtPq2=Nty)^=ns#Byj5C6^XM* zqDs?bo&}Z+5^pg~Dm9HJJ#H{XZeDIsmdy}|Lcys8{Gq^=6Xb%Vy%KZc^$De4q49tF zo+XbI#*~AJM7b&D@H@Y4`}}CvD5qz*Bo#*C8xG&|g!{0gl_T$vab+-UmI3HB*_N<* zrgGAKiN8~q495lOT<;lll2m-~9mmc2JsKmun&{6QC;;Nhn{P7);?>OeShlR;fM1Nr z7t|)P34jdl5sR%`_T1WB-CFN9Nqc((sJqx_E@OVA1*YE+`jVnTfZy-`Ii1Dvz!qp* zF*OGO02cNA0|H3PzydZxI*LmRLmoonLero&ezQ`4V12qFs`fe zN0AKZpAkNd2&WL~UV*1!2&5IYX7SDyz|%ZxZ#7LH(QUEWqy-;&-Avy8q@6AwKS{+A zMgrq!r*yaTBY^UUD*S(V2z(u!*}`|7zmJcP3Wmp(G7=`O?s)J?vmGgbHo*3IAosp~ zQkdK*0C|I5xu4XYT$wL7LEK*;onGuj8ON^Ojw)XK6$ucL+QYkUr3+sepd6uqdQJEj#m{k6DM$E)&bM4usl z`GhoxQx8u3ga*o=K!D=^^${Wm<4^G4Jsdx-b$ZMcgnpOi#A9`&H!?b4KWCL@A24JD zF&)D5RVr6jd?+X^Qox7ORZ<47T~hJ`0T!HjA71^cl)AZjaX6uJaevim4bI(DpFBb! z(No{7HLBq7Y1P7|0EgxA_Mr$CoGLYjG`kebfMy=L2S7nVYOQ`a!(XO&K7_u|^fW6~ zQEIKjWVq(D;$p0sDK-7PuO$Ii#rj6~1y zSj}*WX|xZt@qQP+HSt zxog93>yPp(%N=l9Iu`~`sloO0G;ocnEnUgn;`BD_s%yKiLge85GmdWXIj8p}?zqn) zEBVU~B5AR!Z;YcXY)MUtoYB8dsyI*O`i%{^6r}y^iOELF`=#= z2;*N}*n6Yrq6N_MN$y2^Rfq(wPyVbS;J$XEvj_}LX;H;cRW|qk(c9H;SysB2tGBwJ zro5E#b>=|A4q1Z#cpns7XZJ(x>X}2PiucTC?vu7RLe=NW_C<-kBK^|S`<4%c#-7?D zJ8G?=3O|2cBv#fXhy^Gh1t_A3=I8a79ixDrFk)k+>&c*uxIR3Iv_g7hgE}|dJn3mt z<1np4{cbIdkgT6?Dfb~mL>xb1+)H|A6PjtbgPzr~2EHWjBuvhd*HCB-vS;bG2T zsU^URU7_uziVLL1$4|B%MHuu7hKC3dMc}c-r59@sGfh_#e4z?w5{=r<6^;g>w>pi? z;IIcvZ?-fVYAEe!uxA-0EnKl=J)BCN^bd>{TY&8ohccWWO(_}}yngM(W!|*IC2e}R zj^V{ zYr4>M8#Dq$aKC=3c>a+;(0En({AF3&ixfEfV61gNYc2!dEl zQJG$Q0ng1Sw^bV6xD?7*<$9NtyWPMKJ7@BketD?Ce(|ev zb62Zs;PLvTNCH`uiZUEVJN68ZN+^9YQum@Xz{rDhDTiJW7PV%Zz?AlrS3_SaDD|_8 zOWa&Lg|P$7S8Y4&ZnrxZ09GXBjvUFgJNYLOma-C&L^1ZbyDeMx0CUS{SxMfpc+nSm zy!5HH=(^@rU*YP0%uq%yWh3wIA&TU%7kV4~5`+5BJ!G;3&U2;97#lo+OZcZ3vOqil?Hb%V(*#`D$GCKfwiZkpO;tNZ<09yX56d5EoG4 zg5=UNAmz|L`0?coA|`kze2%`$=gOjY>hexVEsKKthL;Y~b;ZcQ$b#Xb(LsnubFYXU zLeiCo5uU3!B!`}p*IyCLr!>6mqz93fh*ABP{&_C<@hag({@0HOjc479;Dp4)&x!Lg zU&%be!vG`pt+wqJP=Nv^qJ|i&w3y15rx?6)FY;MjbI~{K2il*{M9o-^#hbB4>r$T< zea{PS+A4=m8ju7X5U|u3iZIoR3|CuyJV>t`?+h14AM!g+J(3kwv`35Rnn*Ig|9B=d z-tdOS;r$NQQ<$td@-oc|kh`cVBH+#L6c%2&aGp?M(yoq>g%PL^6E4n=9vpe5rQ*HJ zPEBd;vNAAB{@soF3@WuQH}*^md-O12C>oh8`{=bVaqvhc4ha+}BA2?nWVw#`GHHnGC9rLCD@QC^Z|wk#QjlbXIkAb@mKDP!UOjA8Ad5#D{eRT?4iB8>3E* zBufNRE*jWQLu#CfY{))pd)wDi8leB~gT59|jBchvsIR-rF^mOTF9N`EWmLf)k&f?c zlSDig(1;=d=nw7&Q?f#}p+QJ@k=|Q&kKrgqOzi_*~|1GRf%GVM>7f9MUGjG`-bLD+mi2*e;Ofg^W9hfy@P1B*MPNw8 zZaI=uuvbBi>D`SVK3D-AJn)6kY!EM`c?88HO2i<&@u6`vWfR9MEVQOM$o9V0qy_!T z58*y_U?y}h#yH@n2&YOMLn1jI_e(d?_{l;*KlS0t+Wu)O}Mk_4Thw1GL&ahoEeQoQzMkY_KthqPueEFsyrB>BN#<+I68MXj`0HO;HM6a z1jba5cfxlJ3xmnkAB1lp%m#?QgXT*K+`dIG@=GuJMm8RI`AWEISPD!ti$S^0`1j}@ z^q6oaJx;e89b4@V-0evdaagA$38mp2`!glO3$~w&6A)6`QcJ05os{bh?me| zU+CQJ1NqyLM_p81_xV4`swuM~c`l5g`TOb_tQ75~Xswdg||z$tm|G^TrDAE4ah@hZqWwFp+J4qaEGt`fn1^3xv>%{n^{Guv{@MxQJ85k2fIB~Xa$i#BXOivXH)n6=Mdz4D=T8g^ zWVlE;wnDpE0l2Wb*pueXzBbtBm(A_@v6hJo{ZXLDfTwO1)#00C%Wh#l>i$nK-dhG;+4(=1h^@#)D_B$!i z=?2B1d;Z|$`4J z@i2FA`u}n;_7QD_fg&+022Dr>oe99Exc^=*rZ`#Ys z(mG9mY8;h2k0J4mf2WacE3aw_v8UYTCw}+O!Re=a59YebG)7#oH>9Oq!NKKXyv8tB3abC8e zp%47sW#IMCbq0iwn6I@yXS0e?ehFx(CkT+7xF~{|EMtCZ#wgWbF_ zZJVw>kw=E@rK|J#Pd32fr8&&atpRrGX4I3qu#0m~l%WDt{9_ z%%%uQd=vRg=f0+c%`QWo@VGaT3<@wnM;155O1@i);dJHVu8O>9TD|`PFCIP#&L5y~ z+(d+ZvxLOe!fS@;GFz8DvL6XF6(M3n!i=^#BxXVF?wb;2%atgi&_E9mV%Rd3Zt6Vi z7^Lx03yLT~ptNL+a5%J<07>Fq7WuCk24RIH5a>90m`!z(dkU4jUf+;{85YJdHf^&l`n`Br zVE*6xfn^8D#2DgTgO2CF$SU&3$0rJl1by)Ff_y|#f;&FvSYfOu_|$j#9r7Dwo{ZNg zuV3VO!IZO|uk4+6f?u|#(c+Al!<%~`GzJApk!(YrzHLSodR8)K; zCk=|4ync}%!z!Gi7`Mu2s;<>i7fwxB7f2^ENDwGS9|a1hZBF&pbU8DSC{7@q5W8@u z)RZ#rKGB1o(_^RY7Fz0UI?gH1-!Q>nckCpcym^l(mDd5f=$8p{87*ck8c9By5Xl~q z+SeDGh$(G2ts87W?x{T`ew$_n3D8o8h>3gy0UDJ81iWk{@vEtOqPX{p0nw@w$s~9` z>Kn-c2OF?q=7OP3AC>K0vbiHzg*hCK5a=x2kb>r%0HAZJ$@X%~JF&iROPt#lI9AP- z`Y$%Oc6;$8#e2PS|6zjgW2Az1ZFJw!SHORlX>?F!B=HRi&pAC$41DVeOL`N+`%d84 zAt%#)fwAOOmz9^dfAKY*mOc7gFap`Y^@Eur$7{u`m@6lxOL(OdN`00aTf>dURUTP07Wk zSRTnFtWq10ITPhr3vIGJD=?@$x@h_=We8ECNILNlg5@F^Wl1tp#8@|@Qg|K2^A@UH zFYG9Fd^9si9eHvc+0Qp_-f-LuuYEos0fTEa(r3vWEGRHgnl7F9I9IDOIcLh0KY5iZ zNgU4^_5MY^>IFUsx)#H8RxmMO>n9ENpvDNh-z!j`|I}r6@J~cn9~c(a8q9em^2-=a z9;mZ4U!(+}1apzf%E+{m|7IX<#CXRP^ULp0qzvP-^!-^9!4NM$`yB2WN~PG3sTp=8 zZ`Xo$a>sfnSm6j-@t5D{FPu33bbXN0b~j?VBo?Y!(~{=*<^jg&HR~5CmBt;~*x6`N z+Sb7xEXQy|LWtrqYF4?h30s&tt|0wEa6=^@QIEKSJGol#%wWEDjASyI_{s#2ni{I@ z*k4z~{1mj1iytBh6~X@T9zt-P5K=g5=V~^%iqBoYS}fiCN^7LiKD%~J)1rft7Op$F zZOYH(Q&5CWUBqaJI={zQ(HR7fF#zDg$Gbk zDyCAcK|w_EUT}(wOEb0BJMl4&6O%Z?&#fn1HwkEhrWVBjoa|MRorS6SI{rd_M`B2q zk``qsj&r!*TOOxA{9x^3K-Y`VG`Xl7oCJ(W7Z(+kh!`CE7#4b4 zq6|daVGn+NP@bF$B-ZoixB9szOjLXH+Y}gEvxqK*ai5TZ2czV2VJ$>R+kI&;h~`&( zl(Jx0Siq|soSBWGiRcp#HOyKJQC!SDb=pPBT_=d==z|X5SCfIoJX*s_H9F*F)U2R- z{4tjG(Jt!UaoELbGStfMZa}3+p1H3d8!10;HofhVS z%6ep|HYopY?z~3p2eN87h;EOgwKF}?kkP@AC345-*(L+Zj>QzwL-5|(;?{iGC4m~! zhQDh3q12q3Ys_mHLjbtHK%eMbPzJMls5;ttWa}#j#<;t^T>bqfj9tU~R+rI|FW}c68ZC*N10Ja++$SD`$aS3sMmC)BXY+^j{vD zbW@{E3JN%49S{uH9Y^&Fc4-6x*e2V)-ghIThs%h(Fg=a&)rND=N43Tm!4oO0iODKe zMZoGs`sC0CN4;9j59Dj>&&1Q_B%1|*nrhk*`D$5|Q3jvqFs|}g?N;s;g*F|g^=UAE zra~8QZM^lSE3r>DYd^ugZahA0i#025)Ko9-U2oq%*Qgtfika{)o*kzfON8ua_4{E6 zGisZEeInjv<`1YNZMs+)ePS>%kki*G5LD>l#r#&2_qTS1UT+U%O zOtnKBKCz{@ z)vw=veHiYN`YYmG)ldr9v>vu(*}HG*xU_za^7)L5h{Fg4Om(q?2VdvR%tV&2s%=<2 z>LwxT;3>U*xVLlzbVy3ux`?xyMrS&B-ry zyCi~)GP$6By>9~bR6zdwZ3gY7s5 zDjfd=Rl|kSTR2nm*7nX@(C)Cxw&Vw0&zz%E4 zrqviz(;aw`6^E8Ahq?++yUFbVmN;P;pb1D=W4>JUzVGV_Wz?t><44LpZ(_c9nx8Y2 z7Ps>K`mi8|0n9QJHkraF7|KJ9F<9uF>l(suD@|UI+#dU;eT1d( z(KV(v9@Jq7;e@Q0ri*O9aC#E6L?n6!?>u4iack$=H}H*Dwef9KUiiiHO`;o|>lY6u zgyioQ%nFiqMFe@rh}Yko&<=rx2EQN~CH0=v)ZEmo6cvLKq~)pJd9w7#QT zU~bkc2bHsVE`QGnQ$bk;9qLiU;=zn1n1(#oi0>8mep2;oQ#f2l@}Ywl zmxsM~*7~!a_9MPyZrk=}S^BQYC8&L4Y^}9%7YxN!l&g#o|X+#n9 z(?yD(Ey4BxbTl@m)rpq>DS4=GS+`WvlEq2meC4x*v#%7Lj>mCURKP(xf);=z5Llf8 zP1RF~OL|Ip@^}5Vg8|OibzwWQ?FWpEft~jp_MfG>!YArK#y5RRT7Y;0dXr(VPu~cD zJFlfe(%B0|+6dwwOAd4;@pyv-pntR^PN}hWVatKtmkrB8bGBh2Nt6M2F=y0_G3r`; z?d8-PAjW!uO!37W3}`>s^G=k=ib7oC!Z+5==0#zEx%)f8gkh z^SR7jOTW@EluoQo9yt~*;8NZFx9-{`*X?0JOU)|z#y#~bZnUy+mka{8h0mt6Z&EnVZa#{}bC!X|J z0`dt34q^HPNTzNll3)LnfM?5tRDX}iCFmed=a_``K{RKt^QGmX*Ck1LF!y#`A{k2W>aDU#QWLxBX+T~uCLr8dS)e6`Q_nA>95Kj1R*FXMqSG0!& zKj+u`s7YI$X5(n)79_Aya&XfPFNen&>Z+0YP_cJ|Xj>rmP>k~lX{d61=N>bxm&h=dcsLO%I9U9Id@-j^0-Ol4cZ3f2LG_{YmzG~TM;lyCM zIK=f^{sd6N{K@l7Y-?HnytR%XB@Mjz(#R058S_27BWu;OExiIf5J21Gir@jfbe4zXuDw&UO5Wn>)?wZ0r?*1SIV2pOQ(zW+}?ML)L z%zzx#dvwbHtbYR?#@`+Rtc9BMXNq%12de?B@&pEJWQcYSA*<+oeka3CiLQ<%U}0*T+3M> z+29F-G|M+6(555^5KjYr36Y?cylw`484Iy{fOj19ZZvCE5U)>y`fOM!qsRR!CuFIC z0a;$Dcs?z6++dsO7@Vw8tC$5l8yiTJ&L`@tsQWO<#v4W_r_2x@V%5nK$qOyBGZ>1> z6-ww>wt_Nyq2)A9JS@=k^hO}D>BzKzmc078xwsB^-^}zBQf+qnadY^5aPvcY0;U?o zOkMwB%-U1OV@S5A?$U48o!BlzfM4}F zv+Hn2k2Q*A_z48MiFPRBbrr;JZ@lBp(wi4sTSQkasZFi>B$8bLWQ!iQko(hXX(o`M zcc@hNYHT(nK5&sUww>Y-Kf(_z-J7C*-}j967auLXz%TYoOjABETVBM1>@A9zd7VRS zVCZS_3DYaeOcC4UrRxC~S%$Y_Kxs^NB7m2k`$TKri$;Taocx_rLY(#=OY*wjg7}X5 z=p196g&Qpi=I#I9Fs_e@p@DYi=Z#hP5Kt3e_b0f3ayIlq4eG-ZRWem&ebKiVKY<)R zxOx<9{7M~{GgSPl%YVC@r9b$#{l&q$-Mg-o-4V+80~F-)3U!>w$Tnm)xtc-`laa6{ zhb2`syd@9uDwk^(#ljbp!3Dw-9);&meaKT12*47}Y^bJqdLUbK%=5D5T(K0pnl}5w zR<+`US{TPev|U$p?(!2zFom^dv9>lHWG{t=t5$n>=V^5DD7as=2Rs!tB;6Fsq$JjQ z;*r|48IzKwdx5#zpH^}5&nV}lMCR1;<)Da#NdZ&leU_sqM%dWf=fgj7Ou+;9_Ax$jGooSgegvEaHTJ*ugO&favq3~s)Fiw;MoLaYN0!*r$oGQS&)*&{}#j1RPJw?+@N!^J#?92jm{p{;T6t>1maO$PRd z=segOBG(t+K$tF`&^uNTZx_W(O zm^8$8-T{zm(3vxYlY2-1U|gy~fvIW|9Kb|I)q~ym0P=Y^?~`FqfcgWS)<&g1tg9(3 zpFi{^w^(l`_{7P`fX+)>4w<^>R0kHNC?LMy!Ia8EE7+lQ;;Y$EVso;YuYcgIP41d- zxSySTy47%J${O)QAS)wHuMK_v30j?Kxtj;z99X2n&?+N)NL%H^MosIcoI6)QAD&9H zP_1_kEs;J=G z3Qi0Vb-G@0jCnUu>f&3=LA%l$(g+L%YgOI%e>gshmDBUP-`33w0Up###k)0XXP%&A zcWhG%>7R-EE1vsbxAvCl-}XJ+&2WrCYXtIY>Hyg<{Hv8YOWi!bFS#%D)gB=V8$j(3 lY^wZb59+^`_). + And, if you still managed to get your graphs split by other means, just put tensorboard log files into the same folder. + + .. image:: ../_static/img/split_graph.png + :width: 330 + :alt: split_graph + Once the learn function is called, you can monitor the RL agent during or after the training, with the following bash command: .. code-block:: bash diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index b1ed3b1..6a42744 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -29,6 +29,7 @@ Others: Documentation: ^^^^^^^^^^^^^^ - Fix typo in docstring "nature" -> "Nature" (@Melanol) +- Add info on split tensorboard logs into (@Melanol) Release 1.6.0 (2022-07-11) From 646d6d38b6ba9aac612d4431176493a465ac4758 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francesco=20Lucian=C3=B2?= <56108851+francescoluciano@users.noreply.github.com> Date: Sat, 30 Jul 2022 12:52:35 +0200 Subject: [PATCH 27/33] Fixed typo in PPO doc (#983) * Fixed typo Fixed typo * Update changelog Co-authored-by: Antonin Raffin --- docs/misc/changelog.rst | 7 ++++--- docs/modules/ppo.rst | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 6a42744..3acdbcd 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -28,8 +28,9 @@ Others: Documentation: ^^^^^^^^^^^^^^ -- Fix typo in docstring "nature" -> "Nature" (@Melanol) -- Add info on split tensorboard logs into (@Melanol) +- Fixed typo in docstring "nature" -> "Nature" (@Melanol) +- Added info on split tensorboard logs into (@Melanol) +- Fixed typo in ppo doc (@francescoluciano) Release 1.6.0 (2022-07-11) @@ -1014,4 +1015,4 @@ And all the contributors: @eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP @simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485 @Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede -@Melanol @qgallouedec +@Melanol @qgallouedec @francescoluciano diff --git a/docs/modules/ppo.rst b/docs/modules/ppo.rst index 829aa33..d0c425f 100644 --- a/docs/modules/ppo.rst +++ b/docs/modules/ppo.rst @@ -8,7 +8,7 @@ PPO The `Proximal Policy Optimization `_ algorithm combines ideas from A2C (having multiple workers) and TRPO (it uses a trust region to improve the actor). -The main idea is that after an update, the new policy should be not too far form the old policy. +The main idea is that after an update, the new policy should be not too far from the old policy. For that, ppo uses clipping to avoid too large update. From 6ce33f5bd2dabe389509845ce8789d30de53f298 Mon Sep 17 00:00:00 2001 From: jlp-ue <54306210+jlp-ue@users.noreply.github.com> Date: Fri, 5 Aug 2022 17:54:48 +0200 Subject: [PATCH 28/33] Fix url in docs (#1000) * fixed URL in docs * Update changelog.rst --- docs/guide/install.rst | 2 +- docs/misc/changelog.rst | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/guide/install.rst b/docs/guide/install.rst index a9bb761..0169495 100644 --- a/docs/guide/install.rst +++ b/docs/guide/install.rst @@ -60,7 +60,7 @@ Bleeding-edge version .. code-block:: bash - pip install git+https://github.com/carlosluis/stable-baselines3/tree/fix_tests + pip install git+https://github.com/carlosluis/stable-baselines3@fix_tests See `PR #780 `_ for more information. diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 3acdbcd..b7f27ab 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -31,6 +31,7 @@ Documentation: - Fixed typo in docstring "nature" -> "Nature" (@Melanol) - Added info on split tensorboard logs into (@Melanol) - Fixed typo in ppo doc (@francescoluciano) +- Fixed typo in install doc(@jlp-ue) Release 1.6.0 (2022-07-11) @@ -1015,4 +1016,4 @@ And all the contributors: @eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP @simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485 @Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede -@Melanol @qgallouedec @francescoluciano +@Melanol @qgallouedec @francescoluciano @jlp-ue From c4f54fcf047d7bf425fb6b88a3c8ed23fe375f9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Sat, 6 Aug 2022 14:19:20 +0200 Subject: [PATCH 29/33] Handling multi-dimensional action spaces (#971) * Handle non 1D action shape * Revert changes of observation (out of the scope of this PR) * Apply changes to DictReplayBuffer * Update tests * Rollout buffer n-D actions space handling * Remove error when non 1D action space * ActorCriticPolicy return action with the proper shape * remove useless reshape * Update changelog * Add tests Co-authored-by: Antonin RAFFIN --- docs/misc/changelog.rst | 1 + stable_baselines3/common/buffers.py | 9 +++--- stable_baselines3/common/distributions.py | 1 - stable_baselines3/common/policies.py | 5 ++-- tests/test_spaces.py | 36 +++++++++++++++++++++-- 5 files changed, 42 insertions(+), 10 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index b7f27ab..e9d4b78 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -19,6 +19,7 @@ Bug Fixes: ^^^^^^^^^^ - Fixed the issue that ``predict`` does not always return action as ``np.ndarray`` (@qgallouedec) - Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers. +- Added multidimensional action space support (@qgallouedec) Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 5ed9b4c..0eb2651 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -247,8 +247,7 @@ class ReplayBuffer(BaseBuffer): next_obs = next_obs.reshape((self.n_envs,) + self.obs_shape) # Same, for actions - if isinstance(self.action_space, spaces.Discrete): - action = action.reshape((self.n_envs, self.action_dim)) + action = action.reshape((self.n_envs, self.action_dim)) # Copy to avoid modification by reference self.observations[self.pos] = np.array(obs).copy() @@ -433,6 +432,9 @@ class RolloutBuffer(BaseBuffer): if isinstance(self.observation_space, spaces.Discrete): obs = obs.reshape((self.n_envs,) + self.obs_shape) + # Same reshape, for actions + action = action.reshape((self.n_envs, self.action_dim)) + self.observations[self.pos] = np.array(obs).copy() self.actions[self.pos] = np.array(action).copy() self.rewards[self.pos] = np.array(reward).copy() @@ -586,8 +588,7 @@ class DictReplayBuffer(ReplayBuffer): self.next_observations[key][self.pos] = np.array(next_obs[key]).copy() # Same reshape, for actions - if isinstance(self.action_space, spaces.Discrete): - action = action.reshape((self.n_envs, self.action_dim)) + action = action.reshape((self.n_envs, self.action_dim)) self.actions[self.pos] = np.array(action).copy() self.rewards[self.pos] = np.array(reward).copy() diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index 5247751..fc48625 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -658,7 +658,6 @@ def make_proba_distribution( dist_kwargs = {} if isinstance(action_space, spaces.Box): - assert len(action_space.shape) == 1, "Error: the action space must be a vector" cls = StateDependentNoiseDistribution if use_sde else DiagGaussianDistribution return cls(get_action_dim(action_space), **dist_kwargs) elif isinstance(action_space, spaces.Discrete): diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index a88fad6..3809b61 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -336,8 +336,8 @@ class BasePolicy(BaseModel): with th.no_grad(): actions = self._predict(observation, deterministic=deterministic) - # Convert to numpy - actions = actions.cpu().numpy() + # Convert to numpy, and reshape to the original action shape + actions = actions.cpu().numpy().reshape((-1,) + self.action_space.shape) if isinstance(self.action_space, gym.spaces.Box): if self.squash_output: @@ -592,6 +592,7 @@ class ActorCriticPolicy(BasePolicy): distribution = self._get_action_dist_from_latent(latent_pi) actions = distribution.get_actions(deterministic=deterministic) log_prob = distribution.log_prob(actions) + actions = actions.reshape((-1,) + self.action_space.shape) return actions, values, log_prob def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distribution: diff --git a/tests/test_spaces.py b/tests/test_spaces.py index b754042..0696492 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -33,6 +33,19 @@ class DummyMultiBinary(gym.Env): return self.observation_space.sample(), 0.0, False, {} +class DummyMultidimensionalAction(gym.Env): + def __init__(self): + super().__init__() + self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) + self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32) + + def reset(self): + return self.observation_space.sample() + + def step(self, action): + return self.observation_space.sample(), 0.0, False, {} + + @pytest.mark.parametrize("model_class", [SAC, TD3, DQN]) @pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8)]) def test_identity_spaces(model_class, env): @@ -53,22 +66,39 @@ def test_identity_spaces(model_class, env): @pytest.mark.parametrize("model_class", [A2C, DDPG, DQN, PPO, SAC, TD3]) -@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1"]) +@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1", DummyMultidimensionalAction()]) def test_action_spaces(model_class, env): + kwargs = {} if model_class in [SAC, DDPG, TD3]: - supported_action_space = env == "Pendulum-v1" + supported_action_space = env == "Pendulum-v1" or isinstance(env, DummyMultidimensionalAction) + kwargs["learning_starts"] = 2 + kwargs["train_freq"] = 32 elif model_class == DQN: supported_action_space = env == "CartPole-v1" elif model_class in [A2C, PPO]: supported_action_space = True + kwargs["n_steps"] = 64 if supported_action_space: - model_class("MlpPolicy", env) + model = model_class("MlpPolicy", env, **kwargs) + if isinstance(env, DummyMultidimensionalAction): + model.learn(64) else: with pytest.raises(AssertionError): model_class("MlpPolicy", env) +def test_sde_multi_dim(): + SAC( + "MlpPolicy", + DummyMultidimensionalAction(), + learning_starts=10, + use_sde=True, + sde_sample_freq=2, + use_sde_at_warmup=True, + ).learn(20) + + @pytest.mark.parametrize("model_class", [A2C, PPO, DQN]) @pytest.mark.parametrize("env", ["Taxi-v3"]) def test_discrete_obs_space(model_class, env): From a30d36002b066a16193f4534d0aa74962225c508 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Tue, 16 Aug 2022 10:53:22 +0200 Subject: [PATCH 30/33] Fix `DictReplayBuffer.next_observations` type (#1013) * Fix DictReplayBuffer.next_observations type * Update changelog Co-authored-by: Antonin RAFFIN --- docs/misc/changelog.rst | 1 + stable_baselines3/common/type_aliases.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index e9d4b78..3fde0be 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -26,6 +26,7 @@ Deprecations: Others: ^^^^^^^ +- Fixed ``DictReplayBuffer.next_observations`` typing (@qgallouedec) Documentation: ^^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index 7e69d39..f4c29ab 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -50,7 +50,7 @@ class ReplayBufferSamples(NamedTuple): class DictReplayBufferSamples(ReplayBufferSamples): observations: TensorDict actions: th.Tensor - next_observations: th.Tensor + next_observations: TensorDict dones: th.Tensor rewards: th.Tensor From 792e3bcc275cb5f71f894e55c37a374ca9b744c7 Mon Sep 17 00:00:00 2001 From: Burak Demirbilek Date: Tue, 16 Aug 2022 14:32:32 +0300 Subject: [PATCH 31/33] Fixed missing verbose parameter passing (#1011) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- docs/misc/changelog.rst | 3 ++- stable_baselines3/common/base_class.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 3fde0be..b01d60d 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -20,6 +20,7 @@ Bug Fixes: - Fixed the issue that ``predict`` does not always return action as ``np.ndarray`` (@qgallouedec) - Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers. - Added multidimensional action space support (@qgallouedec) +- Fixed missing verbose parameter passing in the ``EvalCallback`` constructor (@burakdmb) Deprecations: ^^^^^^^^^^^^^ @@ -1018,4 +1019,4 @@ And all the contributors: @eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP @simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485 @Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede -@Melanol @qgallouedec @francescoluciano @jlp-ue +@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index e8032e7..9445ee4 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -392,6 +392,7 @@ class BaseAlgorithm(ABC): log_path=log_path, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, + verbose=self.verbose, ) callback = CallbackList([callback, eval_callback]) From 73822c34da221f85b3b736518cc5b49c315f9480 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Tue, 16 Aug 2022 17:54:55 +0200 Subject: [PATCH 32/33] Support for `device=auto` buffers and set it as default value (#1009) * Default device is "auto" for buffer + auto device support in BufferBaseClass * Update docstring * Update tests * Unify tests * Update changelog * Fix tests on CUDA device Co-authored-by: Antonin RAFFIN Co-authored-by: Antonin Raffin --- docs/misc/changelog.rst | 2 + stable_baselines3/common/buffers.py | 21 +++++----- stable_baselines3/her/her_replay_buffer.py | 2 +- tests/test_buffers.py | 49 +++++++++++++++++++++- 4 files changed, 61 insertions(+), 13 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index b01d60d..88c3c90 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -29,6 +29,8 @@ Others: ^^^^^^^ - Fixed ``DictReplayBuffer.next_observations`` typing (@qgallouedec) +- Added support for ``device="auto"`` in buffers and made it default (@qgallouedec) + Documentation: ^^^^^^^^^^^^^^ - Fixed typo in docstring "nature" -> "Nature" (@Melanol) diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 0eb2651..5972531 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -13,6 +13,7 @@ from stable_baselines3.common.type_aliases import ( ReplayBufferSamples, RolloutBufferSamples, ) +from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import VecNormalize try: @@ -39,7 +40,7 @@ class BaseBuffer(ABC): buffer_size: int, observation_space: spaces.Space, action_space: spaces.Space, - device: Union[th.device, str] = "cpu", + device: Union[th.device, str] = "auto", n_envs: int = 1, ): super().__init__() @@ -51,7 +52,7 @@ class BaseBuffer(ABC): self.action_dim = get_action_dim(action_space) self.pos = 0 self.full = False - self.device = device + self.device = get_device(device) self.n_envs = n_envs @staticmethod @@ -157,7 +158,7 @@ class ReplayBuffer(BaseBuffer): :param buffer_size: Max number of element in the buffer :param observation_space: Observation space :param action_space: Action space - :param device: + :param device: PyTorch device :param n_envs: Number of parallel environments :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer which reduces by almost a factor two the memory used, @@ -175,7 +176,7 @@ class ReplayBuffer(BaseBuffer): buffer_size: int, observation_space: spaces.Space, action_space: spaces.Space, - device: Union[th.device, str] = "cpu", + device: Union[th.device, str] = "auto", n_envs: int = 1, optimize_memory_usage: bool = False, handle_timeout_termination: bool = True, @@ -328,7 +329,7 @@ class RolloutBuffer(BaseBuffer): :param buffer_size: Max number of element in the buffer :param observation_space: Observation space :param action_space: Action space - :param device: + :param device: PyTorch device :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator Equivalent to classic advantage when set to 1. :param gamma: Discount factor @@ -340,7 +341,7 @@ class RolloutBuffer(BaseBuffer): buffer_size: int, observation_space: spaces.Space, action_space: spaces.Space, - device: Union[th.device, str] = "cpu", + device: Union[th.device, str] = "auto", gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, @@ -493,7 +494,7 @@ class DictReplayBuffer(ReplayBuffer): :param buffer_size: Max number of element in the buffer :param observation_space: Observation space :param action_space: Action space - :param device: + :param device: PyTorch device :param n_envs: Number of parallel environments :param optimize_memory_usage: Enable a memory efficient variant Disabled for now (see https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702) @@ -507,7 +508,7 @@ class DictReplayBuffer(ReplayBuffer): buffer_size: int, observation_space: spaces.Space, action_space: spaces.Space, - device: Union[th.device, str] = "cpu", + device: Union[th.device, str] = "auto", n_envs: int = 1, optimize_memory_usage: bool = False, handle_timeout_termination: bool = True, @@ -658,7 +659,7 @@ class DictRolloutBuffer(RolloutBuffer): :param buffer_size: Max number of element in the buffer :param observation_space: Observation space :param action_space: Action space - :param device: + :param device: PyTorch device :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator Equivalent to Monte-Carlo advantage estimate when set to 1. :param gamma: Discount factor @@ -670,7 +671,7 @@ class DictRolloutBuffer(RolloutBuffer): buffer_size: int, observation_space: spaces.Space, action_space: spaces.Space, - device: Union[th.device, str] = "cpu", + device: Union[th.device, str] = "auto", gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, diff --git a/stable_baselines3/her/her_replay_buffer.py b/stable_baselines3/her/her_replay_buffer.py index 3c19aac..e3fc63e 100644 --- a/stable_baselines3/her/her_replay_buffer.py +++ b/stable_baselines3/her/her_replay_buffer.py @@ -73,7 +73,7 @@ class HerReplayBuffer(DictReplayBuffer): self, env: VecEnv, buffer_size: int, - device: Union[th.device, str] = "cpu", + device: Union[th.device, str] = "auto", replay_buffer: Optional[DictReplayBuffer] = None, max_episode_length: Optional[int] = None, n_sampled_goal: int = 4, diff --git a/tests/test_buffers.py b/tests/test_buffers.py index 45c5e6a..0e028e6 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -4,9 +4,10 @@ import pytest import torch as th from gym import spaces -from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer +from stable_baselines3.common.buffers import DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.type_aliases import DictReplayBufferSamples, ReplayBufferSamples +from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import VecNormalize @@ -71,7 +72,7 @@ def test_replay_buffer_normalization(replay_buffer_cls): env = make_vec_env(env) env = VecNormalize(env) - buffer = replay_buffer_cls(100, env.observation_space, env.action_space) + buffer = replay_buffer_cls(100, env.observation_space, env.action_space, device="cpu") # Interract and store transitions env.reset() @@ -94,3 +95,47 @@ def test_replay_buffer_normalization(replay_buffer_cls): assert th.allclose(observations.mean(0), th.zeros(1), atol=1) # Test reward normalization assert np.allclose(sample.rewards.mean(0), np.zeros(1), atol=1) + + +@pytest.mark.parametrize("replay_buffer_cls", [DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer]) +@pytest.mark.parametrize("device", ["cpu", "cuda", "auto"]) +def test_device_buffer(replay_buffer_cls, device): + if device == "cuda" and not th.cuda.is_available(): + pytest.skip("CUDA not available") + + env = { + RolloutBuffer: DummyEnv, + DictRolloutBuffer: DummyDictEnv, + ReplayBuffer: DummyEnv, + DictReplayBuffer: DummyDictEnv, + }[replay_buffer_cls] + env = make_vec_env(env) + + buffer = replay_buffer_cls(100, env.observation_space, env.action_space, device=device) + + # Interract and store transitions + obs = env.reset() + for _ in range(100): + action = env.action_space.sample() + next_obs, reward, done, info = env.step(action) + if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]: + episode_start, values, log_prob = np.zeros(1), th.zeros(1), th.ones(1) + buffer.add(obs, action, reward, episode_start, values, log_prob) + else: + buffer.add(obs, next_obs, action, reward, done, info) + obs = next_obs + + # Get data from the buffer + if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]: + data = buffer.get(50) + elif replay_buffer_cls in [ReplayBuffer, DictReplayBuffer]: + data = buffer.sample(50) + + # Check that all data are on the desired device + desired_device = get_device(device).type + for value in list(data): + if isinstance(value, dict): + for key in value.keys(): + assert value[key].device.type == desired_device + elif isinstance(value, th.Tensor): + assert value.device.type == desired_device From 57e0054e62ac5a964f9c1e557b59028307d21bff Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Wed, 17 Aug 2022 09:55:40 +0200 Subject: [PATCH 33/33] Add Quentin to the list of maintainers (#1014) --- README.md | 2 +- docs/misc/changelog.rst | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index fd6c8db..973a180 100644 --- a/README.md +++ b/README.md @@ -232,7 +232,7 @@ To cite this repository in publications: ## Maintainers -Stable-Baselines3 is currently maintained by [Ashley Hill](https://github.com/hill-a) (aka @hill-a), [Antonin Raffin](https://araffin.github.io/) (aka [@araffin](https://github.com/araffin)), [Maximilian Ernestus](https://github.com/ernestum) (aka @ernestum), [Adam Gleave](https://github.com/adamgleave) (@AdamGleave) and [Anssi Kanervisto](https://github.com/Miffyli) (@Miffyli). +Stable-Baselines3 is currently maintained by [Ashley Hill](https://github.com/hill-a) (aka @hill-a), [Antonin Raffin](https://araffin.github.io/) (aka [@araffin](https://github.com/araffin)), [Maximilian Ernestus](https://github.com/ernestum) (aka @ernestum), [Adam Gleave](https://github.com/adamgleave) (@AdamGleave), [Anssi Kanervisto](https://github.com/Miffyli) (@Miffyli) and [Quentin Gallouédec](https://gallouedec.com/) (@qgallouedec). **Important Note: We do not do technical support, nor consulting** and don't answer personal questions per email. Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv), [Reddit](https://www.reddit.com/r/reinforcementlearning/) or [Stack Overflow](https://stackoverflow.com/) in that case. diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 88c3c90..eca1173 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -986,7 +986,8 @@ Maintainers ----------- Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_), `Ashley Hill`_ (aka @hill-a), -`Maximilian Ernestus`_ (aka @ernestum), `Adam Gleave`_ (`@AdamGleave`_) and `Anssi Kanervisto`_ (aka `@Miffyli`_). +`Maximilian Ernestus`_ (aka @ernestum), `Adam Gleave`_ (`@AdamGleave`_), `Anssi Kanervisto`_ (aka `@Miffyli`_) +and `Quentin Gallouédec`_ (aka @qgallouedec). .. _Ashley Hill: https://github.com/hill-a .. _Antonin Raffin: https://araffin.github.io/ @@ -996,6 +997,8 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_) .. _@AdamGleave: https://github.com/adamgleave .. _Anssi Kanervisto: https://github.com/Miffyli .. _@Miffyli: https://github.com/Miffyli +.. _Quentin Gallouédec: https://gallouedec.com/ +.. _@qgallouedec: https://github.com/qgallouedec