Standardize the use of from gym import spaces (#1240)

* generalize the use of `from gym import spaces`

* command line get system info

* Documentation line length for doc

* update changelog

* add space before os plateform to avoid ref to other issue

* format

* get_system_info update in changelog

* fix type check error

* fix get system info

* add comment about regex

* update version
This commit is contained in:
Quentin Gallouédec 2023-01-02 14:51:11 +01:00 committed by GitHub
parent 2bb8ef5e63
commit 4fa17dcf0f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
34 changed files with 219 additions and 196 deletions

View file

@ -52,9 +52,8 @@ body:
* Versions of any other relevant libraries
You can use `sb3.get_system_info()` to print relevant packages info:
```python
import stable_baselines3 as sb3
sb3.get_system_info()
```sh
python -c 'import stable_baselines3 as sb3; sb3.get_system_info()'
```
- type: checkboxes
id: terms

View file

@ -36,6 +36,7 @@ body:
```python
import gym
import numpy as np
from gym import spaces
from stable_baselines3 import A2C
from stable_baselines3.common.env_checker import check_env
@ -43,20 +44,20 @@ body:
class CustomEnv(gym.Env):
def __init__(self):
super().__init__()
self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(14,))
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(6,))
def __init__(self):
super().__init__()
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(14,))
self.action_space = spaces.Box(low=-1, high=1, shape=(6,))
def reset(self):
return self.observation_space.sample()
def reset(self):
return self.observation_space.sample()
def step(self, action):
obs = self.observation_space.sample()
reward = 1.0
done = False
info = {}
return obs, reward, done, info
def step(self, action):
obs = self.observation_space.sample()
reward = 1.0
done = False
info = {}
return obs, reward, done, info
env = CustomEnv()
check_env(env)
@ -86,9 +87,8 @@ body:
* Versions of any other relevant libraries
You can use `sb3.get_system_info()` to print relevant packages info:
```python
import stable_baselines3 as sb3
sb3.get_system_info()
```sh
python -c 'import stable_baselines3 as sb3; sb3.get_system_info()'
```
- type: checkboxes
id: terms

View file

@ -38,7 +38,8 @@ pip install -e .[docs,tests,extra]
## Codestyle
We are using [black codestyle](https://github.com/psf/black) (max line length of 127 characters) together with [isort](https://github.com/timothycrosley/isort) to sort the imports.
We use [black codestyle](https://github.com/psf/black) (max line length of 127 characters) together with [isort](https://github.com/timothycrosley/isort) to sort the imports.
For the documentation, we use the default line length of 88 characters per line.
**Please run `make format`** to reformat your code. You can check the codestyle using `make check-codestyle` and `make lint`.

View file

@ -27,14 +27,17 @@ That is to say, your environment must implement the following methods (and inher
.. code-block:: python
import gym
import numpy as np
from gym import spaces
class CustomEnv(gym.Env):
"""Custom Environment that follows gym interface"""
"""Custom Environment that follows gym interface."""
metadata = {"render.modes": ["human"]}
def __init__(self, arg1, arg2, ...):
super(CustomEnv, self).__init__()
super().__init__()
# Define action and observation space
# They must be gym.spaces objects
# Example when using discrete actions:
@ -46,12 +49,15 @@ That is to say, your environment must implement the following methods (and inher
def step(self, action):
...
return observation, reward, done, info
def reset(self):
...
return observation # reward, done, info can't be included
def render(self, mode="human"):
...
def close (self):
def close(self):
...

View file

@ -125,9 +125,9 @@ that derives from ``BaseFeaturesExtractor`` and then pass it to the model when t
.. code-block:: python
import gym
import torch as th
import torch.nn as nn
from gym import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
@ -140,7 +140,7 @@ that derives from ``BaseFeaturesExtractor`` and then pass it to the model when t
This corresponds to the number of unit for the last layer.
"""
def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256):
def __init__(self, observation_space: spaces.Box, features_dim: int = 256):
super().__init__(observation_space, features_dim)
# We assume CxHxW images (channels first)
# Re-ordering will be done by pre-preprocessing or wrapper
@ -199,7 +199,7 @@ downsampling and "vector" with a single linear layer.
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
class CustomCombinedExtractor(BaseFeaturesExtractor):
def __init__(self, observation_space: gym.spaces.Dict):
def __init__(self, observation_space: spaces.Dict):
# We do not know features-dim here before going over all the items,
# so put something dummy for now. PyTorch requires calling
# nn.Module.__init__ before adding modules
@ -310,7 +310,7 @@ If your task requires even more granular control over the policy/value architect
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
import gym
from gym import spaces
import torch as th
from torch import nn
@ -367,8 +367,8 @@ If your task requires even more granular control over the policy/value architect
class CustomActorCriticPolicy(ActorCriticPolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Callable[[float], float],
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,

View file

@ -4,7 +4,7 @@ Changelog
==========
Release 1.7.0a10 (WIP)
Release 1.7.0a11 (WIP)
--------------------------
.. note::
@ -71,6 +71,8 @@ Others:
- Upgraded GitHub CI/setup-python to v4 and checkout to v3
- Set tensors construction directly on the device (~8% speed boost on GPU)
- Monkey-patched ``np.bool = bool`` so gym 0.21 is compatible with NumPy 1.24+
- Standardized the use of ``from gym import spaces``
- Modified ``get_system_info`` to avoid issue linked to copy-pasting on GitHub issue
Documentation:
^^^^^^^^^^^^^^

View file

@ -11,6 +11,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Un
import gym
import numpy as np
import torch as th
from gym import spaces
from stable_baselines3.common import utils
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback
@ -101,7 +102,7 @@ class BaseAlgorithm(ABC):
seed: Optional[int] = None,
use_sde: bool = False,
sde_sample_freq: int = -1,
supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None,
supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None,
):
if isinstance(policy, str):
self.policy_class = self._get_policy_from_name(policy)
@ -117,8 +118,8 @@ class BaseAlgorithm(ABC):
self._vec_normalize_env = unwrap_vec_normalize(env)
self.verbose = verbose
self.policy_kwargs = {} if policy_kwargs is None else policy_kwargs
self.observation_space = None # type: Optional[gym.spaces.Space]
self.action_space = None # type: Optional[gym.spaces.Space]
self.observation_space = None # type: Optional[spaces.Space]
self.action_space = None # type: Optional[spaces.Space]
self.n_envs = None
self.num_timesteps = 0
# Used for updating schedules
@ -175,13 +176,13 @@ class BaseAlgorithm(ABC):
)
# Catch common mistake: using MlpPolicy/CnnPolicy instead of MultiInputPolicy
if policy in ["MlpPolicy", "CnnPolicy"] and isinstance(self.observation_space, gym.spaces.Dict):
if policy in ["MlpPolicy", "CnnPolicy"] and isinstance(self.observation_space, spaces.Dict):
raise ValueError(f"You must use `MultiInputPolicy` when working with dict observation space, not {policy}")
if self.use_sde and not isinstance(self.action_space, gym.spaces.Box):
if self.use_sde and not isinstance(self.action_space, spaces.Box):
raise ValueError("generalized State-Dependent Exploration (gSDE) can only be used with continuous actions.")
if isinstance(self.action_space, gym.spaces.Box):
if isinstance(self.action_space, spaces.Box):
assert np.all(
np.isfinite(np.array([self.action_space.low, self.action_space.high]))
), "Continuous action space must have a finite lower and upper bound"
@ -212,7 +213,7 @@ class BaseAlgorithm(ABC):
if not is_vecenv_wrapped(env, VecTransposeImage):
wrap_with_vectranspose = False
if isinstance(env.observation_space, gym.spaces.Dict):
if isinstance(env.observation_space, spaces.Dict):
# If even one of the keys is a image-space in need of transpose, apply transpose
# If the image spaces are not consistent (for instance one is channel first,
# the other channel last), VecTransposeImage will throw an error

View file

@ -3,7 +3,6 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
import gym
import numpy as np
import torch as th
from gym import spaces
@ -659,7 +658,7 @@ class TanhBijector:
def make_proba_distribution(
action_space: gym.spaces.Space, use_sde: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None
action_space: spaces.Space, use_sde: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None
) -> Distribution:
"""
Return an instance of Distribution for the correct type of action space

View file

@ -266,7 +266,7 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action
def _check_spaces(env: gym.Env) -> None:
"""
Check that the observation and action spaces are defined and inherit from gym.spaces.Space. For
Check that the observation and action spaces are defined and inherit from spaces.Space. For
envs that follow the goal-conditioned standard (previously, the gym.GoalEnv interface) we check
the observation space is gym.spaces.Dict
"""

View file

@ -1,14 +1,16 @@
from typing import Optional, Union
from typing import Any, Dict, Generic, Optional, Tuple, TypeVar, Union
import gym
import numpy as np
from gym import Env, Space
from gym.spaces import Box, Discrete, MultiBinary, MultiDiscrete
from gym import spaces
from stable_baselines3.common.type_aliases import GymObs, GymStepReturn
T = TypeVar("T", int, np.ndarray)
class IdentityEnv(Env):
def __init__(self, dim: Optional[int] = None, space: Optional[Space] = None, ep_length: int = 100):
class IdentityEnv(gym.Env, Generic[T]):
def __init__(self, dim: Optional[int] = None, space: Optional[spaces.Space] = None, ep_length: int = 100):
"""
Identity environment for testing purposes
@ -22,7 +24,7 @@ class IdentityEnv(Env):
if space is None:
if dim is None:
dim = 1
space = Discrete(dim)
space = spaces.Discrete(dim)
else:
assert dim is None, "arguments for both 'dim' and 'space' provided: at most one allowed"
@ -32,13 +34,13 @@ class IdentityEnv(Env):
self.num_resets = -1 # Becomes 0 after __init__ exits.
self.reset()
def reset(self) -> GymObs:
def reset(self) -> T:
self.current_step = 0
self.num_resets += 1
self._choose_next_state()
return self.state
def step(self, action: Union[int, np.ndarray]) -> GymStepReturn:
def step(self, action: T) -> Tuple[T, float, bool, Dict[str, Any]]:
reward = self._get_reward(action)
self._choose_next_state()
self.current_step += 1
@ -48,14 +50,14 @@ class IdentityEnv(Env):
def _choose_next_state(self) -> None:
self.state = self.action_space.sample()
def _get_reward(self, action: Union[int, np.ndarray]) -> float:
def _get_reward(self, action: T) -> float:
return 1.0 if np.all(self.state == action) else 0.0
def render(self, mode: str = "human") -> None:
pass
class IdentityEnvBox(IdentityEnv):
class IdentityEnvBox(IdentityEnv[np.ndarray]):
def __init__(self, low: float = -1.0, high: float = 1.0, eps: float = 0.05, ep_length: int = 100):
"""
Identity environment for testing purposes
@ -65,7 +67,7 @@ class IdentityEnvBox(IdentityEnv):
:param eps: the epsilon bound for correct value
:param ep_length: the length of each episode in timesteps
"""
space = Box(low=low, high=high, shape=(1,), dtype=np.float32)
space = spaces.Box(low=low, high=high, shape=(1,), dtype=np.float32)
super().__init__(ep_length=ep_length, space=space)
self.eps = eps
@ -80,7 +82,7 @@ class IdentityEnvBox(IdentityEnv):
return 1.0 if (self.state - self.eps) <= action <= (self.state + self.eps) else 0.0
class IdentityEnvMultiDiscrete(IdentityEnv):
class IdentityEnvMultiDiscrete(IdentityEnv[np.ndarray]):
def __init__(self, dim: int = 1, ep_length: int = 100):
"""
Identity environment for testing purposes
@ -88,11 +90,11 @@ class IdentityEnvMultiDiscrete(IdentityEnv):
:param dim: the size of the dimensions you want to learn
:param ep_length: the length of each episode in timesteps
"""
space = MultiDiscrete([dim, dim])
space = spaces.MultiDiscrete([dim, dim])
super().__init__(ep_length=ep_length, space=space)
class IdentityEnvMultiBinary(IdentityEnv):
class IdentityEnvMultiBinary(IdentityEnv[np.ndarray]):
def __init__(self, dim: int = 1, ep_length: int = 100):
"""
Identity environment for testing purposes
@ -100,11 +102,11 @@ class IdentityEnvMultiBinary(IdentityEnv):
:param dim: the size of the dimensions you want to learn
:param ep_length: the length of each episode in timesteps
"""
space = MultiBinary(dim)
space = spaces.MultiBinary(dim)
super().__init__(ep_length=ep_length, space=space)
class FakeImageEnv(Env):
class FakeImageEnv(gym.Env):
"""
Fake image environment for testing purposes, it mimics Atari games.
@ -128,11 +130,11 @@ class FakeImageEnv(Env):
self.observation_shape = (screen_height, screen_width, n_channels)
if channel_first:
self.observation_shape = (n_channels, screen_height, screen_width)
self.observation_space = Box(low=0, high=255, shape=self.observation_shape, dtype=np.uint8)
self.observation_space = spaces.Box(low=0, high=255, shape=self.observation_shape, dtype=np.uint8)
if discrete:
self.action_space = Discrete(action_dim)
self.action_space = spaces.Discrete(action_dim)
else:
self.action_space = Box(low=-1, high=1, shape=(5,), dtype=np.float32)
self.action_space = spaces.Box(low=-1, high=1, shape=(5,), dtype=np.float32)
self.ep_length = 10
self.current_step = 0

View file

@ -2,6 +2,7 @@ from typing import Dict, Union
import gym
import numpy as np
from gym import spaces
from stable_baselines3.common.type_aliases import GymStepReturn
@ -53,14 +54,14 @@ class SimpleMultiObsEnv(gym.Env):
self.random_start = random_start
self.discrete_actions = discrete_actions
if discrete_actions:
self.action_space = gym.spaces.Discrete(4)
self.action_space = spaces.Discrete(4)
else:
self.action_space = gym.spaces.Box(0, 1, (4,))
self.action_space = spaces.Box(0, 1, (4,))
self.observation_space = gym.spaces.Dict(
self.observation_space = spaces.Dict(
spaces={
"vec": gym.spaces.Box(0, 1, (self.vector_size,), dtype=np.float64),
"img": gym.spaces.Box(0, 255, self.img_size, dtype=np.uint8),
"vec": spaces.Box(0, 1, (self.vector_size,), dtype=np.float64),
"img": spaces.Box(0, 255, self.img_size, dtype=np.uint8),
}
)
self.count = 0

View file

@ -6,9 +6,9 @@ import warnings
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import gym
import numpy as np
import torch as th
from gym import spaces
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer
@ -100,7 +100,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
sde_sample_freq: int = -1,
use_sde_at_warmup: bool = False,
sde_support: bool = True,
supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None,
supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None,
):
super().__init__(
@ -173,7 +173,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
# Use DictReplayBuffer if needed
if self.replay_buffer_class is None:
if isinstance(self.observation_space, gym.spaces.Dict):
if isinstance(self.observation_space, spaces.Dict):
self.replay_buffer_class = DictReplayBuffer
else:
self.replay_buffer_class = ReplayBuffer
@ -395,7 +395,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
unscaled_action, _ = self.predict(self._last_obs, deterministic=False)
# Rescale the action from [low, high] to [-1, 1]
if isinstance(self.action_space, gym.spaces.Box):
if isinstance(self.action_space, spaces.Box):
scaled_action = self.policy.scale_action(unscaled_action)
# Add noise to the action (improve exploration)

View file

@ -2,9 +2,9 @@ import sys
import time
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import gym
import numpy as np
import torch as th
from gym import spaces
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
@ -70,7 +70,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None,
supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None,
):
super().__init__(
@ -103,7 +103,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
self._setup_lr_schedule()
self.set_random_seed(self.seed)
buffer_cls = DictRolloutBuffer if isinstance(self.observation_space, gym.spaces.Dict) else RolloutBuffer
buffer_cls = DictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else RolloutBuffer
self.rollout_buffer = buffer_cls(
self.n_steps,
@ -169,7 +169,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
# Rescale and perform action
clipped_actions = actions
# Clip the actions to avoid out of bound error
if isinstance(self.action_space, gym.spaces.Box):
if isinstance(self.action_space, spaces.Box):
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
new_obs, rewards, dones, infos = env.step(clipped_actions)
@ -184,7 +184,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
self._update_info_buffer(infos)
n_steps += 1
if isinstance(self.action_space, gym.spaces.Discrete):
if isinstance(self.action_space, spaces.Discrete):
# Reshape in case of discrete action
actions = actions.reshape(-1, 1)

View file

@ -7,9 +7,9 @@ from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import gym
import numpy as np
import torch as th
from gym import spaces
from torch import nn
from stable_baselines3.common.distributions import (
@ -60,8 +60,8 @@ class BaseModel(nn.Module):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
features_extractor: Optional[nn.Module] = None,
@ -344,7 +344,7 @@ class BasePolicy(BaseModel, ABC):
# Convert to numpy, and reshape to the original action shape
actions = actions.cpu().numpy().reshape((-1,) + self.action_space.shape)
if isinstance(self.action_space, gym.spaces.Box):
if isinstance(self.action_space, spaces.Box):
if self.squash_output:
# Rescale to proper domain when using squashing
actions = self.unscale_action(actions)
@ -415,8 +415,8 @@ class ActorCriticPolicy(BasePolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
@ -749,8 +749,8 @@ class ActorCriticCnnPolicy(ActorCriticPolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
@ -822,8 +822,8 @@ class MultiInputActorCriticPolicy(ActorCriticPolicy):
def __init__(
self,
observation_space: gym.spaces.Dict,
action_space: gym.spaces.Space,
observation_space: spaces.Dict,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
@ -890,8 +890,8 @@ class ContinuousCritic(BaseModel):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
net_arch: List[int],
features_extractor: nn.Module,
features_dim: int,

View file

@ -3,6 +3,7 @@ from typing import Dict, List, Tuple, Type, Union
import gym
import torch as th
from gym import spaces
from torch import nn
from stable_baselines3.common.preprocessing import get_flattened_obs_dim, is_image_space
@ -63,7 +64,7 @@ class NatureCNN(BaseFeaturesExtractor):
def __init__(
self,
observation_space: gym.spaces.Box,
observation_space: spaces.Box,
features_dim: int = 512,
normalized_image: bool = False,
) -> None:
@ -267,7 +268,7 @@ class CombinedExtractor(BaseFeaturesExtractor):
def __init__(
self,
observation_space: gym.spaces.Dict,
observation_space: spaces.Dict,
cnn_output_dim: int = 256,
normalized_image: bool = False,
) -> None:

View file

@ -2,6 +2,7 @@ import glob
import os
import platform
import random
import re
from collections import deque
from itertools import zip_longest
from typing import Dict, Iterable, List, Optional, Tuple, Union
@ -9,6 +10,7 @@ from typing import Dict, Iterable, List, Optional, Tuple, Union
import gym
import numpy as np
import torch as th
from gym import spaces
import stable_baselines3 as sb3
@ -210,7 +212,7 @@ def configure_logger(
return configure(save_path, format_strings=format_strings)
def check_for_correct_spaces(env: GymEnv, observation_space: gym.spaces.Space, action_space: gym.spaces.Space) -> None:
def check_for_correct_spaces(env: GymEnv, observation_space: spaces.Space, action_space: spaces.Space) -> None:
"""
Checks that the environment has same spaces as provided ones. Used by BaseAlgorithm to check if
spaces match after loading the model with given env.
@ -228,7 +230,7 @@ def check_for_correct_spaces(env: GymEnv, observation_space: gym.spaces.Space, a
raise ValueError(f"Action spaces do not match: {action_space} != {env.action_space}")
def is_vectorized_box_observation(observation: np.ndarray, observation_space: gym.spaces.Box) -> bool:
def is_vectorized_box_observation(observation: np.ndarray, observation_space: spaces.Box) -> bool:
"""
For box observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
@ -249,7 +251,7 @@ def is_vectorized_box_observation(observation: np.ndarray, observation_space: gy
)
def is_vectorized_discrete_observation(observation: Union[int, np.ndarray], observation_space: gym.spaces.Discrete) -> bool:
def is_vectorized_discrete_observation(observation: Union[int, np.ndarray], observation_space: spaces.Discrete) -> bool:
"""
For discrete observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
@ -269,7 +271,7 @@ def is_vectorized_discrete_observation(observation: Union[int, np.ndarray], obse
)
def is_vectorized_multidiscrete_observation(observation: np.ndarray, observation_space: gym.spaces.MultiDiscrete) -> bool:
def is_vectorized_multidiscrete_observation(observation: np.ndarray, observation_space: spaces.MultiDiscrete) -> bool:
"""
For multidiscrete observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
@ -290,7 +292,7 @@ def is_vectorized_multidiscrete_observation(observation: np.ndarray, observation
)
def is_vectorized_multibinary_observation(observation: np.ndarray, observation_space: gym.spaces.MultiBinary) -> bool:
def is_vectorized_multibinary_observation(observation: np.ndarray, observation_space: spaces.MultiBinary) -> bool:
"""
For multibinary observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
@ -311,7 +313,7 @@ def is_vectorized_multibinary_observation(observation: np.ndarray, observation_s
)
def is_vectorized_dict_observation(observation: np.ndarray, observation_space: gym.spaces.Dict) -> bool:
def is_vectorized_dict_observation(observation: np.ndarray, observation_space: spaces.Dict) -> bool:
"""
For dict observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
@ -355,7 +357,7 @@ def is_vectorized_dict_observation(observation: np.ndarray, observation_space: g
)
def is_vectorized_observation(observation: Union[int, np.ndarray], observation_space: gym.spaces.Space) -> bool:
def is_vectorized_observation(observation: Union[int, np.ndarray], observation_space: spaces.Space) -> bool:
"""
For every observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
@ -366,11 +368,11 @@ def is_vectorized_observation(observation: Union[int, np.ndarray], observation_s
"""
is_vec_obs_func_dict = {
gym.spaces.Box: is_vectorized_box_observation,
gym.spaces.Discrete: is_vectorized_discrete_observation,
gym.spaces.MultiDiscrete: is_vectorized_multidiscrete_observation,
gym.spaces.MultiBinary: is_vectorized_multibinary_observation,
gym.spaces.Dict: is_vectorized_dict_observation,
spaces.Box: is_vectorized_box_observation,
spaces.Discrete: is_vectorized_discrete_observation,
spaces.MultiDiscrete: is_vectorized_multidiscrete_observation,
spaces.MultiBinary: is_vectorized_multibinary_observation,
spaces.Dict: is_vectorized_dict_observation,
}
for space_type, is_vec_obs_func in is_vec_obs_func_dict.items():
@ -505,7 +507,9 @@ def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]:
and a formatted string.
"""
env_info = {
"OS": f"{platform.platform()} {platform.version()}",
# In OS, a regex is used to add a space between a "#" and a number to avoid
# wrongly linking to another issue on GitHub. Example: turn "#42" to "# 42".
"OS": re.sub(r"#(\d)", r"# \1", f"{platform.platform()} {platform.version()}"),
"Python": platform.python_version(),
"Stable-Baselines3": sb3.__version__,
"PyTorch": th.__version__,
@ -515,7 +519,7 @@ def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]:
}
env_info_str = ""
for key, value in env_info.items():
env_info_str += f"{key}: {value}\n"
env_info_str += f"- {key}: {value}\n"
if print_info:
print(env_info_str)
return env_info, env_info_str

View file

@ -6,6 +6,7 @@ from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, U
import cloudpickle
import gym
import numpy as np
from gym import spaces
# Define type aliases here to avoid circular import
# Used when we want to access one or more VecEnv
@ -48,14 +49,14 @@ class VecEnv(ABC):
"""
An abstract asynchronous, vectorized environment.
:param num_envs: the number of environments
:param observation_space: the observation space
:param action_space: the action space
:param num_envs: Number of environments
:param observation_space: Observation space
:param action_space: Action space
"""
metadata = {"render.modes": ["human", "rgb_array"]}
def __init__(self, num_envs: int, observation_space: gym.spaces.Space, action_space: gym.spaces.Space):
def __init__(self, num_envs: int, observation_space: spaces.Space, action_space: spaces.Space):
self.num_envs = num_envs
self.observation_space = observation_space
self.action_space = action_space
@ -248,8 +249,8 @@ class VecEnvWrapper(VecEnv):
def __init__(
self,
venv: VecEnv,
observation_space: Optional[gym.spaces.Space] = None,
action_space: Optional[gym.spaces.Space] = None,
observation_space: Optional[spaces.Space] = None,
action_space: Optional[spaces.Space] = None,
):
self.venv = venv
VecEnv.__init__(

View file

@ -4,6 +4,7 @@ from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union
import gym
import numpy as np
from gym import spaces
from stable_baselines3.common.vec_env.base_vec_env import (
CloudpickleWrapper,
@ -196,7 +197,7 @@ class SubprocVecEnv(VecEnv):
return [self.remotes[i] for i in indices]
def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: gym.spaces.Space) -> VecEnvObs:
def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: spaces.Space) -> VecEnvObs:
"""
Flatten observations, depending on the observation space.
@ -210,11 +211,11 @@ def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: gym.space
assert isinstance(obs, (list, tuple)), "expected list or tuple of observations per environment"
assert len(obs) > 0, "need observations from at least one environment"
if isinstance(space, gym.spaces.Dict):
if isinstance(space, spaces.Dict):
assert isinstance(space.spaces, OrderedDict), "Dict space must have ordered subspaces"
assert isinstance(obs[0], dict), "non-dict observation for environment with Dict observation space"
return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()])
elif isinstance(space, gym.spaces.Tuple):
elif isinstance(space, spaces.Tuple):
assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space"
obs_len = len(space.spaces)
return tuple(np.stack([o[i] for o in obs]) for i in range(obs_len))

View file

@ -4,8 +4,8 @@ Helpers for dealing with vectorized environments.
from collections import OrderedDict
from typing import Any, Dict, List, Tuple
import gym
import numpy as np
from gym import spaces
from stable_baselines3.common.preprocessing import check_for_nested_spaces
from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs
@ -22,7 +22,7 @@ def copy_obs_dict(obs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
return OrderedDict([(k, np.copy(v)) for k, v in obs.items()])
def dict_to_obs(obs_space: gym.spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> VecEnvObs:
def dict_to_obs(obs_space: spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> VecEnvObs:
"""
Convert an internal representation raw_obs into the appropriate type
specified by space.
@ -33,9 +33,9 @@ def dict_to_obs(obs_space: gym.spaces.Space, obs_dict: Dict[Any, np.ndarray]) ->
If space is Dict, function is identity; if space is Tuple, converts dict to Tuple;
otherwise, space is unstructured and returns the value raw_obs[None].
"""
if isinstance(obs_space, gym.spaces.Dict):
if isinstance(obs_space, spaces.Dict):
return obs_dict
elif isinstance(obs_space, gym.spaces.Tuple):
elif isinstance(obs_space, spaces.Tuple):
assert len(obs_dict) == len(obs_space.spaces), "size of observation does not match size of observation space"
return tuple(obs_dict[i] for i in range(len(obs_space.spaces)))
else:
@ -43,7 +43,7 @@ def dict_to_obs(obs_space: gym.spaces.Space, obs_dict: Dict[Any, np.ndarray]) ->
return obs_dict[None]
def obs_space_info(obs_space: gym.spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[int, ...]], Dict[Any, np.dtype]]:
def obs_space_info(obs_space: spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[int, ...]], Dict[Any, np.dtype]]:
"""
Get dict-structured information about a gym.Space.
@ -58,10 +58,10 @@ def obs_space_info(obs_space: gym.spaces.Space) -> Tuple[List[str], Dict[Any, Tu
dtypes: a dict mapping keys to dtypes.
"""
check_for_nested_spaces(obs_space)
if isinstance(obs_space, gym.spaces.Dict):
if isinstance(obs_space, spaces.Dict):
assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces"
subspaces = obs_space.spaces
elif isinstance(obs_space, gym.spaces.Tuple):
elif isinstance(obs_space, spaces.Tuple):
subspaces = {i: space for i, space in enumerate(obs_space.spaces)}
else:
assert not hasattr(obs_space, "spaces"), f"Unsupported structured space '{type(obs_space)}'"

View file

@ -2,8 +2,8 @@ import pickle
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union
import gym
import numpy as np
from gym import spaces
from stable_baselines3.common import utils
from stable_baselines3.common.preprocessing import is_image_space
@ -48,14 +48,14 @@ class VecNormalize(VecEnvWrapper):
if self.norm_obs:
self._sanity_checks()
if isinstance(self.observation_space, gym.spaces.Dict):
if isinstance(self.observation_space, spaces.Dict):
self.obs_spaces = self.observation_space.spaces
self.obs_rms = {key: RunningMeanStd(shape=self.obs_spaces[key].shape) for key in self.norm_obs_keys}
# Update observation space when using image
# See explanation below and GH #1214
for key in self.obs_rms.keys():
if is_image_space(self.obs_spaces[key]):
self.observation_space.spaces[key] = gym.spaces.Box(
self.observation_space.spaces[key] = spaces.Box(
low=-clip_obs,
high=clip_obs,
shape=self.obs_spaces[key].shape,
@ -74,7 +74,7 @@ class VecNormalize(VecEnvWrapper):
# in other cases but this will cause backward-incompatible change
# and break already saved policies.
if is_image_space(self.observation_space):
self.observation_space = gym.spaces.Box(
self.observation_space = spaces.Box(
low=-clip_obs,
high=clip_obs,
shape=self.observation_space.shape,
@ -98,13 +98,13 @@ class VecNormalize(VecEnvWrapper):
"""
Check the observations that are going to be normalized are of the correct type (spaces.Box).
"""
if isinstance(self.observation_space, gym.spaces.Dict):
if isinstance(self.observation_space, spaces.Dict):
# By default, we normalize all keys
if self.norm_obs_keys is None:
self.norm_obs_keys = list(self.observation_space.spaces.keys())
# Check that all keys are of type Box
for obs_key in self.norm_obs_keys:
if not isinstance(self.observation_space.spaces[obs_key], gym.spaces.Box):
if not isinstance(self.observation_space.spaces[obs_key], spaces.Box):
raise ValueError(
f"VecNormalize only supports `gym.spaces.Box` observation spaces but {obs_key} "
f"is of type {self.observation_space.spaces[obs_key]}. "
@ -112,7 +112,7 @@ class VecNormalize(VecEnvWrapper):
" that should be normalized via the `norm_obs_keys` parameter."
)
elif isinstance(self.observation_space, gym.spaces.Box):
elif isinstance(self.observation_space, spaces.Box):
if self.norm_obs_keys is not None:
raise ValueError("`norm_obs_keys` param is applicable only with `gym.spaces.Dict` observation spaces")
@ -143,7 +143,7 @@ class VecNormalize(VecEnvWrapper):
:param state:"""
# Backward compatibility
if "norm_obs_keys" not in state and isinstance(state["observation_space"], gym.spaces.Dict):
if "norm_obs_keys" not in state and isinstance(state["observation_space"], spaces.Dict):
state["norm_obs_keys"] = list(state["observation_space"].spaces.keys())
self.__dict__.update(state)
assert "venv" not in state

View file

@ -1,9 +1,9 @@
import warnings
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import gym
import numpy as np
import torch as th
from gym import spaces
from torch.nn import functional as F
from stable_baselines3.common.buffers import ReplayBuffer
@ -116,7 +116,7 @@ class DQN(OffPolicyAlgorithm):
seed=seed,
sde_support=False,
optimize_memory_usage=optimize_memory_usage,
supported_action_spaces=(gym.spaces.Discrete,),
supported_action_spaces=(spaces.Discrete,),
support_multi_env=True,
)

View file

@ -1,7 +1,7 @@
from typing import Any, Dict, List, Optional, Type
import gym
import torch as th
from gym import spaces
from torch import nn
from stable_baselines3.common.policies import BasePolicy
@ -29,8 +29,8 @@ class QNetwork(BasePolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
features_extractor: nn.Module,
features_dim: int,
net_arch: Optional[List[int]] = None,
@ -106,8 +106,8 @@ class DQNPolicy(BasePolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
@ -227,8 +227,8 @@ class CnnPolicy(DQNPolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
@ -272,8 +272,8 @@ class MultiInputPolicy(DQNPolicy):
def __init__(
self,
observation_space: gym.spaces.Dict,
action_space: gym.spaces.Space,
observation_space: spaces.Dict,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,

View file

@ -1,7 +1,7 @@
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import gym
import torch as th
from gym import spaces
from torch import nn
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
@ -47,8 +47,8 @@ class Actor(BasePolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
net_arch: List[int],
features_extractor: nn.Module,
features_dim: int,
@ -207,8 +207,8 @@ class SACPolicy(BasePolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
@ -385,8 +385,8 @@ class CnnPolicy(SACPolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
@ -451,8 +451,8 @@ class MultiInputPolicy(SACPolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,

View file

@ -1,8 +1,8 @@
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import gym
import numpy as np
import torch as th
from gym import spaces
from torch.nn import functional as F
from stable_baselines3.common.buffers import ReplayBuffer
@ -133,7 +133,7 @@ class SAC(OffPolicyAlgorithm):
sde_sample_freq=sde_sample_freq,
use_sde_at_warmup=use_sde_at_warmup,
optimize_memory_usage=optimize_memory_usage,
supported_action_spaces=(gym.spaces.Box),
supported_action_spaces=(spaces.Box),
support_multi_env=True,
)

View file

@ -1,7 +1,7 @@
from typing import Any, Dict, List, Optional, Type, Union
import gym
import torch as th
from gym import spaces
from torch import nn
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic
@ -34,8 +34,8 @@ class Actor(BasePolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
net_arch: List[int],
features_extractor: nn.Module,
features_dim: int,
@ -108,8 +108,8 @@ class TD3Policy(BasePolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
@ -271,8 +271,8 @@ class CnnPolicy(TD3Policy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
@ -325,8 +325,8 @@ class MultiInputPolicy(TD3Policy):
def __init__(
self,
observation_space: gym.spaces.Dict,
action_space: gym.spaces.Space,
observation_space: spaces.Dict,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,

View file

@ -1,8 +1,8 @@
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import gym
import numpy as np
import torch as th
from gym import spaces
from torch.nn import functional as F
from stable_baselines3.common.buffers import ReplayBuffer
@ -116,7 +116,7 @@ class TD3(OffPolicyAlgorithm):
seed=seed,
sde_support=False,
optimize_memory_usage=optimize_memory_usage,
supported_action_spaces=(gym.spaces.Box),
supported_action_spaces=(spaces.Box),
support_multi_env=True,
)

View file

@ -1 +1 @@
1.7.0a10
1.7.0a11

View file

@ -1,14 +1,14 @@
import gym
import numpy as np
import pytest
from gym.spaces import Box, Dict, Discrete
from gym import spaces
from stable_baselines3.common.env_checker import check_env
class ActionDictTestEnv(gym.Env):
action_space = Dict({"position": Discrete(1), "velocity": Discrete(1)})
observation_space = Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32)
action_space = spaces.Dict({"position": spaces.Discrete(1), "velocity": spaces.Discrete(1)})
observation_space = spaces.Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32)
def step(self, action):
observation = np.array([1.0, 1.5, 0.5], dtype=self.observation_space.dtype)

View file

@ -2,6 +2,7 @@ import gym
import numpy as np
import pytest
import torch as th
from gym import spaces
from stable_baselines3 import A2C, PPO, SAC
from stable_baselines3.common.callbacks import BaseCallback
@ -11,8 +12,8 @@ from stable_baselines3.common.policies import ActorCriticPolicy
class CustomEnv(gym.Env):
def __init__(self, max_steps=8):
super().__init__()
self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
self.observation_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
self.max_steps = max_steps
self.n_steps = 0
@ -39,8 +40,8 @@ class InfiniteHorizonEnv(gym.Env):
def __init__(self, n_states=4):
super().__init__()
self.n_states = n_states
self.observation_space = gym.spaces.Discrete(n_states)
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
self.observation_space = spaces.Discrete(n_states)
self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
self.current_state = 0
def reset(self):

View file

@ -8,6 +8,7 @@ import gym
import numpy as np
import pytest
import torch as th
from gym import spaces
from matplotlib import pyplot as plt
from pandas.errors import EmptyDataError
@ -350,8 +351,8 @@ class TimeDelayEnv(gym.Env):
def __init__(self, delay: float = 0.01):
super().__init__()
self.delay = delay
self.observation_space = gym.spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32)
self.action_space = gym.spaces.Discrete(2)
self.observation_space = spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32)
self.action_space = spaces.Discrete(2)
def reset(self):
return self.observation_space.sample()

View file

@ -2,6 +2,7 @@ import gym
import numpy as np
import pytest
import torch as th
from gym import spaces
from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
from stable_baselines3.common.envs import IdentityEnv
@ -17,7 +18,7 @@ MODEL_LIST = [
]
class SubClassedBox(gym.spaces.Box):
class SubClassedBox(spaces.Box):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

View file

@ -1,6 +1,7 @@
import gym
import numpy as np
import pytest
from gym import spaces
from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
from stable_baselines3.common.env_util import make_vec_env
@ -10,8 +11,8 @@ from stable_baselines3.common.evaluation import evaluate_policy
class DummyMultiDiscreteSpace(gym.Env):
def __init__(self, nvec):
super().__init__()
self.observation_space = gym.spaces.MultiDiscrete(nvec)
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
self.observation_space = spaces.MultiDiscrete(nvec)
self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
def reset(self):
return self.observation_space.sample()
@ -23,8 +24,8 @@ class DummyMultiDiscreteSpace(gym.Env):
class DummyMultiBinary(gym.Env):
def __init__(self, n):
super().__init__()
self.observation_space = gym.spaces.MultiBinary(n)
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
self.observation_space = spaces.MultiBinary(n)
self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
def reset(self):
return self.observation_space.sample()
@ -36,8 +37,8 @@ class DummyMultiBinary(gym.Env):
class DummyMultidimensionalAction(gym.Env):
def __init__(self):
super().__init__()
self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32)
self.observation_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
self.action_space = spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32)
def reset(self):
return self.observation_space.sample()
@ -55,7 +56,7 @@ def test_identity_spaces(model_class, env):
"""
# DQN only support discrete actions
if model_class == DQN:
env.action_space = gym.spaces.Discrete(4)
env.action_space = spaces.Discrete(4)
env = gym.wrappers.TimeLimit(env, max_episode_steps=100)

View file

@ -6,6 +6,7 @@ import multiprocessing
import gym
import numpy as np
import pytest
from gym import spaces
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize
@ -64,7 +65,7 @@ class CustomGymEnv(gym.Env):
def test_vecenv_func_checker():
"""The functions in ``env_fns'' must return distinct instances since we need distinct environments."""
env = CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2)))
env = CustomGymEnv(spaces.Box(low=np.zeros(2), high=np.ones(2)))
with pytest.raises(ValueError):
DummyVecEnv([lambda: env for _ in range(N_ENVS)])
@ -78,7 +79,7 @@ def test_vecenv_custom_calls(vec_env_class, vec_env_wrapper):
"""Test access to methods/attributes of vectorized environments"""
def make_env():
return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2)))
return CustomGymEnv(spaces.Box(low=np.zeros(2), high=np.ones(2)))
vec_env = vec_env_class([make_env for _ in range(N_ENVS)])
@ -147,8 +148,8 @@ class StepEnv(gym.Env):
def __init__(self, max_steps):
"""Gym environment for testing that terminal observation is inserted
correctly."""
self.action_space = gym.spaces.Discrete(2)
self.observation_space = gym.spaces.Box(np.array([0]), np.array([999]), dtype="int")
self.action_space = spaces.Discrete(2)
self.observation_space = spaces.Box(np.array([0]), np.array([999]), dtype="int")
self.max_steps = max_steps
self.current_step = 0
@ -210,10 +211,10 @@ def test_vecenv_terminal_obs(vec_env_class, vec_env_wrapper):
SPACES = collections.OrderedDict(
[
("discrete", gym.spaces.Discrete(2)),
("multidiscrete", gym.spaces.MultiDiscrete([2, 3])),
("multibinary", gym.spaces.MultiBinary(3)),
("continuous", gym.spaces.Box(low=np.zeros(2), high=np.ones(2))),
("discrete", spaces.Discrete(2)),
("multidiscrete", spaces.MultiDiscrete([2, 3])),
("multibinary", spaces.MultiBinary(3)),
("continuous", spaces.Box(low=np.zeros(2), high=np.ones(2))),
]
)
@ -252,7 +253,7 @@ def test_vecenv_single_space(vec_env_class, space):
check_vecenv_spaces(vec_env_class, space, obs_assert)
class _UnorderedDictSpace(gym.spaces.Dict):
class _UnorderedDictSpace(spaces.Dict):
"""Like DictSpace, but returns an unordered dict when sampling."""
def sample(self):
@ -262,7 +263,7 @@ class _UnorderedDictSpace(gym.spaces.Dict):
@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES)
def test_vecenv_dict_spaces(vec_env_class):
"""Test dictionary observation spaces with vectorized environments."""
space = gym.spaces.Dict(SPACES)
space = spaces.Dict(SPACES)
def obs_assert(obs):
assert isinstance(obs, collections.OrderedDict)
@ -280,7 +281,7 @@ def test_vecenv_dict_spaces(vec_env_class):
@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES)
def test_vecenv_tuple_spaces(vec_env_class):
"""Test tuple observation spaces with vectorized environments."""
space = gym.spaces.Tuple(tuple(SPACES.values()))
space = spaces.Tuple(tuple(SPACES.values()))
def obs_assert(obs):
assert isinstance(obs, tuple)
@ -298,7 +299,7 @@ def test_subproc_start_method():
all_methods = {"forkserver", "spawn", "fork"}
available_methods = multiprocessing.get_all_start_methods()
start_methods += list(all_methods.intersection(available_methods))
space = gym.spaces.Discrete(2)
space = spaces.Discrete(2)
def obs_assert(obs):
return check_vecenv_obs(obs, space)
@ -338,7 +339,7 @@ class CustomWrapperBB(CustomWrapperB):
def test_vecenv_wrapper_getattr():
def make_env():
return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2)))
return CustomGymEnv(spaces.Box(low=np.zeros(2), high=np.ones(2)))
vec_env = DummyVecEnv([make_env for _ in range(N_ENVS)])
wrapped = CustomWrapperA(CustomWrapperBB(vec_env))
@ -367,7 +368,7 @@ def test_framestack_vecenv():
def make_image_env():
return CustomGymEnv(
gym.spaces.Box(
spaces.Box(
low=np.zeros(image_space_shape),
high=np.ones(image_space_shape) * 255,
dtype=np.uint8,
@ -376,7 +377,7 @@ def test_framestack_vecenv():
def make_transposed_image_env():
return CustomGymEnv(
gym.spaces.Box(
spaces.Box(
low=np.zeros(transposed_image_space_shape),
high=np.ones(transposed_image_space_shape) * 255,
dtype=np.uint8,
@ -384,7 +385,7 @@ def test_framestack_vecenv():
)
def make_non_image_env():
return CustomGymEnv(gym.spaces.Box(low=np.zeros((2,)), high=np.ones((2,))))
return CustomGymEnv(spaces.Box(low=np.zeros((2,)), high=np.ones((2,))))
vec_env = DummyVecEnv([make_image_env for _ in range(N_ENVS)])
vec_env = VecFrameStack(vec_env, n_stack=2)
@ -433,10 +434,10 @@ def test_framestack_vecenv():
def test_vec_env_is_wrapped():
# Test is_wrapped call of subproc workers
def make_env():
return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2)))
return CustomGymEnv(spaces.Box(low=np.zeros(2), high=np.ones(2)))
def make_monitored_env():
return Monitor(CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2))))
return Monitor(CustomGymEnv(spaces.Box(low=np.zeros(2), high=np.ones(2))))
# One with monitor, one without
vec_env = SubprocVecEnv([make_env, make_monitored_env])
@ -457,7 +458,7 @@ def test_vec_env_is_wrapped():
@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES)
def test_vec_seeding(vec_env_class):
def make_env():
return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2)))
return CustomGymEnv(spaces.Box(low=np.zeros(2), high=np.ones(2)))
# For SubprocVecEnv check for all starting methods
start_methods = [None]

View file

@ -23,8 +23,8 @@ class DummyRewardEnv(gym.Env):
metadata = {}
def __init__(self, return_reward_idx=0):
self.action_space = gym.spaces.Discrete(2)
self.observation_space = gym.spaces.Box(low=np.array([-1.0]), high=np.array([1.0]))
self.action_space = spaces.Discrete(2)
self.observation_space = spaces.Box(low=np.array([-1.0]), high=np.array([1.0]))
self.returned_rewards = [0, 1, 3, 4]
self.return_reward_idx = return_reward_idx
self.t = self.return_reward_idx