mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-03 23:49:57 +00:00
Standardize the use of from gym import spaces (#1240)
* generalize the use of `from gym import spaces` * command line get system info * Documentation line length for doc * update changelog * add space before os plateform to avoid ref to other issue * format * get_system_info update in changelog * fix type check error * fix get system info * add comment about regex * update version
This commit is contained in:
parent
2bb8ef5e63
commit
4fa17dcf0f
34 changed files with 219 additions and 196 deletions
5
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
5
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
|
|
@ -52,9 +52,8 @@ body:
|
|||
* Versions of any other relevant libraries
|
||||
|
||||
You can use `sb3.get_system_info()` to print relevant packages info:
|
||||
```python
|
||||
import stable_baselines3 as sb3
|
||||
sb3.get_system_info()
|
||||
```sh
|
||||
python -c 'import stable_baselines3 as sb3; sb3.get_system_info()'
|
||||
```
|
||||
- type: checkboxes
|
||||
id: terms
|
||||
|
|
|
|||
30
.github/ISSUE_TEMPLATE/custom_env.yml
vendored
30
.github/ISSUE_TEMPLATE/custom_env.yml
vendored
|
|
@ -36,6 +36,7 @@ body:
|
|||
```python
|
||||
import gym
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
from stable_baselines3 import A2C
|
||||
from stable_baselines3.common.env_checker import check_env
|
||||
|
|
@ -43,20 +44,20 @@ body:
|
|||
|
||||
class CustomEnv(gym.Env):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(14,))
|
||||
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(6,))
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(14,))
|
||||
self.action_space = spaces.Box(low=-1, high=1, shape=(6,))
|
||||
|
||||
def reset(self):
|
||||
return self.observation_space.sample()
|
||||
def reset(self):
|
||||
return self.observation_space.sample()
|
||||
|
||||
def step(self, action):
|
||||
obs = self.observation_space.sample()
|
||||
reward = 1.0
|
||||
done = False
|
||||
info = {}
|
||||
return obs, reward, done, info
|
||||
def step(self, action):
|
||||
obs = self.observation_space.sample()
|
||||
reward = 1.0
|
||||
done = False
|
||||
info = {}
|
||||
return obs, reward, done, info
|
||||
|
||||
env = CustomEnv()
|
||||
check_env(env)
|
||||
|
|
@ -86,9 +87,8 @@ body:
|
|||
* Versions of any other relevant libraries
|
||||
|
||||
You can use `sb3.get_system_info()` to print relevant packages info:
|
||||
```python
|
||||
import stable_baselines3 as sb3
|
||||
sb3.get_system_info()
|
||||
```sh
|
||||
python -c 'import stable_baselines3 as sb3; sb3.get_system_info()'
|
||||
```
|
||||
- type: checkboxes
|
||||
id: terms
|
||||
|
|
|
|||
|
|
@ -38,7 +38,8 @@ pip install -e .[docs,tests,extra]
|
|||
|
||||
## Codestyle
|
||||
|
||||
We are using [black codestyle](https://github.com/psf/black) (max line length of 127 characters) together with [isort](https://github.com/timothycrosley/isort) to sort the imports.
|
||||
We use [black codestyle](https://github.com/psf/black) (max line length of 127 characters) together with [isort](https://github.com/timothycrosley/isort) to sort the imports.
|
||||
For the documentation, we use the default line length of 88 characters per line.
|
||||
|
||||
**Please run `make format`** to reformat your code. You can check the codestyle using `make check-codestyle` and `make lint`.
|
||||
|
||||
|
|
|
|||
|
|
@ -27,14 +27,17 @@ That is to say, your environment must implement the following methods (and inher
|
|||
.. code-block:: python
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
|
||||
class CustomEnv(gym.Env):
|
||||
"""Custom Environment that follows gym interface"""
|
||||
"""Custom Environment that follows gym interface."""
|
||||
|
||||
metadata = {"render.modes": ["human"]}
|
||||
|
||||
def __init__(self, arg1, arg2, ...):
|
||||
super(CustomEnv, self).__init__()
|
||||
super().__init__()
|
||||
# Define action and observation space
|
||||
# They must be gym.spaces objects
|
||||
# Example when using discrete actions:
|
||||
|
|
@ -46,12 +49,15 @@ That is to say, your environment must implement the following methods (and inher
|
|||
def step(self, action):
|
||||
...
|
||||
return observation, reward, done, info
|
||||
|
||||
def reset(self):
|
||||
...
|
||||
return observation # reward, done, info can't be included
|
||||
|
||||
def render(self, mode="human"):
|
||||
...
|
||||
def close (self):
|
||||
|
||||
def close(self):
|
||||
...
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -125,9 +125,9 @@ that derives from ``BaseFeaturesExtractor`` and then pass it to the model when t
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
from gym import spaces
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
|
||||
|
|
@ -140,7 +140,7 @@ that derives from ``BaseFeaturesExtractor`` and then pass it to the model when t
|
|||
This corresponds to the number of unit for the last layer.
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256):
|
||||
def __init__(self, observation_space: spaces.Box, features_dim: int = 256):
|
||||
super().__init__(observation_space, features_dim)
|
||||
# We assume CxHxW images (channels first)
|
||||
# Re-ordering will be done by pre-preprocessing or wrapper
|
||||
|
|
@ -199,7 +199,7 @@ downsampling and "vector" with a single linear layer.
|
|||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
|
||||
|
||||
class CustomCombinedExtractor(BaseFeaturesExtractor):
|
||||
def __init__(self, observation_space: gym.spaces.Dict):
|
||||
def __init__(self, observation_space: spaces.Dict):
|
||||
# We do not know features-dim here before going over all the items,
|
||||
# so put something dummy for now. PyTorch requires calling
|
||||
# nn.Module.__init__ before adding modules
|
||||
|
|
@ -310,7 +310,7 @@ If your task requires even more granular control over the policy/value architect
|
|||
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import gym
|
||||
from gym import spaces
|
||||
import torch as th
|
||||
from torch import nn
|
||||
|
||||
|
|
@ -367,8 +367,8 @@ If your task requires even more granular control over the policy/value architect
|
|||
class CustomActorCriticPolicy(ActorCriticPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Callable[[float], float],
|
||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.7.0a10 (WIP)
|
||||
Release 1.7.0a11 (WIP)
|
||||
--------------------------
|
||||
|
||||
.. note::
|
||||
|
|
@ -71,6 +71,8 @@ Others:
|
|||
- Upgraded GitHub CI/setup-python to v4 and checkout to v3
|
||||
- Set tensors construction directly on the device (~8% speed boost on GPU)
|
||||
- Monkey-patched ``np.bool = bool`` so gym 0.21 is compatible with NumPy 1.24+
|
||||
- Standardized the use of ``from gym import spaces``
|
||||
- Modified ``get_system_info`` to avoid issue linked to copy-pasting on GitHub issue
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Un
|
|||
import gym
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
|
||||
from stable_baselines3.common import utils
|
||||
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback
|
||||
|
|
@ -101,7 +102,7 @@ class BaseAlgorithm(ABC):
|
|||
seed: Optional[int] = None,
|
||||
use_sde: bool = False,
|
||||
sde_sample_freq: int = -1,
|
||||
supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None,
|
||||
supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None,
|
||||
):
|
||||
if isinstance(policy, str):
|
||||
self.policy_class = self._get_policy_from_name(policy)
|
||||
|
|
@ -117,8 +118,8 @@ class BaseAlgorithm(ABC):
|
|||
self._vec_normalize_env = unwrap_vec_normalize(env)
|
||||
self.verbose = verbose
|
||||
self.policy_kwargs = {} if policy_kwargs is None else policy_kwargs
|
||||
self.observation_space = None # type: Optional[gym.spaces.Space]
|
||||
self.action_space = None # type: Optional[gym.spaces.Space]
|
||||
self.observation_space = None # type: Optional[spaces.Space]
|
||||
self.action_space = None # type: Optional[spaces.Space]
|
||||
self.n_envs = None
|
||||
self.num_timesteps = 0
|
||||
# Used for updating schedules
|
||||
|
|
@ -175,13 +176,13 @@ class BaseAlgorithm(ABC):
|
|||
)
|
||||
|
||||
# Catch common mistake: using MlpPolicy/CnnPolicy instead of MultiInputPolicy
|
||||
if policy in ["MlpPolicy", "CnnPolicy"] and isinstance(self.observation_space, gym.spaces.Dict):
|
||||
if policy in ["MlpPolicy", "CnnPolicy"] and isinstance(self.observation_space, spaces.Dict):
|
||||
raise ValueError(f"You must use `MultiInputPolicy` when working with dict observation space, not {policy}")
|
||||
|
||||
if self.use_sde and not isinstance(self.action_space, gym.spaces.Box):
|
||||
if self.use_sde and not isinstance(self.action_space, spaces.Box):
|
||||
raise ValueError("generalized State-Dependent Exploration (gSDE) can only be used with continuous actions.")
|
||||
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
if isinstance(self.action_space, spaces.Box):
|
||||
assert np.all(
|
||||
np.isfinite(np.array([self.action_space.low, self.action_space.high]))
|
||||
), "Continuous action space must have a finite lower and upper bound"
|
||||
|
|
@ -212,7 +213,7 @@ class BaseAlgorithm(ABC):
|
|||
|
||||
if not is_vecenv_wrapped(env, VecTransposeImage):
|
||||
wrap_with_vectranspose = False
|
||||
if isinstance(env.observation_space, gym.spaces.Dict):
|
||||
if isinstance(env.observation_space, spaces.Dict):
|
||||
# If even one of the keys is a image-space in need of transpose, apply transpose
|
||||
# If the image spaces are not consistent (for instance one is channel first,
|
||||
# the other channel last), VecTransposeImage will throw an error
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
|
|
@ -659,7 +658,7 @@ class TanhBijector:
|
|||
|
||||
|
||||
def make_proba_distribution(
|
||||
action_space: gym.spaces.Space, use_sde: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None
|
||||
action_space: spaces.Space, use_sde: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None
|
||||
) -> Distribution:
|
||||
"""
|
||||
Return an instance of Distribution for the correct type of action space
|
||||
|
|
|
|||
|
|
@ -266,7 +266,7 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action
|
|||
|
||||
def _check_spaces(env: gym.Env) -> None:
|
||||
"""
|
||||
Check that the observation and action spaces are defined and inherit from gym.spaces.Space. For
|
||||
Check that the observation and action spaces are defined and inherit from spaces.Space. For
|
||||
envs that follow the goal-conditioned standard (previously, the gym.GoalEnv interface) we check
|
||||
the observation space is gym.spaces.Dict
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,14 +1,16 @@
|
|||
from typing import Optional, Union
|
||||
from typing import Any, Dict, Generic, Optional, Tuple, TypeVar, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
from gym import Env, Space
|
||||
from gym.spaces import Box, Discrete, MultiBinary, MultiDiscrete
|
||||
from gym import spaces
|
||||
|
||||
from stable_baselines3.common.type_aliases import GymObs, GymStepReturn
|
||||
|
||||
T = TypeVar("T", int, np.ndarray)
|
||||
|
||||
class IdentityEnv(Env):
|
||||
def __init__(self, dim: Optional[int] = None, space: Optional[Space] = None, ep_length: int = 100):
|
||||
|
||||
class IdentityEnv(gym.Env, Generic[T]):
|
||||
def __init__(self, dim: Optional[int] = None, space: Optional[spaces.Space] = None, ep_length: int = 100):
|
||||
"""
|
||||
Identity environment for testing purposes
|
||||
|
||||
|
|
@ -22,7 +24,7 @@ class IdentityEnv(Env):
|
|||
if space is None:
|
||||
if dim is None:
|
||||
dim = 1
|
||||
space = Discrete(dim)
|
||||
space = spaces.Discrete(dim)
|
||||
else:
|
||||
assert dim is None, "arguments for both 'dim' and 'space' provided: at most one allowed"
|
||||
|
||||
|
|
@ -32,13 +34,13 @@ class IdentityEnv(Env):
|
|||
self.num_resets = -1 # Becomes 0 after __init__ exits.
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> GymObs:
|
||||
def reset(self) -> T:
|
||||
self.current_step = 0
|
||||
self.num_resets += 1
|
||||
self._choose_next_state()
|
||||
return self.state
|
||||
|
||||
def step(self, action: Union[int, np.ndarray]) -> GymStepReturn:
|
||||
def step(self, action: T) -> Tuple[T, float, bool, Dict[str, Any]]:
|
||||
reward = self._get_reward(action)
|
||||
self._choose_next_state()
|
||||
self.current_step += 1
|
||||
|
|
@ -48,14 +50,14 @@ class IdentityEnv(Env):
|
|||
def _choose_next_state(self) -> None:
|
||||
self.state = self.action_space.sample()
|
||||
|
||||
def _get_reward(self, action: Union[int, np.ndarray]) -> float:
|
||||
def _get_reward(self, action: T) -> float:
|
||||
return 1.0 if np.all(self.state == action) else 0.0
|
||||
|
||||
def render(self, mode: str = "human") -> None:
|
||||
pass
|
||||
|
||||
|
||||
class IdentityEnvBox(IdentityEnv):
|
||||
class IdentityEnvBox(IdentityEnv[np.ndarray]):
|
||||
def __init__(self, low: float = -1.0, high: float = 1.0, eps: float = 0.05, ep_length: int = 100):
|
||||
"""
|
||||
Identity environment for testing purposes
|
||||
|
|
@ -65,7 +67,7 @@ class IdentityEnvBox(IdentityEnv):
|
|||
:param eps: the epsilon bound for correct value
|
||||
:param ep_length: the length of each episode in timesteps
|
||||
"""
|
||||
space = Box(low=low, high=high, shape=(1,), dtype=np.float32)
|
||||
space = spaces.Box(low=low, high=high, shape=(1,), dtype=np.float32)
|
||||
super().__init__(ep_length=ep_length, space=space)
|
||||
self.eps = eps
|
||||
|
||||
|
|
@ -80,7 +82,7 @@ class IdentityEnvBox(IdentityEnv):
|
|||
return 1.0 if (self.state - self.eps) <= action <= (self.state + self.eps) else 0.0
|
||||
|
||||
|
||||
class IdentityEnvMultiDiscrete(IdentityEnv):
|
||||
class IdentityEnvMultiDiscrete(IdentityEnv[np.ndarray]):
|
||||
def __init__(self, dim: int = 1, ep_length: int = 100):
|
||||
"""
|
||||
Identity environment for testing purposes
|
||||
|
|
@ -88,11 +90,11 @@ class IdentityEnvMultiDiscrete(IdentityEnv):
|
|||
:param dim: the size of the dimensions you want to learn
|
||||
:param ep_length: the length of each episode in timesteps
|
||||
"""
|
||||
space = MultiDiscrete([dim, dim])
|
||||
space = spaces.MultiDiscrete([dim, dim])
|
||||
super().__init__(ep_length=ep_length, space=space)
|
||||
|
||||
|
||||
class IdentityEnvMultiBinary(IdentityEnv):
|
||||
class IdentityEnvMultiBinary(IdentityEnv[np.ndarray]):
|
||||
def __init__(self, dim: int = 1, ep_length: int = 100):
|
||||
"""
|
||||
Identity environment for testing purposes
|
||||
|
|
@ -100,11 +102,11 @@ class IdentityEnvMultiBinary(IdentityEnv):
|
|||
:param dim: the size of the dimensions you want to learn
|
||||
:param ep_length: the length of each episode in timesteps
|
||||
"""
|
||||
space = MultiBinary(dim)
|
||||
space = spaces.MultiBinary(dim)
|
||||
super().__init__(ep_length=ep_length, space=space)
|
||||
|
||||
|
||||
class FakeImageEnv(Env):
|
||||
class FakeImageEnv(gym.Env):
|
||||
"""
|
||||
Fake image environment for testing purposes, it mimics Atari games.
|
||||
|
||||
|
|
@ -128,11 +130,11 @@ class FakeImageEnv(Env):
|
|||
self.observation_shape = (screen_height, screen_width, n_channels)
|
||||
if channel_first:
|
||||
self.observation_shape = (n_channels, screen_height, screen_width)
|
||||
self.observation_space = Box(low=0, high=255, shape=self.observation_shape, dtype=np.uint8)
|
||||
self.observation_space = spaces.Box(low=0, high=255, shape=self.observation_shape, dtype=np.uint8)
|
||||
if discrete:
|
||||
self.action_space = Discrete(action_dim)
|
||||
self.action_space = spaces.Discrete(action_dim)
|
||||
else:
|
||||
self.action_space = Box(low=-1, high=1, shape=(5,), dtype=np.float32)
|
||||
self.action_space = spaces.Box(low=-1, high=1, shape=(5,), dtype=np.float32)
|
||||
self.ep_length = 10
|
||||
self.current_step = 0
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from typing import Dict, Union
|
|||
|
||||
import gym
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
from stable_baselines3.common.type_aliases import GymStepReturn
|
||||
|
||||
|
|
@ -53,14 +54,14 @@ class SimpleMultiObsEnv(gym.Env):
|
|||
self.random_start = random_start
|
||||
self.discrete_actions = discrete_actions
|
||||
if discrete_actions:
|
||||
self.action_space = gym.spaces.Discrete(4)
|
||||
self.action_space = spaces.Discrete(4)
|
||||
else:
|
||||
self.action_space = gym.spaces.Box(0, 1, (4,))
|
||||
self.action_space = spaces.Box(0, 1, (4,))
|
||||
|
||||
self.observation_space = gym.spaces.Dict(
|
||||
self.observation_space = spaces.Dict(
|
||||
spaces={
|
||||
"vec": gym.spaces.Box(0, 1, (self.vector_size,), dtype=np.float64),
|
||||
"img": gym.spaces.Box(0, 255, self.img_size, dtype=np.uint8),
|
||||
"vec": spaces.Box(0, 1, (self.vector_size,), dtype=np.float64),
|
||||
"img": spaces.Box(0, 255, self.img_size, dtype=np.uint8),
|
||||
}
|
||||
)
|
||||
self.count = 0
|
||||
|
|
|
|||
|
|
@ -6,9 +6,9 @@ import warnings
|
|||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
|
||||
from stable_baselines3.common.base_class import BaseAlgorithm
|
||||
from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer
|
||||
|
|
@ -100,7 +100,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
sde_sample_freq: int = -1,
|
||||
use_sde_at_warmup: bool = False,
|
||||
sde_support: bool = True,
|
||||
supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None,
|
||||
supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
|
|
@ -173,7 +173,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
|
||||
# Use DictReplayBuffer if needed
|
||||
if self.replay_buffer_class is None:
|
||||
if isinstance(self.observation_space, gym.spaces.Dict):
|
||||
if isinstance(self.observation_space, spaces.Dict):
|
||||
self.replay_buffer_class = DictReplayBuffer
|
||||
else:
|
||||
self.replay_buffer_class = ReplayBuffer
|
||||
|
|
@ -395,7 +395,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
unscaled_action, _ = self.predict(self._last_obs, deterministic=False)
|
||||
|
||||
# Rescale the action from [low, high] to [-1, 1]
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
if isinstance(self.action_space, spaces.Box):
|
||||
scaled_action = self.policy.scale_action(unscaled_action)
|
||||
|
||||
# Add noise to the action (improve exploration)
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@ import sys
|
|||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
|
||||
from stable_baselines3.common.base_class import BaseAlgorithm
|
||||
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
|
||||
|
|
@ -70,7 +70,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
seed: Optional[int] = None,
|
||||
device: Union[th.device, str] = "auto",
|
||||
_init_setup_model: bool = True,
|
||||
supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None,
|
||||
supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
|
|
@ -103,7 +103,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
self._setup_lr_schedule()
|
||||
self.set_random_seed(self.seed)
|
||||
|
||||
buffer_cls = DictRolloutBuffer if isinstance(self.observation_space, gym.spaces.Dict) else RolloutBuffer
|
||||
buffer_cls = DictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else RolloutBuffer
|
||||
|
||||
self.rollout_buffer = buffer_cls(
|
||||
self.n_steps,
|
||||
|
|
@ -169,7 +169,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
# Rescale and perform action
|
||||
clipped_actions = actions
|
||||
# Clip the actions to avoid out of bound error
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
if isinstance(self.action_space, spaces.Box):
|
||||
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
|
||||
|
||||
new_obs, rewards, dones, infos = env.step(clipped_actions)
|
||||
|
|
@ -184,7 +184,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
self._update_info_buffer(infos)
|
||||
n_steps += 1
|
||||
|
||||
if isinstance(self.action_space, gym.spaces.Discrete):
|
||||
if isinstance(self.action_space, spaces.Discrete):
|
||||
# Reshape in case of discrete action
|
||||
actions = actions.reshape(-1, 1)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,9 +7,9 @@ from abc import ABC, abstractmethod
|
|||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from torch import nn
|
||||
|
||||
from stable_baselines3.common.distributions import (
|
||||
|
|
@ -60,8 +60,8 @@ class BaseModel(nn.Module):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
features_extractor: Optional[nn.Module] = None,
|
||||
|
|
@ -344,7 +344,7 @@ class BasePolicy(BaseModel, ABC):
|
|||
# Convert to numpy, and reshape to the original action shape
|
||||
actions = actions.cpu().numpy().reshape((-1,) + self.action_space.shape)
|
||||
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
if isinstance(self.action_space, spaces.Box):
|
||||
if self.squash_output:
|
||||
# Rescale to proper domain when using squashing
|
||||
actions = self.unscale_action(actions)
|
||||
|
|
@ -415,8 +415,8 @@ class ActorCriticPolicy(BasePolicy):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
|
|
@ -749,8 +749,8 @@ class ActorCriticCnnPolicy(ActorCriticPolicy):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
|
|
@ -822,8 +822,8 @@ class MultiInputActorCriticPolicy(ActorCriticPolicy):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Dict,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Dict,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
|
|
@ -890,8 +890,8 @@ class ContinuousCritic(BaseModel):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
net_arch: List[int],
|
||||
features_extractor: nn.Module,
|
||||
features_dim: int,
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from typing import Dict, List, Tuple, Type, Union
|
|||
|
||||
import gym
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from torch import nn
|
||||
|
||||
from stable_baselines3.common.preprocessing import get_flattened_obs_dim, is_image_space
|
||||
|
|
@ -63,7 +64,7 @@ class NatureCNN(BaseFeaturesExtractor):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Box,
|
||||
observation_space: spaces.Box,
|
||||
features_dim: int = 512,
|
||||
normalized_image: bool = False,
|
||||
) -> None:
|
||||
|
|
@ -267,7 +268,7 @@ class CombinedExtractor(BaseFeaturesExtractor):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Dict,
|
||||
observation_space: spaces.Dict,
|
||||
cnn_output_dim: int = 256,
|
||||
normalized_image: bool = False,
|
||||
) -> None:
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import glob
|
|||
import os
|
||||
import platform
|
||||
import random
|
||||
import re
|
||||
from collections import deque
|
||||
from itertools import zip_longest
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
|
@ -9,6 +10,7 @@ from typing import Dict, Iterable, List, Optional, Tuple, Union
|
|||
import gym
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
|
||||
import stable_baselines3 as sb3
|
||||
|
||||
|
|
@ -210,7 +212,7 @@ def configure_logger(
|
|||
return configure(save_path, format_strings=format_strings)
|
||||
|
||||
|
||||
def check_for_correct_spaces(env: GymEnv, observation_space: gym.spaces.Space, action_space: gym.spaces.Space) -> None:
|
||||
def check_for_correct_spaces(env: GymEnv, observation_space: spaces.Space, action_space: spaces.Space) -> None:
|
||||
"""
|
||||
Checks that the environment has same spaces as provided ones. Used by BaseAlgorithm to check if
|
||||
spaces match after loading the model with given env.
|
||||
|
|
@ -228,7 +230,7 @@ def check_for_correct_spaces(env: GymEnv, observation_space: gym.spaces.Space, a
|
|||
raise ValueError(f"Action spaces do not match: {action_space} != {env.action_space}")
|
||||
|
||||
|
||||
def is_vectorized_box_observation(observation: np.ndarray, observation_space: gym.spaces.Box) -> bool:
|
||||
def is_vectorized_box_observation(observation: np.ndarray, observation_space: spaces.Box) -> bool:
|
||||
"""
|
||||
For box observation type, detects and validates the shape,
|
||||
then returns whether or not the observation is vectorized.
|
||||
|
|
@ -249,7 +251,7 @@ def is_vectorized_box_observation(observation: np.ndarray, observation_space: gy
|
|||
)
|
||||
|
||||
|
||||
def is_vectorized_discrete_observation(observation: Union[int, np.ndarray], observation_space: gym.spaces.Discrete) -> bool:
|
||||
def is_vectorized_discrete_observation(observation: Union[int, np.ndarray], observation_space: spaces.Discrete) -> bool:
|
||||
"""
|
||||
For discrete observation type, detects and validates the shape,
|
||||
then returns whether or not the observation is vectorized.
|
||||
|
|
@ -269,7 +271,7 @@ def is_vectorized_discrete_observation(observation: Union[int, np.ndarray], obse
|
|||
)
|
||||
|
||||
|
||||
def is_vectorized_multidiscrete_observation(observation: np.ndarray, observation_space: gym.spaces.MultiDiscrete) -> bool:
|
||||
def is_vectorized_multidiscrete_observation(observation: np.ndarray, observation_space: spaces.MultiDiscrete) -> bool:
|
||||
"""
|
||||
For multidiscrete observation type, detects and validates the shape,
|
||||
then returns whether or not the observation is vectorized.
|
||||
|
|
@ -290,7 +292,7 @@ def is_vectorized_multidiscrete_observation(observation: np.ndarray, observation
|
|||
)
|
||||
|
||||
|
||||
def is_vectorized_multibinary_observation(observation: np.ndarray, observation_space: gym.spaces.MultiBinary) -> bool:
|
||||
def is_vectorized_multibinary_observation(observation: np.ndarray, observation_space: spaces.MultiBinary) -> bool:
|
||||
"""
|
||||
For multibinary observation type, detects and validates the shape,
|
||||
then returns whether or not the observation is vectorized.
|
||||
|
|
@ -311,7 +313,7 @@ def is_vectorized_multibinary_observation(observation: np.ndarray, observation_s
|
|||
)
|
||||
|
||||
|
||||
def is_vectorized_dict_observation(observation: np.ndarray, observation_space: gym.spaces.Dict) -> bool:
|
||||
def is_vectorized_dict_observation(observation: np.ndarray, observation_space: spaces.Dict) -> bool:
|
||||
"""
|
||||
For dict observation type, detects and validates the shape,
|
||||
then returns whether or not the observation is vectorized.
|
||||
|
|
@ -355,7 +357,7 @@ def is_vectorized_dict_observation(observation: np.ndarray, observation_space: g
|
|||
)
|
||||
|
||||
|
||||
def is_vectorized_observation(observation: Union[int, np.ndarray], observation_space: gym.spaces.Space) -> bool:
|
||||
def is_vectorized_observation(observation: Union[int, np.ndarray], observation_space: spaces.Space) -> bool:
|
||||
"""
|
||||
For every observation type, detects and validates the shape,
|
||||
then returns whether or not the observation is vectorized.
|
||||
|
|
@ -366,11 +368,11 @@ def is_vectorized_observation(observation: Union[int, np.ndarray], observation_s
|
|||
"""
|
||||
|
||||
is_vec_obs_func_dict = {
|
||||
gym.spaces.Box: is_vectorized_box_observation,
|
||||
gym.spaces.Discrete: is_vectorized_discrete_observation,
|
||||
gym.spaces.MultiDiscrete: is_vectorized_multidiscrete_observation,
|
||||
gym.spaces.MultiBinary: is_vectorized_multibinary_observation,
|
||||
gym.spaces.Dict: is_vectorized_dict_observation,
|
||||
spaces.Box: is_vectorized_box_observation,
|
||||
spaces.Discrete: is_vectorized_discrete_observation,
|
||||
spaces.MultiDiscrete: is_vectorized_multidiscrete_observation,
|
||||
spaces.MultiBinary: is_vectorized_multibinary_observation,
|
||||
spaces.Dict: is_vectorized_dict_observation,
|
||||
}
|
||||
|
||||
for space_type, is_vec_obs_func in is_vec_obs_func_dict.items():
|
||||
|
|
@ -505,7 +507,9 @@ def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]:
|
|||
and a formatted string.
|
||||
"""
|
||||
env_info = {
|
||||
"OS": f"{platform.platform()} {platform.version()}",
|
||||
# In OS, a regex is used to add a space between a "#" and a number to avoid
|
||||
# wrongly linking to another issue on GitHub. Example: turn "#42" to "# 42".
|
||||
"OS": re.sub(r"#(\d)", r"# \1", f"{platform.platform()} {platform.version()}"),
|
||||
"Python": platform.python_version(),
|
||||
"Stable-Baselines3": sb3.__version__,
|
||||
"PyTorch": th.__version__,
|
||||
|
|
@ -515,7 +519,7 @@ def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]:
|
|||
}
|
||||
env_info_str = ""
|
||||
for key, value in env_info.items():
|
||||
env_info_str += f"{key}: {value}\n"
|
||||
env_info_str += f"- {key}: {value}\n"
|
||||
if print_info:
|
||||
print(env_info_str)
|
||||
return env_info, env_info_str
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, U
|
|||
import cloudpickle
|
||||
import gym
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
# Define type aliases here to avoid circular import
|
||||
# Used when we want to access one or more VecEnv
|
||||
|
|
@ -48,14 +49,14 @@ class VecEnv(ABC):
|
|||
"""
|
||||
An abstract asynchronous, vectorized environment.
|
||||
|
||||
:param num_envs: the number of environments
|
||||
:param observation_space: the observation space
|
||||
:param action_space: the action space
|
||||
:param num_envs: Number of environments
|
||||
:param observation_space: Observation space
|
||||
:param action_space: Action space
|
||||
"""
|
||||
|
||||
metadata = {"render.modes": ["human", "rgb_array"]}
|
||||
|
||||
def __init__(self, num_envs: int, observation_space: gym.spaces.Space, action_space: gym.spaces.Space):
|
||||
def __init__(self, num_envs: int, observation_space: spaces.Space, action_space: spaces.Space):
|
||||
self.num_envs = num_envs
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
|
|
@ -248,8 +249,8 @@ class VecEnvWrapper(VecEnv):
|
|||
def __init__(
|
||||
self,
|
||||
venv: VecEnv,
|
||||
observation_space: Optional[gym.spaces.Space] = None,
|
||||
action_space: Optional[gym.spaces.Space] = None,
|
||||
observation_space: Optional[spaces.Space] = None,
|
||||
action_space: Optional[spaces.Space] = None,
|
||||
):
|
||||
self.venv = venv
|
||||
VecEnv.__init__(
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union
|
|||
|
||||
import gym
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
from stable_baselines3.common.vec_env.base_vec_env import (
|
||||
CloudpickleWrapper,
|
||||
|
|
@ -196,7 +197,7 @@ class SubprocVecEnv(VecEnv):
|
|||
return [self.remotes[i] for i in indices]
|
||||
|
||||
|
||||
def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: gym.spaces.Space) -> VecEnvObs:
|
||||
def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: spaces.Space) -> VecEnvObs:
|
||||
"""
|
||||
Flatten observations, depending on the observation space.
|
||||
|
||||
|
|
@ -210,11 +211,11 @@ def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: gym.space
|
|||
assert isinstance(obs, (list, tuple)), "expected list or tuple of observations per environment"
|
||||
assert len(obs) > 0, "need observations from at least one environment"
|
||||
|
||||
if isinstance(space, gym.spaces.Dict):
|
||||
if isinstance(space, spaces.Dict):
|
||||
assert isinstance(space.spaces, OrderedDict), "Dict space must have ordered subspaces"
|
||||
assert isinstance(obs[0], dict), "non-dict observation for environment with Dict observation space"
|
||||
return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()])
|
||||
elif isinstance(space, gym.spaces.Tuple):
|
||||
elif isinstance(space, spaces.Tuple):
|
||||
assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space"
|
||||
obs_len = len(space.spaces)
|
||||
return tuple(np.stack([o[i] for o in obs]) for i in range(obs_len))
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ Helpers for dealing with vectorized environments.
|
|||
from collections import OrderedDict
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
from stable_baselines3.common.preprocessing import check_for_nested_spaces
|
||||
from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs
|
||||
|
|
@ -22,7 +22,7 @@ def copy_obs_dict(obs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
|
|||
return OrderedDict([(k, np.copy(v)) for k, v in obs.items()])
|
||||
|
||||
|
||||
def dict_to_obs(obs_space: gym.spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> VecEnvObs:
|
||||
def dict_to_obs(obs_space: spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> VecEnvObs:
|
||||
"""
|
||||
Convert an internal representation raw_obs into the appropriate type
|
||||
specified by space.
|
||||
|
|
@ -33,9 +33,9 @@ def dict_to_obs(obs_space: gym.spaces.Space, obs_dict: Dict[Any, np.ndarray]) ->
|
|||
If space is Dict, function is identity; if space is Tuple, converts dict to Tuple;
|
||||
otherwise, space is unstructured and returns the value raw_obs[None].
|
||||
"""
|
||||
if isinstance(obs_space, gym.spaces.Dict):
|
||||
if isinstance(obs_space, spaces.Dict):
|
||||
return obs_dict
|
||||
elif isinstance(obs_space, gym.spaces.Tuple):
|
||||
elif isinstance(obs_space, spaces.Tuple):
|
||||
assert len(obs_dict) == len(obs_space.spaces), "size of observation does not match size of observation space"
|
||||
return tuple(obs_dict[i] for i in range(len(obs_space.spaces)))
|
||||
else:
|
||||
|
|
@ -43,7 +43,7 @@ def dict_to_obs(obs_space: gym.spaces.Space, obs_dict: Dict[Any, np.ndarray]) ->
|
|||
return obs_dict[None]
|
||||
|
||||
|
||||
def obs_space_info(obs_space: gym.spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[int, ...]], Dict[Any, np.dtype]]:
|
||||
def obs_space_info(obs_space: spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[int, ...]], Dict[Any, np.dtype]]:
|
||||
"""
|
||||
Get dict-structured information about a gym.Space.
|
||||
|
||||
|
|
@ -58,10 +58,10 @@ def obs_space_info(obs_space: gym.spaces.Space) -> Tuple[List[str], Dict[Any, Tu
|
|||
dtypes: a dict mapping keys to dtypes.
|
||||
"""
|
||||
check_for_nested_spaces(obs_space)
|
||||
if isinstance(obs_space, gym.spaces.Dict):
|
||||
if isinstance(obs_space, spaces.Dict):
|
||||
assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces"
|
||||
subspaces = obs_space.spaces
|
||||
elif isinstance(obs_space, gym.spaces.Tuple):
|
||||
elif isinstance(obs_space, spaces.Tuple):
|
||||
subspaces = {i: space for i, space in enumerate(obs_space.spaces)}
|
||||
else:
|
||||
assert not hasattr(obs_space, "spaces"), f"Unsupported structured space '{type(obs_space)}'"
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@ import pickle
|
|||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
from stable_baselines3.common import utils
|
||||
from stable_baselines3.common.preprocessing import is_image_space
|
||||
|
|
@ -48,14 +48,14 @@ class VecNormalize(VecEnvWrapper):
|
|||
if self.norm_obs:
|
||||
self._sanity_checks()
|
||||
|
||||
if isinstance(self.observation_space, gym.spaces.Dict):
|
||||
if isinstance(self.observation_space, spaces.Dict):
|
||||
self.obs_spaces = self.observation_space.spaces
|
||||
self.obs_rms = {key: RunningMeanStd(shape=self.obs_spaces[key].shape) for key in self.norm_obs_keys}
|
||||
# Update observation space when using image
|
||||
# See explanation below and GH #1214
|
||||
for key in self.obs_rms.keys():
|
||||
if is_image_space(self.obs_spaces[key]):
|
||||
self.observation_space.spaces[key] = gym.spaces.Box(
|
||||
self.observation_space.spaces[key] = spaces.Box(
|
||||
low=-clip_obs,
|
||||
high=clip_obs,
|
||||
shape=self.obs_spaces[key].shape,
|
||||
|
|
@ -74,7 +74,7 @@ class VecNormalize(VecEnvWrapper):
|
|||
# in other cases but this will cause backward-incompatible change
|
||||
# and break already saved policies.
|
||||
if is_image_space(self.observation_space):
|
||||
self.observation_space = gym.spaces.Box(
|
||||
self.observation_space = spaces.Box(
|
||||
low=-clip_obs,
|
||||
high=clip_obs,
|
||||
shape=self.observation_space.shape,
|
||||
|
|
@ -98,13 +98,13 @@ class VecNormalize(VecEnvWrapper):
|
|||
"""
|
||||
Check the observations that are going to be normalized are of the correct type (spaces.Box).
|
||||
"""
|
||||
if isinstance(self.observation_space, gym.spaces.Dict):
|
||||
if isinstance(self.observation_space, spaces.Dict):
|
||||
# By default, we normalize all keys
|
||||
if self.norm_obs_keys is None:
|
||||
self.norm_obs_keys = list(self.observation_space.spaces.keys())
|
||||
# Check that all keys are of type Box
|
||||
for obs_key in self.norm_obs_keys:
|
||||
if not isinstance(self.observation_space.spaces[obs_key], gym.spaces.Box):
|
||||
if not isinstance(self.observation_space.spaces[obs_key], spaces.Box):
|
||||
raise ValueError(
|
||||
f"VecNormalize only supports `gym.spaces.Box` observation spaces but {obs_key} "
|
||||
f"is of type {self.observation_space.spaces[obs_key]}. "
|
||||
|
|
@ -112,7 +112,7 @@ class VecNormalize(VecEnvWrapper):
|
|||
" that should be normalized via the `norm_obs_keys` parameter."
|
||||
)
|
||||
|
||||
elif isinstance(self.observation_space, gym.spaces.Box):
|
||||
elif isinstance(self.observation_space, spaces.Box):
|
||||
if self.norm_obs_keys is not None:
|
||||
raise ValueError("`norm_obs_keys` param is applicable only with `gym.spaces.Dict` observation spaces")
|
||||
|
||||
|
|
@ -143,7 +143,7 @@ class VecNormalize(VecEnvWrapper):
|
|||
|
||||
:param state:"""
|
||||
# Backward compatibility
|
||||
if "norm_obs_keys" not in state and isinstance(state["observation_space"], gym.spaces.Dict):
|
||||
if "norm_obs_keys" not in state and isinstance(state["observation_space"], spaces.Dict):
|
||||
state["norm_obs_keys"] = list(state["observation_space"].spaces.keys())
|
||||
self.__dict__.update(state)
|
||||
assert "venv" not in state
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from torch.nn import functional as F
|
||||
|
||||
from stable_baselines3.common.buffers import ReplayBuffer
|
||||
|
|
@ -116,7 +116,7 @@ class DQN(OffPolicyAlgorithm):
|
|||
seed=seed,
|
||||
sde_support=False,
|
||||
optimize_memory_usage=optimize_memory_usage,
|
||||
supported_action_spaces=(gym.spaces.Discrete,),
|
||||
supported_action_spaces=(spaces.Discrete,),
|
||||
support_multi_env=True,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from torch import nn
|
||||
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
|
|
@ -29,8 +29,8 @@ class QNetwork(BasePolicy):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
features_extractor: nn.Module,
|
||||
features_dim: int,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
|
|
@ -106,8 +106,8 @@ class DQNPolicy(BasePolicy):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
|
|
@ -227,8 +227,8 @@ class CnnPolicy(DQNPolicy):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
|
|
@ -272,8 +272,8 @@ class MultiInputPolicy(DQNPolicy):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Dict,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Dict,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from torch import nn
|
||||
|
||||
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
|
||||
|
|
@ -47,8 +47,8 @@ class Actor(BasePolicy):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
net_arch: List[int],
|
||||
features_extractor: nn.Module,
|
||||
features_dim: int,
|
||||
|
|
@ -207,8 +207,8 @@ class SACPolicy(BasePolicy):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
|
|
@ -385,8 +385,8 @@ class CnnPolicy(SACPolicy):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
|
|
@ -451,8 +451,8 @@ class MultiInputPolicy(SACPolicy):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from torch.nn import functional as F
|
||||
|
||||
from stable_baselines3.common.buffers import ReplayBuffer
|
||||
|
|
@ -133,7 +133,7 @@ class SAC(OffPolicyAlgorithm):
|
|||
sde_sample_freq=sde_sample_freq,
|
||||
use_sde_at_warmup=use_sde_at_warmup,
|
||||
optimize_memory_usage=optimize_memory_usage,
|
||||
supported_action_spaces=(gym.spaces.Box),
|
||||
supported_action_spaces=(spaces.Box),
|
||||
support_multi_env=True,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from torch import nn
|
||||
|
||||
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic
|
||||
|
|
@ -34,8 +34,8 @@ class Actor(BasePolicy):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
net_arch: List[int],
|
||||
features_extractor: nn.Module,
|
||||
features_dim: int,
|
||||
|
|
@ -108,8 +108,8 @@ class TD3Policy(BasePolicy):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
|
|
@ -271,8 +271,8 @@ class CnnPolicy(TD3Policy):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
|
|
@ -325,8 +325,8 @@ class MultiInputPolicy(TD3Policy):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Dict,
|
||||
action_space: gym.spaces.Space,
|
||||
observation_space: spaces.Dict,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from torch.nn import functional as F
|
||||
|
||||
from stable_baselines3.common.buffers import ReplayBuffer
|
||||
|
|
@ -116,7 +116,7 @@ class TD3(OffPolicyAlgorithm):
|
|||
seed=seed,
|
||||
sde_support=False,
|
||||
optimize_memory_usage=optimize_memory_usage,
|
||||
supported_action_spaces=(gym.spaces.Box),
|
||||
supported_action_spaces=(spaces.Box),
|
||||
support_multi_env=True,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.7.0a10
|
||||
1.7.0a11
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
import gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
from gym.spaces import Box, Dict, Discrete
|
||||
from gym import spaces
|
||||
|
||||
from stable_baselines3.common.env_checker import check_env
|
||||
|
||||
|
||||
class ActionDictTestEnv(gym.Env):
|
||||
action_space = Dict({"position": Discrete(1), "velocity": Discrete(1)})
|
||||
observation_space = Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32)
|
||||
action_space = spaces.Dict({"position": spaces.Discrete(1), "velocity": spaces.Discrete(1)})
|
||||
observation_space = spaces.Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32)
|
||||
|
||||
def step(self, action):
|
||||
observation = np.array([1.0, 1.5, 0.5], dtype=self.observation_space.dtype)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import gym
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
|
||||
from stable_baselines3 import A2C, PPO, SAC
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
|
|
@ -11,8 +12,8 @@ from stable_baselines3.common.policies import ActorCriticPolicy
|
|||
class CustomEnv(gym.Env):
|
||||
def __init__(self, max_steps=8):
|
||||
super().__init__()
|
||||
self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||
self.observation_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||
self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||
self.max_steps = max_steps
|
||||
self.n_steps = 0
|
||||
|
||||
|
|
@ -39,8 +40,8 @@ class InfiniteHorizonEnv(gym.Env):
|
|||
def __init__(self, n_states=4):
|
||||
super().__init__()
|
||||
self.n_states = n_states
|
||||
self.observation_space = gym.spaces.Discrete(n_states)
|
||||
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||
self.observation_space = spaces.Discrete(n_states)
|
||||
self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||
self.current_state = 0
|
||||
|
||||
def reset(self):
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import gym
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
from matplotlib import pyplot as plt
|
||||
from pandas.errors import EmptyDataError
|
||||
|
||||
|
|
@ -350,8 +351,8 @@ class TimeDelayEnv(gym.Env):
|
|||
def __init__(self, delay: float = 0.01):
|
||||
super().__init__()
|
||||
self.delay = delay
|
||||
self.observation_space = gym.spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32)
|
||||
self.action_space = gym.spaces.Discrete(2)
|
||||
self.observation_space = spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32)
|
||||
self.action_space = spaces.Discrete(2)
|
||||
|
||||
def reset(self):
|
||||
return self.observation_space.sample()
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import gym
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
|
||||
from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
|
||||
from stable_baselines3.common.envs import IdentityEnv
|
||||
|
|
@ -17,7 +18,7 @@ MODEL_LIST = [
|
|||
]
|
||||
|
||||
|
||||
class SubClassedBox(gym.spaces.Box):
|
||||
class SubClassedBox(spaces.Box):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
from gym import spaces
|
||||
|
||||
from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
|
|
@ -10,8 +11,8 @@ from stable_baselines3.common.evaluation import evaluate_policy
|
|||
class DummyMultiDiscreteSpace(gym.Env):
|
||||
def __init__(self, nvec):
|
||||
super().__init__()
|
||||
self.observation_space = gym.spaces.MultiDiscrete(nvec)
|
||||
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||
self.observation_space = spaces.MultiDiscrete(nvec)
|
||||
self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||
|
||||
def reset(self):
|
||||
return self.observation_space.sample()
|
||||
|
|
@ -23,8 +24,8 @@ class DummyMultiDiscreteSpace(gym.Env):
|
|||
class DummyMultiBinary(gym.Env):
|
||||
def __init__(self, n):
|
||||
super().__init__()
|
||||
self.observation_space = gym.spaces.MultiBinary(n)
|
||||
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||
self.observation_space = spaces.MultiBinary(n)
|
||||
self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||
|
||||
def reset(self):
|
||||
return self.observation_space.sample()
|
||||
|
|
@ -36,8 +37,8 @@ class DummyMultiBinary(gym.Env):
|
|||
class DummyMultidimensionalAction(gym.Env):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32)
|
||||
self.observation_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||
self.action_space = spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32)
|
||||
|
||||
def reset(self):
|
||||
return self.observation_space.sample()
|
||||
|
|
@ -55,7 +56,7 @@ def test_identity_spaces(model_class, env):
|
|||
"""
|
||||
# DQN only support discrete actions
|
||||
if model_class == DQN:
|
||||
env.action_space = gym.spaces.Discrete(4)
|
||||
env.action_space = spaces.Discrete(4)
|
||||
|
||||
env = gym.wrappers.TimeLimit(env, max_episode_steps=100)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import multiprocessing
|
|||
import gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
from gym import spaces
|
||||
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize
|
||||
|
|
@ -64,7 +65,7 @@ class CustomGymEnv(gym.Env):
|
|||
|
||||
def test_vecenv_func_checker():
|
||||
"""The functions in ``env_fns'' must return distinct instances since we need distinct environments."""
|
||||
env = CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2)))
|
||||
env = CustomGymEnv(spaces.Box(low=np.zeros(2), high=np.ones(2)))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
DummyVecEnv([lambda: env for _ in range(N_ENVS)])
|
||||
|
|
@ -78,7 +79,7 @@ def test_vecenv_custom_calls(vec_env_class, vec_env_wrapper):
|
|||
"""Test access to methods/attributes of vectorized environments"""
|
||||
|
||||
def make_env():
|
||||
return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2)))
|
||||
return CustomGymEnv(spaces.Box(low=np.zeros(2), high=np.ones(2)))
|
||||
|
||||
vec_env = vec_env_class([make_env for _ in range(N_ENVS)])
|
||||
|
||||
|
|
@ -147,8 +148,8 @@ class StepEnv(gym.Env):
|
|||
def __init__(self, max_steps):
|
||||
"""Gym environment for testing that terminal observation is inserted
|
||||
correctly."""
|
||||
self.action_space = gym.spaces.Discrete(2)
|
||||
self.observation_space = gym.spaces.Box(np.array([0]), np.array([999]), dtype="int")
|
||||
self.action_space = spaces.Discrete(2)
|
||||
self.observation_space = spaces.Box(np.array([0]), np.array([999]), dtype="int")
|
||||
self.max_steps = max_steps
|
||||
self.current_step = 0
|
||||
|
||||
|
|
@ -210,10 +211,10 @@ def test_vecenv_terminal_obs(vec_env_class, vec_env_wrapper):
|
|||
|
||||
SPACES = collections.OrderedDict(
|
||||
[
|
||||
("discrete", gym.spaces.Discrete(2)),
|
||||
("multidiscrete", gym.spaces.MultiDiscrete([2, 3])),
|
||||
("multibinary", gym.spaces.MultiBinary(3)),
|
||||
("continuous", gym.spaces.Box(low=np.zeros(2), high=np.ones(2))),
|
||||
("discrete", spaces.Discrete(2)),
|
||||
("multidiscrete", spaces.MultiDiscrete([2, 3])),
|
||||
("multibinary", spaces.MultiBinary(3)),
|
||||
("continuous", spaces.Box(low=np.zeros(2), high=np.ones(2))),
|
||||
]
|
||||
)
|
||||
|
||||
|
|
@ -252,7 +253,7 @@ def test_vecenv_single_space(vec_env_class, space):
|
|||
check_vecenv_spaces(vec_env_class, space, obs_assert)
|
||||
|
||||
|
||||
class _UnorderedDictSpace(gym.spaces.Dict):
|
||||
class _UnorderedDictSpace(spaces.Dict):
|
||||
"""Like DictSpace, but returns an unordered dict when sampling."""
|
||||
|
||||
def sample(self):
|
||||
|
|
@ -262,7 +263,7 @@ class _UnorderedDictSpace(gym.spaces.Dict):
|
|||
@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES)
|
||||
def test_vecenv_dict_spaces(vec_env_class):
|
||||
"""Test dictionary observation spaces with vectorized environments."""
|
||||
space = gym.spaces.Dict(SPACES)
|
||||
space = spaces.Dict(SPACES)
|
||||
|
||||
def obs_assert(obs):
|
||||
assert isinstance(obs, collections.OrderedDict)
|
||||
|
|
@ -280,7 +281,7 @@ def test_vecenv_dict_spaces(vec_env_class):
|
|||
@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES)
|
||||
def test_vecenv_tuple_spaces(vec_env_class):
|
||||
"""Test tuple observation spaces with vectorized environments."""
|
||||
space = gym.spaces.Tuple(tuple(SPACES.values()))
|
||||
space = spaces.Tuple(tuple(SPACES.values()))
|
||||
|
||||
def obs_assert(obs):
|
||||
assert isinstance(obs, tuple)
|
||||
|
|
@ -298,7 +299,7 @@ def test_subproc_start_method():
|
|||
all_methods = {"forkserver", "spawn", "fork"}
|
||||
available_methods = multiprocessing.get_all_start_methods()
|
||||
start_methods += list(all_methods.intersection(available_methods))
|
||||
space = gym.spaces.Discrete(2)
|
||||
space = spaces.Discrete(2)
|
||||
|
||||
def obs_assert(obs):
|
||||
return check_vecenv_obs(obs, space)
|
||||
|
|
@ -338,7 +339,7 @@ class CustomWrapperBB(CustomWrapperB):
|
|||
|
||||
def test_vecenv_wrapper_getattr():
|
||||
def make_env():
|
||||
return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2)))
|
||||
return CustomGymEnv(spaces.Box(low=np.zeros(2), high=np.ones(2)))
|
||||
|
||||
vec_env = DummyVecEnv([make_env for _ in range(N_ENVS)])
|
||||
wrapped = CustomWrapperA(CustomWrapperBB(vec_env))
|
||||
|
|
@ -367,7 +368,7 @@ def test_framestack_vecenv():
|
|||
|
||||
def make_image_env():
|
||||
return CustomGymEnv(
|
||||
gym.spaces.Box(
|
||||
spaces.Box(
|
||||
low=np.zeros(image_space_shape),
|
||||
high=np.ones(image_space_shape) * 255,
|
||||
dtype=np.uint8,
|
||||
|
|
@ -376,7 +377,7 @@ def test_framestack_vecenv():
|
|||
|
||||
def make_transposed_image_env():
|
||||
return CustomGymEnv(
|
||||
gym.spaces.Box(
|
||||
spaces.Box(
|
||||
low=np.zeros(transposed_image_space_shape),
|
||||
high=np.ones(transposed_image_space_shape) * 255,
|
||||
dtype=np.uint8,
|
||||
|
|
@ -384,7 +385,7 @@ def test_framestack_vecenv():
|
|||
)
|
||||
|
||||
def make_non_image_env():
|
||||
return CustomGymEnv(gym.spaces.Box(low=np.zeros((2,)), high=np.ones((2,))))
|
||||
return CustomGymEnv(spaces.Box(low=np.zeros((2,)), high=np.ones((2,))))
|
||||
|
||||
vec_env = DummyVecEnv([make_image_env for _ in range(N_ENVS)])
|
||||
vec_env = VecFrameStack(vec_env, n_stack=2)
|
||||
|
|
@ -433,10 +434,10 @@ def test_framestack_vecenv():
|
|||
def test_vec_env_is_wrapped():
|
||||
# Test is_wrapped call of subproc workers
|
||||
def make_env():
|
||||
return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2)))
|
||||
return CustomGymEnv(spaces.Box(low=np.zeros(2), high=np.ones(2)))
|
||||
|
||||
def make_monitored_env():
|
||||
return Monitor(CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2))))
|
||||
return Monitor(CustomGymEnv(spaces.Box(low=np.zeros(2), high=np.ones(2))))
|
||||
|
||||
# One with monitor, one without
|
||||
vec_env = SubprocVecEnv([make_env, make_monitored_env])
|
||||
|
|
@ -457,7 +458,7 @@ def test_vec_env_is_wrapped():
|
|||
@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES)
|
||||
def test_vec_seeding(vec_env_class):
|
||||
def make_env():
|
||||
return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2)))
|
||||
return CustomGymEnv(spaces.Box(low=np.zeros(2), high=np.ones(2)))
|
||||
|
||||
# For SubprocVecEnv check for all starting methods
|
||||
start_methods = [None]
|
||||
|
|
|
|||
|
|
@ -23,8 +23,8 @@ class DummyRewardEnv(gym.Env):
|
|||
metadata = {}
|
||||
|
||||
def __init__(self, return_reward_idx=0):
|
||||
self.action_space = gym.spaces.Discrete(2)
|
||||
self.observation_space = gym.spaces.Box(low=np.array([-1.0]), high=np.array([1.0]))
|
||||
self.action_space = spaces.Discrete(2)
|
||||
self.observation_space = spaces.Box(low=np.array([-1.0]), high=np.array([1.0]))
|
||||
self.returned_rewards = [0, 1, 3, 4]
|
||||
self.return_reward_idx = return_reward_idx
|
||||
self.t = self.return_reward_idx
|
||||
|
|
|
|||
Loading…
Reference in a new issue