Type annotation bundle (logger, vec env, custom envs) (#1479)

* Switch from List to Sequence for `seed()` type hint

* Fix logger type hints

* Improve replay buffer type hints

* Fix custom envs type annotations

* Fix VecMonitor type hints

* Fix RMSprop type hint

* Fix vec extract dict obs type hints

* Fix vec frame stack type annotations

* Fix base vec env type hints

* Fix dummy vec env type hints

* Fix for mypy

* Fixes for the tests

* mypy doesn't like when we overwrite type

* fix step of SimpleMultiObsEnv

* remove useless type specification

* Rm useless type hint

* Improve logger type hint

* format

* rm useless type hint

* Re-add variables in constructor, remove unused import

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
This commit is contained in:
Antonin RAFFIN 2023-05-04 20:27:15 +02:00 committed by GitHub
parent d6ddee9366
commit 63a0bb9da1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 113 additions and 96 deletions

View file

@ -3,7 +3,7 @@
Changelog
==========
Release 2.0.0a6 (WIP)
Release 2.0.0a7 (WIP)
--------------------------
**Gymnasium support**
@ -48,12 +48,19 @@ Others:
- Fixed ``stable_baselines3/sac/*.py`` type hints
- Fixed ``stable_baselines3/td3/*.py`` type hints
- Fixed ``stable_baselines3/common/base_class.py`` type hints
- Fixed ``stable_baselines3/common/logger.py`` type hints
- Fixed ``stable_baselines3/common/envs/*.py`` type hints
- Fixed ``stable_baselines3/common/vec_env/vec_monitor|vec_extract_dict_obs|util.py`` type hints
- Fixed ``stable_baselines3/common/vec_env/base_vec_env.py`` type hints
- Fixed ``stable_baselines3/common/vec_env/vec_frame_stack.py`` type hints
- Fixed ``stable_baselines3/common/vec_env/dummy_vec_env.py`` type hints
- Upgraded docker images to use mamba/micromamba and CUDA 11.7
- Updated env checker to reflect what subset of Gymnasium is supported and improve GoalEnv checks
- Improve type annotation of wrappers
- Tests envs are now checked too
- Added render test for ``VecEnv``
- Update issue templates and env info saved with the model
- Changed ``seed()`` method return type from ``List`` to ``Sequence``
Documentation:
^^^^^^^^^^^^^^

View file

@ -38,23 +38,12 @@ exclude = """(?x)(
stable_baselines3/common/buffers.py$
| stable_baselines3/common/callbacks.py$
| stable_baselines3/common/distributions.py$
| stable_baselines3/common/envs/bit_flipping_env.py$
| stable_baselines3/common/envs/identity_env.py$
| stable_baselines3/common/envs/multi_input_envs.py$
| stable_baselines3/common/logger.py$
| stable_baselines3/common/off_policy_algorithm.py$
| stable_baselines3/common/policies.py$
| stable_baselines3/common/save_util.py$
| stable_baselines3/common/sb2_compat/rmsprop_tf_like.py$
| stable_baselines3/common/utils.py$
| stable_baselines3/common/vec_env/__init__.py$
| stable_baselines3/common/vec_env/base_vec_env.py$
| stable_baselines3/common/vec_env/dummy_vec_env.py$
| stable_baselines3/common/vec_env/subproc_vec_env.py$
| stable_baselines3/common/vec_env/util.py$
| stable_baselines3/common/vec_env/vec_extract_dict_obs.py$
| stable_baselines3/common/vec_env/vec_frame_stack.py$
| stable_baselines3/common/vec_env/vec_monitor.py$
| stable_baselines3/common/vec_env/vec_normalize.py$
| stable_baselines3/common/vec_env/vec_transpose.py$
| stable_baselines3/common/vec_env/vec_video_recorder.py$

View file

@ -681,6 +681,8 @@ class DictRolloutBuffer(RolloutBuffer):
:param n_envs: Number of parallel environments
"""
observations: Dict[str, np.ndarray]
def __init__(
self,
buffer_size: int,
@ -697,8 +699,7 @@ class DictRolloutBuffer(RolloutBuffer):
self.gae_lambda = gae_lambda
self.gamma = gamma
self.observations, self.actions, self.rewards, self.advantages = None, None, None, None
self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None
self.generator_ready = False
self.reset()

View file

@ -26,6 +26,7 @@ class BitFlippingEnv(Env):
"""
spec = EnvSpec("BitFlippingEnv-v0", "no-entry-point")
state: np.ndarray
def __init__(
self,
@ -35,8 +36,10 @@ class BitFlippingEnv(Env):
discrete_obs_space: bool = False,
image_obs_space: bool = False,
channel_first: bool = True,
render_mode: str = "human",
):
super().__init__()
self.render_mode = render_mode
# Shape of the observation when using image space
self.image_shape = (1, 36, 36) if channel_first else (36, 36, 1)
# The achieved goal is determined by the current state
@ -95,7 +98,6 @@ class BitFlippingEnv(Env):
self.continuous = continuous
self.discrete_obs_space = discrete_obs_space
self.image_obs_space = image_obs_space
self.state = None
self.desired_goal = np.ones((n_bits,), dtype=self.observation_space["desired_goal"].dtype)
if max_steps is None:
max_steps = n_bits
@ -127,21 +129,20 @@ class BitFlippingEnv(Env):
"""
Convert to bit vector if needed.
:param state:
:param batch_size:
:return:
:param state: The state to be converted, which can be either an integer or a numpy array.
:param batch_size: The batch size.
:return: The state converted into a bit vector.
"""
# Convert back to bit vector
if isinstance(state, int):
state = np.array(state).reshape(batch_size, -1)
bit_vector = np.array(state).reshape(batch_size, -1)
# Convert to binary representation
state = ((state[:, :] & (1 << np.arange(len(self.state)))) > 0).astype(int)
bit_vector = ((bit_vector[:, :] & (1 << np.arange(len(self.state)))) > 0).astype(int)
elif self.image_obs_space:
state = state.reshape(batch_size, -1)[:, : len(self.state)] / 255
bit_vector = state.reshape(batch_size, -1)[:, : len(self.state)] / 255
else:
state = np.array(state).reshape(batch_size, -1)
return state
bit_vector = np.array(state).reshape(batch_size, -1)
return bit_vector
def _get_obs(self) -> Dict[str, Union[int, np.ndarray]]:
"""
@ -205,10 +206,11 @@ class BitFlippingEnv(Env):
distance = np.linalg.norm(achieved_goal - desired_goal, axis=-1)
return -(distance > 0).astype(np.float32)
def render(self, mode: str = "human") -> Optional[np.ndarray]:
if mode == "rgb_array":
def render(self) -> Optional[np.ndarray]: # type: ignore[override]
if self.render_mode == "rgb_array":
return self.state.copy()
print(self.state)
return None
def close(self) -> None:
pass

View file

@ -1,4 +1,4 @@
from typing import Dict, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
import gymnasium as gym
import numpy as np
@ -73,7 +73,7 @@ class SimpleMultiObsEnv(gym.Env):
self.init_possible_transitions()
self.num_col = num_col
self.state_mapping = []
self.state_mapping: List[Dict[str, np.ndarray]] = []
self.init_state_mapping(num_col, num_row)
self.max_state = len(self.state_mapping) - 1
@ -121,20 +121,18 @@ class SimpleMultiObsEnv(gym.Env):
self.right_possible = [0, 1, 2, 12, 13, 14]
self.up_possible = [4, 8, 12, 7, 11, 15]
def step(self, action: Union[float, np.ndarray]) -> GymStepReturn:
def step(self, action: Union[int, np.ndarray]) -> GymStepReturn:
"""
Run one timestep of the environment's dynamics. When end of
episode is reached, you are responsible for calling `reset()`
to reset this environment's state.
Accepts an action and returns a tuple (observation, reward, done, info).
Accepts an action and returns a tuple (observation, reward, terminated, truncated, info).
:param action:
:return: tuple (observation, reward, done, info).
:return: tuple (observation, reward, terminated, truncated, info).
"""
if not self.discrete_actions:
action = np.argmax(action)
else:
action = int(action)
action = np.argmax(action) # type: ignore[assignment]
self.count += 1

View file

@ -5,6 +5,7 @@ import sys
import tempfile
import warnings
from collections import defaultdict
from io import TextIOWrapper
from typing import Any, Dict, List, Mapping, Optional, Sequence, TextIO, Tuple, Union
import numpy as np
@ -113,7 +114,7 @@ class KVWriter:
Key Value writer
"""
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None:
"""
Write a dictionary to file
@ -135,7 +136,7 @@ class SeqWriter:
sequence writer
"""
def write_sequence(self, sequence: List) -> None:
def write_sequence(self, sequence: List[str]) -> None:
"""
write_sequence an array to file
@ -163,15 +164,16 @@ class HumanOutputFormat(KVWriter, SeqWriter):
if isinstance(filename_or_file, str):
self.file = open(filename_or_file, "w")
self.own_file = True
else:
assert hasattr(filename_or_file, "write"), f"Expected file or str, got {filename_or_file}"
elif isinstance(filename_or_file, TextIOWrapper): # equivalent to `isinstance(..., TextIO)` (not supported)
self.file = filename_or_file
self.own_file = False
else:
raise ValueError(f"Expected file or str, got {filename_or_file}")
def write(self, key_values: Dict, key_excluded: Dict, step: int = 0) -> None:
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None:
# Create strings for printing
key2str = {}
tag = None
tag = ""
for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())):
if excluded is not None and ("stdout" in excluded or "log" in excluded):
continue
@ -197,9 +199,9 @@ class HumanOutputFormat(KVWriter, SeqWriter):
if key.find("/") > 0: # Find tag and add it to the dict
tag = key[: key.find("/") + 1]
key2str[(tag, self._truncate(tag))] = ""
# Remove tag from key
if tag is not None and tag in key:
key = str(" " + key[len(tag) :])
# Remove tag from key and indent the key
if len(tag) > 0 and tag in key:
key = f"{'':3}{key[len(tag) :]}"
truncated_key = self._truncate(key)
if (tag, truncated_key) in key2str:
@ -240,8 +242,7 @@ class HumanOutputFormat(KVWriter, SeqWriter):
string = string[: self.max_length - 3] + "..."
return string
def write_sequence(self, sequence: List) -> None:
sequence = list(sequence)
def write_sequence(self, sequence: List[str]) -> None:
for i, elem in enumerate(sequence):
self.file.write(elem)
if i < len(sequence) - 1: # add space unless this is the last one
@ -257,9 +258,7 @@ class HumanOutputFormat(KVWriter, SeqWriter):
self.file.close()
def filter_excluded_keys(
key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], _format: str
) -> Dict[str, Any]:
def filter_excluded_keys(key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], _format: str) -> Dict[str, Any]:
"""
Filters the keys specified by ``key_exclude`` for the specified format
@ -285,7 +284,7 @@ class JSONOutputFormat(KVWriter):
def __init__(self, filename: str):
self.file = open(filename, "w")
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None:
def cast_to_json_serializable(value: Any):
if isinstance(value, Video):
raise FormatUnsupportedError(["json"], "video")
@ -332,7 +331,7 @@ class CSVOutputFormat(KVWriter):
self.separator = ","
self.quotechar = '"'
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None:
# Add our current row to the history
key_values = filter_excluded_keys(key_values, key_excluded, "csv")
extra_keys = key_values.keys() - self.keys
@ -394,10 +393,12 @@ class TensorBoardOutputFormat(KVWriter):
"""
def __init__(self, folder: str):
assert SummaryWriter is not None, "tensorboard is not installed, you can use " "pip install tensorboard to do so"
assert SummaryWriter is not None, "tensorboard is not installed, you can use `pip install tensorboard` to do so"
self.writer = SummaryWriter(log_dir=folder)
self._is_closed = False
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None:
assert not self._is_closed, "The SummaryWriter was closed, please re-create one."
for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())):
if excluded is not None and "tensorboard" in excluded:
continue
@ -437,7 +438,7 @@ class TensorBoardOutputFormat(KVWriter):
"""
if self.writer:
self.writer.close()
self.writer = None
self._is_closed = True
def make_output_format(_format: str, log_dir: str, log_suffix: str = "") -> KVWriter:
@ -478,13 +479,24 @@ class Logger:
"""
def __init__(self, folder: Optional[str], output_formats: List[KVWriter]):
self.name_to_value = defaultdict(float) # values this iteration
self.name_to_count = defaultdict(int)
self.name_to_excluded = defaultdict(str)
self.name_to_value: Dict[str, float] = defaultdict(float) # values this iteration
self.name_to_count: Dict[str, int] = defaultdict(int)
self.name_to_excluded: Dict[str, Tuple[str, ...]] = {}
self.level = INFO
self.dir = folder
self.output_formats = output_formats
@staticmethod
def to_tuple(string_or_tuple: Optional[Union[str, Tuple[str, ...]]]) -> Tuple[str, ...]:
"""
Helper function to convert str to tuple of str.
"""
if string_or_tuple is None:
return ("",)
if isinstance(string_or_tuple, tuple):
return string_or_tuple
return (string_or_tuple,)
def record(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
"""
Log a value of some diagnostic
@ -496,9 +508,9 @@ class Logger:
:param exclude: outputs to be excluded
"""
self.name_to_value[key] = value
self.name_to_excluded[key] = exclude
self.name_to_excluded[key] = self.to_tuple(exclude)
def record_mean(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
def record_mean(self, key: str, value: Optional[float], exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
"""
The same as record(), but if called many times, values averaged.
@ -507,12 +519,11 @@ class Logger:
:param exclude: outputs to be excluded
"""
if value is None:
self.name_to_value[key] = None
return
old_val, count = self.name_to_value[key], self.name_to_count[key]
self.name_to_value[key] = old_val * count / (count + 1) + value / (count + 1)
self.name_to_count[key] = count + 1
self.name_to_excluded[key] = exclude
self.name_to_excluded[key] = self.to_tuple(exclude)
def dump(self, step: int = 0) -> None:
"""
@ -592,7 +603,7 @@ class Logger:
"""
self.level = level
def get_dir(self) -> str:
def get_dir(self) -> Optional[str]:
"""
Get directory that log files are being written to.
will be None if there is no output directory (i.e., if you didn't call start)
@ -610,7 +621,7 @@ class Logger:
# Misc
# ----------------------------------------
def _do_log(self, args) -> None:
def _do_log(self, args: Tuple[Any, ...]) -> None:
"""
log to the requested format outputs
@ -618,7 +629,7 @@ class Logger:
"""
for _format in self.output_formats:
if isinstance(_format, SeqWriter):
_format.write_sequence(map(str, args))
_format.write_sequence(list(map(str, args)))
def configure(folder: Optional[str] = None, format_strings: Optional[List[str]] = None) -> Logger:

View file

@ -74,7 +74,7 @@ class RMSpropTFLike(Optimizer):
group.setdefault("centered", False)
@torch.no_grad()
def step(self, closure: Optional[Callable[[], None]] = None) -> Optional[torch.Tensor]:
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
"""Performs a single optimization step.
:param closure: A closure that reevaluates the model

View file

@ -19,17 +19,17 @@ VecEnvObs = Union[np.ndarray, Dict[str, np.ndarray], Tuple[np.ndarray, ...]]
VecEnvStepReturn = Tuple[VecEnvObs, np.ndarray, np.ndarray, List[Dict]]
def tile_images(img_nhwc: Sequence[np.ndarray]) -> np.ndarray: # pragma: no cover
def tile_images(images_nhwc: Sequence[np.ndarray]) -> np.ndarray: # pragma: no cover
"""
Tile N images into one big PxQ image
(P,Q) are chosen to be as close as possible, and if N
is square, then P=Q.
:param img_nhwc: list or array of images, ndim=4 once turned into array. img nhwc
:param images_nhwc: list or array of images, ndim=4 once turned into array.
n = batch index, h = height, w = width, c = channel
:return: img_HWc, ndim=3
"""
img_nhwc = np.asarray(img_nhwc)
img_nhwc = np.asarray(images_nhwc)
n_images, height, width, n_channels = img_nhwc.shape
# new_height was named H before
new_height = int(np.ceil(np.sqrt(n_images)))
@ -67,7 +67,8 @@ class VecEnv(ABC):
self.observation_space = observation_space
self.action_space = action_space
self.render_mode = render_mode
self.reset_infos = [{} for _ in range(num_envs)] # store info returned by the reset method
# store info returned by the reset method
self.reset_infos: List[Dict[str, Any]] = [{} for _ in range(num_envs)]
@abstractmethod
def reset(self) -> VecEnvObs:
@ -192,7 +193,7 @@ class VecEnv(ABC):
"but the render mode defined when initializing the environment must be "
f"'human' or 'rgb_array', not '{self.render_mode}'."
)
return
return None
elif mode and self.render_mode != mode:
warnings.warn(
@ -200,26 +201,26 @@ class VecEnv(ABC):
We allow to pass a mode argument to maintain a backwards compatible VecEnv API, but the mode ({mode})
has to be the same as the environment render mode ({self.render_mode}) which is not the case."""
)
return
return None
mode = mode or self.render_mode
if mode is None:
warnings.warn("You tried to call render() but no `render_mode` was passed to the env constructor.")
return
return None
# mode == self.render_mode == "human"
# In that case, we try to call `self.env.render()` but it might
# crash for subprocesses
if self.render_mode == "human":
self.env_method("render")
return
return None
if mode == "rgb_array" or mode == "human":
# call the render method of the environments
images = self.get_images()
# Create a big image by tiling images from subprocesses
bigimg = tile_images(images)
bigimg = tile_images(images) # type: ignore[arg-type]
if mode == "human":
# Display it using OpenCV
@ -236,9 +237,10 @@ class VecEnv(ABC):
# crash for subprocesses
# and we don't return the values
self.env_method("render")
return None
@abstractmethod
def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]:
"""
Sets the random seeds for all environments, based on a given seed.
Each individual environment will still get its own seed, by incrementing the given seed.
@ -319,7 +321,7 @@ class VecEnvWrapper(VecEnv):
def step_wait(self) -> VecEnvStepReturn:
pass
def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]:
return self.venv.seed(seed)
def close(self) -> None:
@ -394,7 +396,7 @@ class VecEnvWrapper(VecEnv):
all_attributes = self._get_all_attributes()
if name in all_attributes and already_found:
# this venv's attribute is being hidden because of a higher venv.
shadowed_wrapper_class = f"{type(self).__module__}.{type(self).__name__}"
shadowed_wrapper_class: Optional[str] = f"{type(self).__module__}.{type(self).__name__}"
elif name in all_attributes and not already_found:
# we have found the first reference to the attribute. Now check for duplicates.
shadowed_wrapper_class = self.venv.getattr_depth_check(name, True)

View file

@ -1,7 +1,7 @@
import warnings
from collections import OrderedDict
from copy import deepcopy
from typing import Any, Callable, List, Optional, Sequence, Type, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
import gymnasium as gym
import numpy as np
@ -24,6 +24,8 @@ class DummyVecEnv(VecEnv):
:raises ValueError: If the same environment instance is passed as the output of two or more different env_fn.
"""
actions: np.ndarray
def __init__(self, env_fns: List[Callable[[], gym.Env]]):
self.envs = [_patch_env(fn()) for fn in env_fns]
if len(set([id(env.unwrapped) for env in self.envs])) != len(self.envs):
@ -44,8 +46,7 @@ class DummyVecEnv(VecEnv):
self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs, *tuple(shapes[k])), dtype=dtypes[k])) for k in self.keys])
self.buf_dones = np.zeros((self.num_envs,), dtype=bool)
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
self.buf_infos = [{} for _ in range(self.num_envs)]
self.actions = None
self.buf_infos: List[Dict[str, Any]] = [{} for _ in range(self.num_envs)]
self.metadata = env.metadata
def step_async(self, actions: np.ndarray) -> None:
@ -70,7 +71,7 @@ class DummyVecEnv(VecEnv):
self._save_obs(env_idx, obs)
return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos))
def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]:
# Avoid circular import
from stable_baselines3.common.utils import compat_gym_seed
@ -78,7 +79,7 @@ class DummyVecEnv(VecEnv):
seed = np.random.randint(0, 2**32 - 1)
seeds = []
for idx, env in enumerate(self.envs):
seeds.append(compat_gym_seed(env, seed=seed + idx))
seeds.append(compat_gym_seed(env, seed=seed + idx)) # type: ignore[func-returns-value]
return seeds
def reset(self) -> VecEnvObs:
@ -97,7 +98,7 @@ class DummyVecEnv(VecEnv):
f"The render mode is {self.render_mode}, but this method assumes it is `rgb_array` to obtain images."
)
return [None for _ in self.envs]
return [env.render() for env in self.envs]
return [env.render() for env in self.envs] # type: ignore[misc]
def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]:
"""
@ -113,7 +114,7 @@ class DummyVecEnv(VecEnv):
if key is None:
self.buf_obs[key][env_idx] = obs
else:
self.buf_obs[key][env_idx] = obs[key]
self.buf_obs[key][env_idx] = obs[key] # type: ignore[call-overload]
def _obs_from_buf(self) -> VecEnvObs:
return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs))

View file

@ -135,7 +135,7 @@ class SubprocVecEnv(VecEnv):
obs, rews, dones, infos, self.reset_infos = zip(*results)
return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos
def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]:
if seed is None:
seed = np.random.randint(0, 2**32 - 1)
for idx, remote in enumerate(self.remotes):

View file

@ -62,10 +62,10 @@ def obs_space_info(obs_space: spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[
assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces"
subspaces = obs_space.spaces
elif isinstance(obs_space, spaces.Tuple):
subspaces = {i: space for i, space in enumerate(obs_space.spaces)}
subspaces = {i: space for i, space in enumerate(obs_space.spaces)} # type: ignore[assignment]
else:
assert not hasattr(obs_space, "spaces"), f"Unsupported structured space '{type(obs_space)}'"
subspaces = {None: obs_space}
subspaces = {None: obs_space} # type: ignore[assignment]
keys = []
shapes = {}
dtypes = {}

View file

@ -1,4 +1,5 @@
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
@ -13,14 +14,19 @@ class VecExtractDictObs(VecEnvWrapper):
def __init__(self, venv: VecEnv, key: str):
self.key = key
assert isinstance(
venv.observation_space, spaces.Dict
), f"VecExtractDictObs can only be used with Dict obs space, not {venv.observation_space}"
super().__init__(venv=venv, observation_space=venv.observation_space.spaces[self.key])
def reset(self) -> np.ndarray:
obs = self.venv.reset()
assert isinstance(obs, dict)
return obs[self.key]
def step_wait(self) -> VecEnvStepReturn:
obs, reward, done, infos = self.venv.step_wait()
assert isinstance(obs, dict)
for info in infos:
if "terminal_observation" in info:
info["terminal_observation"] = info["terminal_observation"][self.key]

View file

@ -31,7 +31,7 @@ class VecFrameStack(VecEnvWrapper):
self,
) -> Tuple[Union[np.ndarray, Dict[str, np.ndarray]], np.ndarray, np.ndarray, List[Dict[str, Any]],]:
observations, rewards, dones, infos = self.venv.step_wait()
observations, infos = self.stacked_obs.update(observations, dones, infos)
observations, infos = self.stacked_obs.update(observations, dones, infos) # type: ignore[arg-type]
return observations, rewards, dones, infos
def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
@ -39,5 +39,5 @@ class VecFrameStack(VecEnvWrapper):
Reset all environments
"""
observation = self.venv.reset() # pytype:disable=annotation-type-mismatch
observation = self.stacked_obs.reset(observation)
observation = self.stacked_obs.reset(observation) # type: ignore[arg-type]
return observation

View file

@ -49,8 +49,6 @@ class VecMonitor(VecEnvWrapper):
)
VecEnvWrapper.__init__(self, venv)
self.episode_returns = None
self.episode_lengths = None
self.episode_count = 0
self.t_start = time.time()
@ -58,13 +56,15 @@ class VecMonitor(VecEnvWrapper):
if hasattr(venv, "spec") and venv.spec is not None:
env_id = venv.spec.id
self.results_writer: Optional[ResultsWriter] = None
if filename:
self.results_writer = ResultsWriter(
filename, header={"t_start": self.t_start, "env_id": env_id}, extra_keys=info_keywords
filename, header={"t_start": self.t_start, "env_id": str(env_id)}, extra_keys=info_keywords
)
else:
self.results_writer = None
self.info_keywords = info_keywords
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
def reset(self) -> VecEnvObs:
obs = self.venv.reset()

View file

@ -1 +1 @@
2.0.0a6
2.0.0a7

View file

@ -13,7 +13,7 @@ from stable_baselines3.common.atari_wrappers import MaxAndSkipEnv
from stable_baselines3.common.env_util import is_wrapped, make_atari_env, make_vec_env, unwrap_wrapper
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.noise import ActionNoise, OrnsteinUhlenbeckActionNoise, VectorizedActionNoise
from stable_baselines3.common.noise import OrnsteinUhlenbeckActionNoise, VectorizedActionNoise
from stable_baselines3.common.utils import (
check_shape_equal,
get_parameters_by_name,
@ -349,7 +349,7 @@ def test_vec_noise():
num_actions = 10
mu = np.zeros(num_actions)
sigma = np.ones(num_actions) * 0.4
base: ActionNoise = OrnsteinUhlenbeckActionNoise(mu, sigma)
base = OrnsteinUhlenbeckActionNoise(mu, sigma)
with pytest.raises(ValueError):
vec = VectorizedActionNoise(base, -1)
with pytest.raises(ValueError):