From 4fa17dcf0f72455aa3d36308291d4b052b2544f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Mon, 2 Jan 2023 14:51:11 +0100 Subject: [PATCH] Standardize the use of `from gym import spaces` (#1240) * generalize the use of `from gym import spaces` * command line get system info * Documentation line length for doc * update changelog * add space before os plateform to avoid ref to other issue * format * get_system_info update in changelog * fix type check error * fix get system info * add comment about regex * update version --- .github/ISSUE_TEMPLATE/bug_report.yml | 5 +-- .github/ISSUE_TEMPLATE/custom_env.yml | 30 +++++++------- CONTRIBUTING.md | 3 +- docs/guide/custom_env.rst | 12 ++++-- docs/guide/custom_policy.rst | 12 +++--- docs/misc/changelog.rst | 4 +- stable_baselines3/common/base_class.py | 15 +++---- stable_baselines3/common/distributions.py | 3 +- stable_baselines3/common/env_checker.py | 2 +- stable_baselines3/common/envs/identity_env.py | 40 ++++++++++--------- .../common/envs/multi_input_envs.py | 11 ++--- .../common/off_policy_algorithm.py | 8 ++-- .../common/on_policy_algorithm.py | 10 ++--- stable_baselines3/common/policies.py | 24 +++++------ stable_baselines3/common/torch_layers.py | 5 ++- stable_baselines3/common/utils.py | 32 ++++++++------- .../common/vec_env/base_vec_env.py | 13 +++--- .../common/vec_env/subproc_vec_env.py | 7 ++-- stable_baselines3/common/vec_env/util.py | 14 +++---- .../common/vec_env/vec_normalize.py | 16 ++++---- stable_baselines3/dqn/dqn.py | 4 +- stable_baselines3/dqn/policies.py | 18 ++++----- stable_baselines3/sac/policies.py | 18 ++++----- stable_baselines3/sac/sac.py | 4 +- stable_baselines3/td3/policies.py | 18 ++++----- stable_baselines3/td3/td3.py | 4 +- stable_baselines3/version.txt | 2 +- tests/test_env_checker.py | 6 +-- tests/test_gae.py | 9 +++-- tests/test_logger.py | 5 ++- tests/test_predict.py | 3 +- tests/test_spaces.py | 15 +++---- tests/test_vec_envs.py | 39 +++++++++--------- tests/test_vec_normalize.py | 4 +- 34 files changed, 219 insertions(+), 196 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index 8defe9a..1382f85 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -52,9 +52,8 @@ body: * Versions of any other relevant libraries You can use `sb3.get_system_info()` to print relevant packages info: - ```python - import stable_baselines3 as sb3 - sb3.get_system_info() + ```sh + python -c 'import stable_baselines3 as sb3; sb3.get_system_info()' ``` - type: checkboxes id: terms diff --git a/.github/ISSUE_TEMPLATE/custom_env.yml b/.github/ISSUE_TEMPLATE/custom_env.yml index 7887ef6..cf624c0 100644 --- a/.github/ISSUE_TEMPLATE/custom_env.yml +++ b/.github/ISSUE_TEMPLATE/custom_env.yml @@ -36,6 +36,7 @@ body: ```python import gym import numpy as np + from gym import spaces from stable_baselines3 import A2C from stable_baselines3.common.env_checker import check_env @@ -43,20 +44,20 @@ body: class CustomEnv(gym.Env): - def __init__(self): - super().__init__() - self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(14,)) - self.action_space = gym.spaces.Box(low=-1, high=1, shape=(6,)) + def __init__(self): + super().__init__() + self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(14,)) + self.action_space = spaces.Box(low=-1, high=1, shape=(6,)) - def reset(self): - return self.observation_space.sample() + def reset(self): + return self.observation_space.sample() - def step(self, action): - obs = self.observation_space.sample() - reward = 1.0 - done = False - info = {} - return obs, reward, done, info + def step(self, action): + obs = self.observation_space.sample() + reward = 1.0 + done = False + info = {} + return obs, reward, done, info env = CustomEnv() check_env(env) @@ -86,9 +87,8 @@ body: * Versions of any other relevant libraries You can use `sb3.get_system_info()` to print relevant packages info: - ```python - import stable_baselines3 as sb3 - sb3.get_system_info() + ```sh + python -c 'import stable_baselines3 as sb3; sb3.get_system_info()' ``` - type: checkboxes id: terms diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 89b15b7..eb1d08f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -38,7 +38,8 @@ pip install -e .[docs,tests,extra] ## Codestyle -We are using [black codestyle](https://github.com/psf/black) (max line length of 127 characters) together with [isort](https://github.com/timothycrosley/isort) to sort the imports. +We use [black codestyle](https://github.com/psf/black) (max line length of 127 characters) together with [isort](https://github.com/timothycrosley/isort) to sort the imports. +For the documentation, we use the default line length of 88 characters per line. **Please run `make format`** to reformat your code. You can check the codestyle using `make check-codestyle` and `make lint`. diff --git a/docs/guide/custom_env.rst b/docs/guide/custom_env.rst index a561b2d..d2878c3 100644 --- a/docs/guide/custom_env.rst +++ b/docs/guide/custom_env.rst @@ -27,14 +27,17 @@ That is to say, your environment must implement the following methods (and inher .. code-block:: python import gym + import numpy as np from gym import spaces + class CustomEnv(gym.Env): - """Custom Environment that follows gym interface""" + """Custom Environment that follows gym interface.""" + metadata = {"render.modes": ["human"]} def __init__(self, arg1, arg2, ...): - super(CustomEnv, self).__init__() + super().__init__() # Define action and observation space # They must be gym.spaces objects # Example when using discrete actions: @@ -46,12 +49,15 @@ That is to say, your environment must implement the following methods (and inher def step(self, action): ... return observation, reward, done, info + def reset(self): ... return observation # reward, done, info can't be included + def render(self, mode="human"): ... - def close (self): + + def close(self): ... diff --git a/docs/guide/custom_policy.rst b/docs/guide/custom_policy.rst index 616fc49..458f60f 100644 --- a/docs/guide/custom_policy.rst +++ b/docs/guide/custom_policy.rst @@ -125,9 +125,9 @@ that derives from ``BaseFeaturesExtractor`` and then pass it to the model when t .. code-block:: python - import gym import torch as th import torch.nn as nn + from gym import spaces from stable_baselines3 import PPO from stable_baselines3.common.torch_layers import BaseFeaturesExtractor @@ -140,7 +140,7 @@ that derives from ``BaseFeaturesExtractor`` and then pass it to the model when t This corresponds to the number of unit for the last layer. """ - def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256): + def __init__(self, observation_space: spaces.Box, features_dim: int = 256): super().__init__(observation_space, features_dim) # We assume CxHxW images (channels first) # Re-ordering will be done by pre-preprocessing or wrapper @@ -199,7 +199,7 @@ downsampling and "vector" with a single linear layer. from stable_baselines3.common.torch_layers import BaseFeaturesExtractor class CustomCombinedExtractor(BaseFeaturesExtractor): - def __init__(self, observation_space: gym.spaces.Dict): + def __init__(self, observation_space: spaces.Dict): # We do not know features-dim here before going over all the items, # so put something dummy for now. PyTorch requires calling # nn.Module.__init__ before adding modules @@ -310,7 +310,7 @@ If your task requires even more granular control over the policy/value architect from typing import Callable, Dict, List, Optional, Tuple, Type, Union - import gym + from gym import spaces import torch as th from torch import nn @@ -367,8 +367,8 @@ If your task requires even more granular control over the policy/value architect class CustomActorCriticPolicy(ActorCriticPolicy): def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Space, lr_schedule: Callable[[float], float], net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, activation_fn: Type[nn.Module] = nn.Tanh, diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index b5b1701..7e3ecd7 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.7.0a10 (WIP) +Release 1.7.0a11 (WIP) -------------------------- .. note:: @@ -71,6 +71,8 @@ Others: - Upgraded GitHub CI/setup-python to v4 and checkout to v3 - Set tensors construction directly on the device (~8% speed boost on GPU) - Monkey-patched ``np.bool = bool`` so gym 0.21 is compatible with NumPy 1.24+ +- Standardized the use of ``from gym import spaces`` +- Modified ``get_system_info`` to avoid issue linked to copy-pasting on GitHub issue Documentation: ^^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 132f314..9d5fa03 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -11,6 +11,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Un import gym import numpy as np import torch as th +from gym import spaces from stable_baselines3.common import utils from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback @@ -101,7 +102,7 @@ class BaseAlgorithm(ABC): seed: Optional[int] = None, use_sde: bool = False, sde_sample_freq: int = -1, - supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None, + supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None, ): if isinstance(policy, str): self.policy_class = self._get_policy_from_name(policy) @@ -117,8 +118,8 @@ class BaseAlgorithm(ABC): self._vec_normalize_env = unwrap_vec_normalize(env) self.verbose = verbose self.policy_kwargs = {} if policy_kwargs is None else policy_kwargs - self.observation_space = None # type: Optional[gym.spaces.Space] - self.action_space = None # type: Optional[gym.spaces.Space] + self.observation_space = None # type: Optional[spaces.Space] + self.action_space = None # type: Optional[spaces.Space] self.n_envs = None self.num_timesteps = 0 # Used for updating schedules @@ -175,13 +176,13 @@ class BaseAlgorithm(ABC): ) # Catch common mistake: using MlpPolicy/CnnPolicy instead of MultiInputPolicy - if policy in ["MlpPolicy", "CnnPolicy"] and isinstance(self.observation_space, gym.spaces.Dict): + if policy in ["MlpPolicy", "CnnPolicy"] and isinstance(self.observation_space, spaces.Dict): raise ValueError(f"You must use `MultiInputPolicy` when working with dict observation space, not {policy}") - if self.use_sde and not isinstance(self.action_space, gym.spaces.Box): + if self.use_sde and not isinstance(self.action_space, spaces.Box): raise ValueError("generalized State-Dependent Exploration (gSDE) can only be used with continuous actions.") - if isinstance(self.action_space, gym.spaces.Box): + if isinstance(self.action_space, spaces.Box): assert np.all( np.isfinite(np.array([self.action_space.low, self.action_space.high])) ), "Continuous action space must have a finite lower and upper bound" @@ -212,7 +213,7 @@ class BaseAlgorithm(ABC): if not is_vecenv_wrapped(env, VecTransposeImage): wrap_with_vectranspose = False - if isinstance(env.observation_space, gym.spaces.Dict): + if isinstance(env.observation_space, spaces.Dict): # If even one of the keys is a image-space in need of transpose, apply transpose # If the image spaces are not consistent (for instance one is channel first, # the other channel last), VecTransposeImage will throw an error diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index b78ef82..b1cd439 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -3,7 +3,6 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union -import gym import numpy as np import torch as th from gym import spaces @@ -659,7 +658,7 @@ class TanhBijector: def make_proba_distribution( - action_space: gym.spaces.Space, use_sde: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None + action_space: spaces.Space, use_sde: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None ) -> Distribution: """ Return an instance of Distribution for the correct type of action space diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index fcf952d..cb682ca 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -266,7 +266,7 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action def _check_spaces(env: gym.Env) -> None: """ - Check that the observation and action spaces are defined and inherit from gym.spaces.Space. For + Check that the observation and action spaces are defined and inherit from spaces.Space. For envs that follow the goal-conditioned standard (previously, the gym.GoalEnv interface) we check the observation space is gym.spaces.Dict """ diff --git a/stable_baselines3/common/envs/identity_env.py b/stable_baselines3/common/envs/identity_env.py index 8f6ccd2..ea11268 100644 --- a/stable_baselines3/common/envs/identity_env.py +++ b/stable_baselines3/common/envs/identity_env.py @@ -1,14 +1,16 @@ -from typing import Optional, Union +from typing import Any, Dict, Generic, Optional, Tuple, TypeVar, Union +import gym import numpy as np -from gym import Env, Space -from gym.spaces import Box, Discrete, MultiBinary, MultiDiscrete +from gym import spaces from stable_baselines3.common.type_aliases import GymObs, GymStepReturn +T = TypeVar("T", int, np.ndarray) -class IdentityEnv(Env): - def __init__(self, dim: Optional[int] = None, space: Optional[Space] = None, ep_length: int = 100): + +class IdentityEnv(gym.Env, Generic[T]): + def __init__(self, dim: Optional[int] = None, space: Optional[spaces.Space] = None, ep_length: int = 100): """ Identity environment for testing purposes @@ -22,7 +24,7 @@ class IdentityEnv(Env): if space is None: if dim is None: dim = 1 - space = Discrete(dim) + space = spaces.Discrete(dim) else: assert dim is None, "arguments for both 'dim' and 'space' provided: at most one allowed" @@ -32,13 +34,13 @@ class IdentityEnv(Env): self.num_resets = -1 # Becomes 0 after __init__ exits. self.reset() - def reset(self) -> GymObs: + def reset(self) -> T: self.current_step = 0 self.num_resets += 1 self._choose_next_state() return self.state - def step(self, action: Union[int, np.ndarray]) -> GymStepReturn: + def step(self, action: T) -> Tuple[T, float, bool, Dict[str, Any]]: reward = self._get_reward(action) self._choose_next_state() self.current_step += 1 @@ -48,14 +50,14 @@ class IdentityEnv(Env): def _choose_next_state(self) -> None: self.state = self.action_space.sample() - def _get_reward(self, action: Union[int, np.ndarray]) -> float: + def _get_reward(self, action: T) -> float: return 1.0 if np.all(self.state == action) else 0.0 def render(self, mode: str = "human") -> None: pass -class IdentityEnvBox(IdentityEnv): +class IdentityEnvBox(IdentityEnv[np.ndarray]): def __init__(self, low: float = -1.0, high: float = 1.0, eps: float = 0.05, ep_length: int = 100): """ Identity environment for testing purposes @@ -65,7 +67,7 @@ class IdentityEnvBox(IdentityEnv): :param eps: the epsilon bound for correct value :param ep_length: the length of each episode in timesteps """ - space = Box(low=low, high=high, shape=(1,), dtype=np.float32) + space = spaces.Box(low=low, high=high, shape=(1,), dtype=np.float32) super().__init__(ep_length=ep_length, space=space) self.eps = eps @@ -80,7 +82,7 @@ class IdentityEnvBox(IdentityEnv): return 1.0 if (self.state - self.eps) <= action <= (self.state + self.eps) else 0.0 -class IdentityEnvMultiDiscrete(IdentityEnv): +class IdentityEnvMultiDiscrete(IdentityEnv[np.ndarray]): def __init__(self, dim: int = 1, ep_length: int = 100): """ Identity environment for testing purposes @@ -88,11 +90,11 @@ class IdentityEnvMultiDiscrete(IdentityEnv): :param dim: the size of the dimensions you want to learn :param ep_length: the length of each episode in timesteps """ - space = MultiDiscrete([dim, dim]) + space = spaces.MultiDiscrete([dim, dim]) super().__init__(ep_length=ep_length, space=space) -class IdentityEnvMultiBinary(IdentityEnv): +class IdentityEnvMultiBinary(IdentityEnv[np.ndarray]): def __init__(self, dim: int = 1, ep_length: int = 100): """ Identity environment for testing purposes @@ -100,11 +102,11 @@ class IdentityEnvMultiBinary(IdentityEnv): :param dim: the size of the dimensions you want to learn :param ep_length: the length of each episode in timesteps """ - space = MultiBinary(dim) + space = spaces.MultiBinary(dim) super().__init__(ep_length=ep_length, space=space) -class FakeImageEnv(Env): +class FakeImageEnv(gym.Env): """ Fake image environment for testing purposes, it mimics Atari games. @@ -128,11 +130,11 @@ class FakeImageEnv(Env): self.observation_shape = (screen_height, screen_width, n_channels) if channel_first: self.observation_shape = (n_channels, screen_height, screen_width) - self.observation_space = Box(low=0, high=255, shape=self.observation_shape, dtype=np.uint8) + self.observation_space = spaces.Box(low=0, high=255, shape=self.observation_shape, dtype=np.uint8) if discrete: - self.action_space = Discrete(action_dim) + self.action_space = spaces.Discrete(action_dim) else: - self.action_space = Box(low=-1, high=1, shape=(5,), dtype=np.float32) + self.action_space = spaces.Box(low=-1, high=1, shape=(5,), dtype=np.float32) self.ep_length = 10 self.current_step = 0 diff --git a/stable_baselines3/common/envs/multi_input_envs.py b/stable_baselines3/common/envs/multi_input_envs.py index 2e5f13f..433591d 100644 --- a/stable_baselines3/common/envs/multi_input_envs.py +++ b/stable_baselines3/common/envs/multi_input_envs.py @@ -2,6 +2,7 @@ from typing import Dict, Union import gym import numpy as np +from gym import spaces from stable_baselines3.common.type_aliases import GymStepReturn @@ -53,14 +54,14 @@ class SimpleMultiObsEnv(gym.Env): self.random_start = random_start self.discrete_actions = discrete_actions if discrete_actions: - self.action_space = gym.spaces.Discrete(4) + self.action_space = spaces.Discrete(4) else: - self.action_space = gym.spaces.Box(0, 1, (4,)) + self.action_space = spaces.Box(0, 1, (4,)) - self.observation_space = gym.spaces.Dict( + self.observation_space = spaces.Dict( spaces={ - "vec": gym.spaces.Box(0, 1, (self.vector_size,), dtype=np.float64), - "img": gym.spaces.Box(0, 255, self.img_size, dtype=np.uint8), + "vec": spaces.Box(0, 1, (self.vector_size,), dtype=np.float64), + "img": spaces.Box(0, 255, self.img_size, dtype=np.uint8), } ) self.count = 0 diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index 5e018fc..48779ef 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -6,9 +6,9 @@ import warnings from copy import deepcopy from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union -import gym import numpy as np import torch as th +from gym import spaces from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer @@ -100,7 +100,7 @@ class OffPolicyAlgorithm(BaseAlgorithm): sde_sample_freq: int = -1, use_sde_at_warmup: bool = False, sde_support: bool = True, - supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None, + supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None, ): super().__init__( @@ -173,7 +173,7 @@ class OffPolicyAlgorithm(BaseAlgorithm): # Use DictReplayBuffer if needed if self.replay_buffer_class is None: - if isinstance(self.observation_space, gym.spaces.Dict): + if isinstance(self.observation_space, spaces.Dict): self.replay_buffer_class = DictReplayBuffer else: self.replay_buffer_class = ReplayBuffer @@ -395,7 +395,7 @@ class OffPolicyAlgorithm(BaseAlgorithm): unscaled_action, _ = self.predict(self._last_obs, deterministic=False) # Rescale the action from [low, high] to [-1, 1] - if isinstance(self.action_space, gym.spaces.Box): + if isinstance(self.action_space, spaces.Box): scaled_action = self.policy.scale_action(unscaled_action) # Add noise to the action (improve exploration) diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 35ad2b9..bc0dda4 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -2,9 +2,9 @@ import sys import time from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union -import gym import numpy as np import torch as th +from gym import spaces from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer @@ -70,7 +70,7 @@ class OnPolicyAlgorithm(BaseAlgorithm): seed: Optional[int] = None, device: Union[th.device, str] = "auto", _init_setup_model: bool = True, - supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None, + supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None, ): super().__init__( @@ -103,7 +103,7 @@ class OnPolicyAlgorithm(BaseAlgorithm): self._setup_lr_schedule() self.set_random_seed(self.seed) - buffer_cls = DictRolloutBuffer if isinstance(self.observation_space, gym.spaces.Dict) else RolloutBuffer + buffer_cls = DictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else RolloutBuffer self.rollout_buffer = buffer_cls( self.n_steps, @@ -169,7 +169,7 @@ class OnPolicyAlgorithm(BaseAlgorithm): # Rescale and perform action clipped_actions = actions # Clip the actions to avoid out of bound error - if isinstance(self.action_space, gym.spaces.Box): + if isinstance(self.action_space, spaces.Box): clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) new_obs, rewards, dones, infos = env.step(clipped_actions) @@ -184,7 +184,7 @@ class OnPolicyAlgorithm(BaseAlgorithm): self._update_info_buffer(infos) n_steps += 1 - if isinstance(self.action_space, gym.spaces.Discrete): + if isinstance(self.action_space, spaces.Discrete): # Reshape in case of discrete action actions = actions.reshape(-1, 1) diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 7331815..a752a7a 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -7,9 +7,9 @@ from abc import ABC, abstractmethod from functools import partial from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union -import gym import numpy as np import torch as th +from gym import spaces from torch import nn from stable_baselines3.common.distributions import ( @@ -60,8 +60,8 @@ class BaseModel(nn.Module): def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Space, features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, features_extractor_kwargs: Optional[Dict[str, Any]] = None, features_extractor: Optional[nn.Module] = None, @@ -344,7 +344,7 @@ class BasePolicy(BaseModel, ABC): # Convert to numpy, and reshape to the original action shape actions = actions.cpu().numpy().reshape((-1,) + self.action_space.shape) - if isinstance(self.action_space, gym.spaces.Box): + if isinstance(self.action_space, spaces.Box): if self.squash_output: # Rescale to proper domain when using squashing actions = self.unscale_action(actions) @@ -415,8 +415,8 @@ class ActorCriticPolicy(BasePolicy): def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Space, lr_schedule: Schedule, net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, activation_fn: Type[nn.Module] = nn.Tanh, @@ -749,8 +749,8 @@ class ActorCriticCnnPolicy(ActorCriticPolicy): def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Space, lr_schedule: Schedule, net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, activation_fn: Type[nn.Module] = nn.Tanh, @@ -822,8 +822,8 @@ class MultiInputActorCriticPolicy(ActorCriticPolicy): def __init__( self, - observation_space: gym.spaces.Dict, - action_space: gym.spaces.Space, + observation_space: spaces.Dict, + action_space: spaces.Space, lr_schedule: Schedule, net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, activation_fn: Type[nn.Module] = nn.Tanh, @@ -890,8 +890,8 @@ class ContinuousCritic(BaseModel): def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Space, net_arch: List[int], features_extractor: nn.Module, features_dim: int, diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index 105d39f..bc766e0 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -3,6 +3,7 @@ from typing import Dict, List, Tuple, Type, Union import gym import torch as th +from gym import spaces from torch import nn from stable_baselines3.common.preprocessing import get_flattened_obs_dim, is_image_space @@ -63,7 +64,7 @@ class NatureCNN(BaseFeaturesExtractor): def __init__( self, - observation_space: gym.spaces.Box, + observation_space: spaces.Box, features_dim: int = 512, normalized_image: bool = False, ) -> None: @@ -267,7 +268,7 @@ class CombinedExtractor(BaseFeaturesExtractor): def __init__( self, - observation_space: gym.spaces.Dict, + observation_space: spaces.Dict, cnn_output_dim: int = 256, normalized_image: bool = False, ) -> None: diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 6aabb32..4dc284e 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -2,6 +2,7 @@ import glob import os import platform import random +import re from collections import deque from itertools import zip_longest from typing import Dict, Iterable, List, Optional, Tuple, Union @@ -9,6 +10,7 @@ from typing import Dict, Iterable, List, Optional, Tuple, Union import gym import numpy as np import torch as th +from gym import spaces import stable_baselines3 as sb3 @@ -210,7 +212,7 @@ def configure_logger( return configure(save_path, format_strings=format_strings) -def check_for_correct_spaces(env: GymEnv, observation_space: gym.spaces.Space, action_space: gym.spaces.Space) -> None: +def check_for_correct_spaces(env: GymEnv, observation_space: spaces.Space, action_space: spaces.Space) -> None: """ Checks that the environment has same spaces as provided ones. Used by BaseAlgorithm to check if spaces match after loading the model with given env. @@ -228,7 +230,7 @@ def check_for_correct_spaces(env: GymEnv, observation_space: gym.spaces.Space, a raise ValueError(f"Action spaces do not match: {action_space} != {env.action_space}") -def is_vectorized_box_observation(observation: np.ndarray, observation_space: gym.spaces.Box) -> bool: +def is_vectorized_box_observation(observation: np.ndarray, observation_space: spaces.Box) -> bool: """ For box observation type, detects and validates the shape, then returns whether or not the observation is vectorized. @@ -249,7 +251,7 @@ def is_vectorized_box_observation(observation: np.ndarray, observation_space: gy ) -def is_vectorized_discrete_observation(observation: Union[int, np.ndarray], observation_space: gym.spaces.Discrete) -> bool: +def is_vectorized_discrete_observation(observation: Union[int, np.ndarray], observation_space: spaces.Discrete) -> bool: """ For discrete observation type, detects and validates the shape, then returns whether or not the observation is vectorized. @@ -269,7 +271,7 @@ def is_vectorized_discrete_observation(observation: Union[int, np.ndarray], obse ) -def is_vectorized_multidiscrete_observation(observation: np.ndarray, observation_space: gym.spaces.MultiDiscrete) -> bool: +def is_vectorized_multidiscrete_observation(observation: np.ndarray, observation_space: spaces.MultiDiscrete) -> bool: """ For multidiscrete observation type, detects and validates the shape, then returns whether or not the observation is vectorized. @@ -290,7 +292,7 @@ def is_vectorized_multidiscrete_observation(observation: np.ndarray, observation ) -def is_vectorized_multibinary_observation(observation: np.ndarray, observation_space: gym.spaces.MultiBinary) -> bool: +def is_vectorized_multibinary_observation(observation: np.ndarray, observation_space: spaces.MultiBinary) -> bool: """ For multibinary observation type, detects and validates the shape, then returns whether or not the observation is vectorized. @@ -311,7 +313,7 @@ def is_vectorized_multibinary_observation(observation: np.ndarray, observation_s ) -def is_vectorized_dict_observation(observation: np.ndarray, observation_space: gym.spaces.Dict) -> bool: +def is_vectorized_dict_observation(observation: np.ndarray, observation_space: spaces.Dict) -> bool: """ For dict observation type, detects and validates the shape, then returns whether or not the observation is vectorized. @@ -355,7 +357,7 @@ def is_vectorized_dict_observation(observation: np.ndarray, observation_space: g ) -def is_vectorized_observation(observation: Union[int, np.ndarray], observation_space: gym.spaces.Space) -> bool: +def is_vectorized_observation(observation: Union[int, np.ndarray], observation_space: spaces.Space) -> bool: """ For every observation type, detects and validates the shape, then returns whether or not the observation is vectorized. @@ -366,11 +368,11 @@ def is_vectorized_observation(observation: Union[int, np.ndarray], observation_s """ is_vec_obs_func_dict = { - gym.spaces.Box: is_vectorized_box_observation, - gym.spaces.Discrete: is_vectorized_discrete_observation, - gym.spaces.MultiDiscrete: is_vectorized_multidiscrete_observation, - gym.spaces.MultiBinary: is_vectorized_multibinary_observation, - gym.spaces.Dict: is_vectorized_dict_observation, + spaces.Box: is_vectorized_box_observation, + spaces.Discrete: is_vectorized_discrete_observation, + spaces.MultiDiscrete: is_vectorized_multidiscrete_observation, + spaces.MultiBinary: is_vectorized_multibinary_observation, + spaces.Dict: is_vectorized_dict_observation, } for space_type, is_vec_obs_func in is_vec_obs_func_dict.items(): @@ -505,7 +507,9 @@ def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]: and a formatted string. """ env_info = { - "OS": f"{platform.platform()} {platform.version()}", + # In OS, a regex is used to add a space between a "#" and a number to avoid + # wrongly linking to another issue on GitHub. Example: turn "#42" to "# 42". + "OS": re.sub(r"#(\d)", r"# \1", f"{platform.platform()} {platform.version()}"), "Python": platform.python_version(), "Stable-Baselines3": sb3.__version__, "PyTorch": th.__version__, @@ -515,7 +519,7 @@ def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]: } env_info_str = "" for key, value in env_info.items(): - env_info_str += f"{key}: {value}\n" + env_info_str += f"- {key}: {value}\n" if print_info: print(env_info_str) return env_info, env_info_str diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index 9870605..0b3e1b4 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -6,6 +6,7 @@ from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, U import cloudpickle import gym import numpy as np +from gym import spaces # Define type aliases here to avoid circular import # Used when we want to access one or more VecEnv @@ -48,14 +49,14 @@ class VecEnv(ABC): """ An abstract asynchronous, vectorized environment. - :param num_envs: the number of environments - :param observation_space: the observation space - :param action_space: the action space + :param num_envs: Number of environments + :param observation_space: Observation space + :param action_space: Action space """ metadata = {"render.modes": ["human", "rgb_array"]} - def __init__(self, num_envs: int, observation_space: gym.spaces.Space, action_space: gym.spaces.Space): + def __init__(self, num_envs: int, observation_space: spaces.Space, action_space: spaces.Space): self.num_envs = num_envs self.observation_space = observation_space self.action_space = action_space @@ -248,8 +249,8 @@ class VecEnvWrapper(VecEnv): def __init__( self, venv: VecEnv, - observation_space: Optional[gym.spaces.Space] = None, - action_space: Optional[gym.spaces.Space] = None, + observation_space: Optional[spaces.Space] = None, + action_space: Optional[spaces.Space] = None, ): self.venv = venv VecEnv.__init__( diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index f723c71..7ff579d 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -4,6 +4,7 @@ from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union import gym import numpy as np +from gym import spaces from stable_baselines3.common.vec_env.base_vec_env import ( CloudpickleWrapper, @@ -196,7 +197,7 @@ class SubprocVecEnv(VecEnv): return [self.remotes[i] for i in indices] -def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: gym.spaces.Space) -> VecEnvObs: +def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: spaces.Space) -> VecEnvObs: """ Flatten observations, depending on the observation space. @@ -210,11 +211,11 @@ def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: gym.space assert isinstance(obs, (list, tuple)), "expected list or tuple of observations per environment" assert len(obs) > 0, "need observations from at least one environment" - if isinstance(space, gym.spaces.Dict): + if isinstance(space, spaces.Dict): assert isinstance(space.spaces, OrderedDict), "Dict space must have ordered subspaces" assert isinstance(obs[0], dict), "non-dict observation for environment with Dict observation space" return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()]) - elif isinstance(space, gym.spaces.Tuple): + elif isinstance(space, spaces.Tuple): assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space" obs_len = len(space.spaces) return tuple(np.stack([o[i] for o in obs]) for i in range(obs_len)) diff --git a/stable_baselines3/common/vec_env/util.py b/stable_baselines3/common/vec_env/util.py index ca590cb..7d318ac 100644 --- a/stable_baselines3/common/vec_env/util.py +++ b/stable_baselines3/common/vec_env/util.py @@ -4,8 +4,8 @@ Helpers for dealing with vectorized environments. from collections import OrderedDict from typing import Any, Dict, List, Tuple -import gym import numpy as np +from gym import spaces from stable_baselines3.common.preprocessing import check_for_nested_spaces from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs @@ -22,7 +22,7 @@ def copy_obs_dict(obs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: return OrderedDict([(k, np.copy(v)) for k, v in obs.items()]) -def dict_to_obs(obs_space: gym.spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> VecEnvObs: +def dict_to_obs(obs_space: spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> VecEnvObs: """ Convert an internal representation raw_obs into the appropriate type specified by space. @@ -33,9 +33,9 @@ def dict_to_obs(obs_space: gym.spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> If space is Dict, function is identity; if space is Tuple, converts dict to Tuple; otherwise, space is unstructured and returns the value raw_obs[None]. """ - if isinstance(obs_space, gym.spaces.Dict): + if isinstance(obs_space, spaces.Dict): return obs_dict - elif isinstance(obs_space, gym.spaces.Tuple): + elif isinstance(obs_space, spaces.Tuple): assert len(obs_dict) == len(obs_space.spaces), "size of observation does not match size of observation space" return tuple(obs_dict[i] for i in range(len(obs_space.spaces))) else: @@ -43,7 +43,7 @@ def dict_to_obs(obs_space: gym.spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> return obs_dict[None] -def obs_space_info(obs_space: gym.spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[int, ...]], Dict[Any, np.dtype]]: +def obs_space_info(obs_space: spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[int, ...]], Dict[Any, np.dtype]]: """ Get dict-structured information about a gym.Space. @@ -58,10 +58,10 @@ def obs_space_info(obs_space: gym.spaces.Space) -> Tuple[List[str], Dict[Any, Tu dtypes: a dict mapping keys to dtypes. """ check_for_nested_spaces(obs_space) - if isinstance(obs_space, gym.spaces.Dict): + if isinstance(obs_space, spaces.Dict): assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces" subspaces = obs_space.spaces - elif isinstance(obs_space, gym.spaces.Tuple): + elif isinstance(obs_space, spaces.Tuple): subspaces = {i: space for i, space in enumerate(obs_space.spaces)} else: assert not hasattr(obs_space, "spaces"), f"Unsupported structured space '{type(obs_space)}'" diff --git a/stable_baselines3/common/vec_env/vec_normalize.py b/stable_baselines3/common/vec_env/vec_normalize.py index ad400d1..8b98413 100644 --- a/stable_baselines3/common/vec_env/vec_normalize.py +++ b/stable_baselines3/common/vec_env/vec_normalize.py @@ -2,8 +2,8 @@ import pickle from copy import deepcopy from typing import Any, Dict, List, Optional, Union -import gym import numpy as np +from gym import spaces from stable_baselines3.common import utils from stable_baselines3.common.preprocessing import is_image_space @@ -48,14 +48,14 @@ class VecNormalize(VecEnvWrapper): if self.norm_obs: self._sanity_checks() - if isinstance(self.observation_space, gym.spaces.Dict): + if isinstance(self.observation_space, spaces.Dict): self.obs_spaces = self.observation_space.spaces self.obs_rms = {key: RunningMeanStd(shape=self.obs_spaces[key].shape) for key in self.norm_obs_keys} # Update observation space when using image # See explanation below and GH #1214 for key in self.obs_rms.keys(): if is_image_space(self.obs_spaces[key]): - self.observation_space.spaces[key] = gym.spaces.Box( + self.observation_space.spaces[key] = spaces.Box( low=-clip_obs, high=clip_obs, shape=self.obs_spaces[key].shape, @@ -74,7 +74,7 @@ class VecNormalize(VecEnvWrapper): # in other cases but this will cause backward-incompatible change # and break already saved policies. if is_image_space(self.observation_space): - self.observation_space = gym.spaces.Box( + self.observation_space = spaces.Box( low=-clip_obs, high=clip_obs, shape=self.observation_space.shape, @@ -98,13 +98,13 @@ class VecNormalize(VecEnvWrapper): """ Check the observations that are going to be normalized are of the correct type (spaces.Box). """ - if isinstance(self.observation_space, gym.spaces.Dict): + if isinstance(self.observation_space, spaces.Dict): # By default, we normalize all keys if self.norm_obs_keys is None: self.norm_obs_keys = list(self.observation_space.spaces.keys()) # Check that all keys are of type Box for obs_key in self.norm_obs_keys: - if not isinstance(self.observation_space.spaces[obs_key], gym.spaces.Box): + if not isinstance(self.observation_space.spaces[obs_key], spaces.Box): raise ValueError( f"VecNormalize only supports `gym.spaces.Box` observation spaces but {obs_key} " f"is of type {self.observation_space.spaces[obs_key]}. " @@ -112,7 +112,7 @@ class VecNormalize(VecEnvWrapper): " that should be normalized via the `norm_obs_keys` parameter." ) - elif isinstance(self.observation_space, gym.spaces.Box): + elif isinstance(self.observation_space, spaces.Box): if self.norm_obs_keys is not None: raise ValueError("`norm_obs_keys` param is applicable only with `gym.spaces.Dict` observation spaces") @@ -143,7 +143,7 @@ class VecNormalize(VecEnvWrapper): :param state:""" # Backward compatibility - if "norm_obs_keys" not in state and isinstance(state["observation_space"], gym.spaces.Dict): + if "norm_obs_keys" not in state and isinstance(state["observation_space"], spaces.Dict): state["norm_obs_keys"] = list(state["observation_space"].spaces.keys()) self.__dict__.update(state) assert "venv" not in state diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index e8f4945..dd8794e 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -1,9 +1,9 @@ import warnings from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union -import gym import numpy as np import torch as th +from gym import spaces from torch.nn import functional as F from stable_baselines3.common.buffers import ReplayBuffer @@ -116,7 +116,7 @@ class DQN(OffPolicyAlgorithm): seed=seed, sde_support=False, optimize_memory_usage=optimize_memory_usage, - supported_action_spaces=(gym.spaces.Discrete,), + supported_action_spaces=(spaces.Discrete,), support_multi_env=True, ) diff --git a/stable_baselines3/dqn/policies.py b/stable_baselines3/dqn/policies.py index 1686ec0..22e6d0a 100644 --- a/stable_baselines3/dqn/policies.py +++ b/stable_baselines3/dqn/policies.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Optional, Type -import gym import torch as th +from gym import spaces from torch import nn from stable_baselines3.common.policies import BasePolicy @@ -29,8 +29,8 @@ class QNetwork(BasePolicy): def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Space, features_extractor: nn.Module, features_dim: int, net_arch: Optional[List[int]] = None, @@ -106,8 +106,8 @@ class DQNPolicy(BasePolicy): def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Space, lr_schedule: Schedule, net_arch: Optional[List[int]] = None, activation_fn: Type[nn.Module] = nn.ReLU, @@ -227,8 +227,8 @@ class CnnPolicy(DQNPolicy): def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Space, lr_schedule: Schedule, net_arch: Optional[List[int]] = None, activation_fn: Type[nn.Module] = nn.ReLU, @@ -272,8 +272,8 @@ class MultiInputPolicy(DQNPolicy): def __init__( self, - observation_space: gym.spaces.Dict, - action_space: gym.spaces.Space, + observation_space: spaces.Dict, + action_space: spaces.Space, lr_schedule: Schedule, net_arch: Optional[List[int]] = None, activation_fn: Type[nn.Module] = nn.ReLU, diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py index d398a70..e756097 100644 --- a/stable_baselines3/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union -import gym import torch as th +from gym import spaces from torch import nn from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution @@ -47,8 +47,8 @@ class Actor(BasePolicy): def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Space, net_arch: List[int], features_extractor: nn.Module, features_dim: int, @@ -207,8 +207,8 @@ class SACPolicy(BasePolicy): def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Space, lr_schedule: Schedule, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.ReLU, @@ -385,8 +385,8 @@ class CnnPolicy(SACPolicy): def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Space, lr_schedule: Schedule, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.ReLU, @@ -451,8 +451,8 @@ class MultiInputPolicy(SACPolicy): def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Space, lr_schedule: Schedule, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.ReLU, diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index fdd23c2..74285b6 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -1,8 +1,8 @@ from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union -import gym import numpy as np import torch as th +from gym import spaces from torch.nn import functional as F from stable_baselines3.common.buffers import ReplayBuffer @@ -133,7 +133,7 @@ class SAC(OffPolicyAlgorithm): sde_sample_freq=sde_sample_freq, use_sde_at_warmup=use_sde_at_warmup, optimize_memory_usage=optimize_memory_usage, - supported_action_spaces=(gym.spaces.Box), + supported_action_spaces=(spaces.Box), support_multi_env=True, ) diff --git a/stable_baselines3/td3/policies.py b/stable_baselines3/td3/policies.py index 676adda..6c4a1e9 100644 --- a/stable_baselines3/td3/policies.py +++ b/stable_baselines3/td3/policies.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Optional, Type, Union -import gym import torch as th +from gym import spaces from torch import nn from stable_baselines3.common.policies import BasePolicy, ContinuousCritic @@ -34,8 +34,8 @@ class Actor(BasePolicy): def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Space, net_arch: List[int], features_extractor: nn.Module, features_dim: int, @@ -108,8 +108,8 @@ class TD3Policy(BasePolicy): def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Space, lr_schedule: Schedule, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.ReLU, @@ -271,8 +271,8 @@ class CnnPolicy(TD3Policy): def __init__( self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, + observation_space: spaces.Space, + action_space: spaces.Space, lr_schedule: Schedule, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.ReLU, @@ -325,8 +325,8 @@ class MultiInputPolicy(TD3Policy): def __init__( self, - observation_space: gym.spaces.Dict, - action_space: gym.spaces.Space, + observation_space: spaces.Dict, + action_space: spaces.Space, lr_schedule: Schedule, net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.ReLU, diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index 97812a9..ae442e1 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -1,8 +1,8 @@ from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union -import gym import numpy as np import torch as th +from gym import spaces from torch.nn import functional as F from stable_baselines3.common.buffers import ReplayBuffer @@ -116,7 +116,7 @@ class TD3(OffPolicyAlgorithm): seed=seed, sde_support=False, optimize_memory_usage=optimize_memory_usage, - supported_action_spaces=(gym.spaces.Box), + supported_action_spaces=(spaces.Box), support_multi_env=True, ) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 89e17c2..a02b7e4 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.7.0a10 +1.7.0a11 diff --git a/tests/test_env_checker.py b/tests/test_env_checker.py index 0b0a82d..3159786 100644 --- a/tests/test_env_checker.py +++ b/tests/test_env_checker.py @@ -1,14 +1,14 @@ import gym import numpy as np import pytest -from gym.spaces import Box, Dict, Discrete +from gym import spaces from stable_baselines3.common.env_checker import check_env class ActionDictTestEnv(gym.Env): - action_space = Dict({"position": Discrete(1), "velocity": Discrete(1)}) - observation_space = Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32) + action_space = spaces.Dict({"position": spaces.Discrete(1), "velocity": spaces.Discrete(1)}) + observation_space = spaces.Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32) def step(self, action): observation = np.array([1.0, 1.5, 0.5], dtype=self.observation_space.dtype) diff --git a/tests/test_gae.py b/tests/test_gae.py index 8e461ed..c90470f 100644 --- a/tests/test_gae.py +++ b/tests/test_gae.py @@ -2,6 +2,7 @@ import gym import numpy as np import pytest import torch as th +from gym import spaces from stable_baselines3 import A2C, PPO, SAC from stable_baselines3.common.callbacks import BaseCallback @@ -11,8 +12,8 @@ from stable_baselines3.common.policies import ActorCriticPolicy class CustomEnv(gym.Env): def __init__(self, max_steps=8): super().__init__() - self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) - self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) + self.observation_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) + self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) self.max_steps = max_steps self.n_steps = 0 @@ -39,8 +40,8 @@ class InfiniteHorizonEnv(gym.Env): def __init__(self, n_states=4): super().__init__() self.n_states = n_states - self.observation_space = gym.spaces.Discrete(n_states) - self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) + self.observation_space = spaces.Discrete(n_states) + self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) self.current_state = 0 def reset(self): diff --git a/tests/test_logger.py b/tests/test_logger.py index 92b65e8..94f1e21 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -8,6 +8,7 @@ import gym import numpy as np import pytest import torch as th +from gym import spaces from matplotlib import pyplot as plt from pandas.errors import EmptyDataError @@ -350,8 +351,8 @@ class TimeDelayEnv(gym.Env): def __init__(self, delay: float = 0.01): super().__init__() self.delay = delay - self.observation_space = gym.spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32) - self.action_space = gym.spaces.Discrete(2) + self.observation_space = spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32) + self.action_space = spaces.Discrete(2) def reset(self): return self.observation_space.sample() diff --git a/tests/test_predict.py b/tests/test_predict.py index 6343e2f..579abff 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -2,6 +2,7 @@ import gym import numpy as np import pytest import torch as th +from gym import spaces from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 from stable_baselines3.common.envs import IdentityEnv @@ -17,7 +18,7 @@ MODEL_LIST = [ ] -class SubClassedBox(gym.spaces.Box): +class SubClassedBox(spaces.Box): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/tests/test_spaces.py b/tests/test_spaces.py index 6f530b7..6dd6dc4 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -1,6 +1,7 @@ import gym import numpy as np import pytest +from gym import spaces from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 from stable_baselines3.common.env_util import make_vec_env @@ -10,8 +11,8 @@ from stable_baselines3.common.evaluation import evaluate_policy class DummyMultiDiscreteSpace(gym.Env): def __init__(self, nvec): super().__init__() - self.observation_space = gym.spaces.MultiDiscrete(nvec) - self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) + self.observation_space = spaces.MultiDiscrete(nvec) + self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) def reset(self): return self.observation_space.sample() @@ -23,8 +24,8 @@ class DummyMultiDiscreteSpace(gym.Env): class DummyMultiBinary(gym.Env): def __init__(self, n): super().__init__() - self.observation_space = gym.spaces.MultiBinary(n) - self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) + self.observation_space = spaces.MultiBinary(n) + self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) def reset(self): return self.observation_space.sample() @@ -36,8 +37,8 @@ class DummyMultiBinary(gym.Env): class DummyMultidimensionalAction(gym.Env): def __init__(self): super().__init__() - self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) - self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32) + self.observation_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) + self.action_space = spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32) def reset(self): return self.observation_space.sample() @@ -55,7 +56,7 @@ def test_identity_spaces(model_class, env): """ # DQN only support discrete actions if model_class == DQN: - env.action_space = gym.spaces.Discrete(4) + env.action_space = spaces.Discrete(4) env = gym.wrappers.TimeLimit(env, max_episode_steps=100) diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index 7a09581..e661e1a 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -6,6 +6,7 @@ import multiprocessing import gym import numpy as np import pytest +from gym import spaces from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize @@ -64,7 +65,7 @@ class CustomGymEnv(gym.Env): def test_vecenv_func_checker(): """The functions in ``env_fns'' must return distinct instances since we need distinct environments.""" - env = CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2))) + env = CustomGymEnv(spaces.Box(low=np.zeros(2), high=np.ones(2))) with pytest.raises(ValueError): DummyVecEnv([lambda: env for _ in range(N_ENVS)]) @@ -78,7 +79,7 @@ def test_vecenv_custom_calls(vec_env_class, vec_env_wrapper): """Test access to methods/attributes of vectorized environments""" def make_env(): - return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2))) + return CustomGymEnv(spaces.Box(low=np.zeros(2), high=np.ones(2))) vec_env = vec_env_class([make_env for _ in range(N_ENVS)]) @@ -147,8 +148,8 @@ class StepEnv(gym.Env): def __init__(self, max_steps): """Gym environment for testing that terminal observation is inserted correctly.""" - self.action_space = gym.spaces.Discrete(2) - self.observation_space = gym.spaces.Box(np.array([0]), np.array([999]), dtype="int") + self.action_space = spaces.Discrete(2) + self.observation_space = spaces.Box(np.array([0]), np.array([999]), dtype="int") self.max_steps = max_steps self.current_step = 0 @@ -210,10 +211,10 @@ def test_vecenv_terminal_obs(vec_env_class, vec_env_wrapper): SPACES = collections.OrderedDict( [ - ("discrete", gym.spaces.Discrete(2)), - ("multidiscrete", gym.spaces.MultiDiscrete([2, 3])), - ("multibinary", gym.spaces.MultiBinary(3)), - ("continuous", gym.spaces.Box(low=np.zeros(2), high=np.ones(2))), + ("discrete", spaces.Discrete(2)), + ("multidiscrete", spaces.MultiDiscrete([2, 3])), + ("multibinary", spaces.MultiBinary(3)), + ("continuous", spaces.Box(low=np.zeros(2), high=np.ones(2))), ] ) @@ -252,7 +253,7 @@ def test_vecenv_single_space(vec_env_class, space): check_vecenv_spaces(vec_env_class, space, obs_assert) -class _UnorderedDictSpace(gym.spaces.Dict): +class _UnorderedDictSpace(spaces.Dict): """Like DictSpace, but returns an unordered dict when sampling.""" def sample(self): @@ -262,7 +263,7 @@ class _UnorderedDictSpace(gym.spaces.Dict): @pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES) def test_vecenv_dict_spaces(vec_env_class): """Test dictionary observation spaces with vectorized environments.""" - space = gym.spaces.Dict(SPACES) + space = spaces.Dict(SPACES) def obs_assert(obs): assert isinstance(obs, collections.OrderedDict) @@ -280,7 +281,7 @@ def test_vecenv_dict_spaces(vec_env_class): @pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES) def test_vecenv_tuple_spaces(vec_env_class): """Test tuple observation spaces with vectorized environments.""" - space = gym.spaces.Tuple(tuple(SPACES.values())) + space = spaces.Tuple(tuple(SPACES.values())) def obs_assert(obs): assert isinstance(obs, tuple) @@ -298,7 +299,7 @@ def test_subproc_start_method(): all_methods = {"forkserver", "spawn", "fork"} available_methods = multiprocessing.get_all_start_methods() start_methods += list(all_methods.intersection(available_methods)) - space = gym.spaces.Discrete(2) + space = spaces.Discrete(2) def obs_assert(obs): return check_vecenv_obs(obs, space) @@ -338,7 +339,7 @@ class CustomWrapperBB(CustomWrapperB): def test_vecenv_wrapper_getattr(): def make_env(): - return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2))) + return CustomGymEnv(spaces.Box(low=np.zeros(2), high=np.ones(2))) vec_env = DummyVecEnv([make_env for _ in range(N_ENVS)]) wrapped = CustomWrapperA(CustomWrapperBB(vec_env)) @@ -367,7 +368,7 @@ def test_framestack_vecenv(): def make_image_env(): return CustomGymEnv( - gym.spaces.Box( + spaces.Box( low=np.zeros(image_space_shape), high=np.ones(image_space_shape) * 255, dtype=np.uint8, @@ -376,7 +377,7 @@ def test_framestack_vecenv(): def make_transposed_image_env(): return CustomGymEnv( - gym.spaces.Box( + spaces.Box( low=np.zeros(transposed_image_space_shape), high=np.ones(transposed_image_space_shape) * 255, dtype=np.uint8, @@ -384,7 +385,7 @@ def test_framestack_vecenv(): ) def make_non_image_env(): - return CustomGymEnv(gym.spaces.Box(low=np.zeros((2,)), high=np.ones((2,)))) + return CustomGymEnv(spaces.Box(low=np.zeros((2,)), high=np.ones((2,)))) vec_env = DummyVecEnv([make_image_env for _ in range(N_ENVS)]) vec_env = VecFrameStack(vec_env, n_stack=2) @@ -433,10 +434,10 @@ def test_framestack_vecenv(): def test_vec_env_is_wrapped(): # Test is_wrapped call of subproc workers def make_env(): - return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2))) + return CustomGymEnv(spaces.Box(low=np.zeros(2), high=np.ones(2))) def make_monitored_env(): - return Monitor(CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2)))) + return Monitor(CustomGymEnv(spaces.Box(low=np.zeros(2), high=np.ones(2)))) # One with monitor, one without vec_env = SubprocVecEnv([make_env, make_monitored_env]) @@ -457,7 +458,7 @@ def test_vec_env_is_wrapped(): @pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES) def test_vec_seeding(vec_env_class): def make_env(): - return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2))) + return CustomGymEnv(spaces.Box(low=np.zeros(2), high=np.ones(2))) # For SubprocVecEnv check for all starting methods start_methods = [None] diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index 00af193..7b443c2 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -23,8 +23,8 @@ class DummyRewardEnv(gym.Env): metadata = {} def __init__(self, return_reward_idx=0): - self.action_space = gym.spaces.Discrete(2) - self.observation_space = gym.spaces.Box(low=np.array([-1.0]), high=np.array([1.0])) + self.action_space = spaces.Discrete(2) + self.observation_space = spaces.Box(low=np.array([-1.0]), high=np.array([1.0])) self.returned_rewards = [0, 1, 3, 4] self.return_reward_idx = return_reward_idx self.t = self.return_reward_idx