From 63a0bb9da13cb98450b83dae75ea2cce582c9b92 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Thu, 4 May 2023 20:27:15 +0200 Subject: [PATCH] Type annotation bundle (logger, vec env, custom envs) (#1479) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Switch from List to Sequence for `seed()` type hint * Fix logger type hints * Improve replay buffer type hints * Fix custom envs type annotations * Fix VecMonitor type hints * Fix RMSprop type hint * Fix vec extract dict obs type hints * Fix vec frame stack type annotations * Fix base vec env type hints * Fix dummy vec env type hints * Fix for mypy * Fixes for the tests * mypy doesn't like when we overwrite type * fix step of SimpleMultiObsEnv * remove useless type specification * Rm useless type hint * Improve logger type hint * format * rm useless type hint * Re-add variables in constructor, remove unused import --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- docs/misc/changelog.rst | 9 ++- pyproject.toml | 11 --- stable_baselines3/common/buffers.py | 5 +- .../common/envs/bit_flipping_env.py | 26 +++---- .../common/envs/multi_input_envs.py | 14 ++-- stable_baselines3/common/logger.py | 69 +++++++++++-------- .../common/sb2_compat/rmsprop_tf_like.py | 2 +- .../common/vec_env/base_vec_env.py | 26 +++---- .../common/vec_env/dummy_vec_env.py | 15 ++-- .../common/vec_env/subproc_vec_env.py | 2 +- stable_baselines3/common/vec_env/util.py | 4 +- .../common/vec_env/vec_extract_dict_obs.py | 6 ++ .../common/vec_env/vec_frame_stack.py | 4 +- .../common/vec_env/vec_monitor.py | 10 +-- stable_baselines3/version.txt | 2 +- tests/test_utils.py | 4 +- 16 files changed, 113 insertions(+), 96 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index c08a923..60750c2 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.0.0a6 (WIP) +Release 2.0.0a7 (WIP) -------------------------- **Gymnasium support** @@ -48,12 +48,19 @@ Others: - Fixed ``stable_baselines3/sac/*.py`` type hints - Fixed ``stable_baselines3/td3/*.py`` type hints - Fixed ``stable_baselines3/common/base_class.py`` type hints +- Fixed ``stable_baselines3/common/logger.py`` type hints +- Fixed ``stable_baselines3/common/envs/*.py`` type hints +- Fixed ``stable_baselines3/common/vec_env/vec_monitor|vec_extract_dict_obs|util.py`` type hints +- Fixed ``stable_baselines3/common/vec_env/base_vec_env.py`` type hints +- Fixed ``stable_baselines3/common/vec_env/vec_frame_stack.py`` type hints +- Fixed ``stable_baselines3/common/vec_env/dummy_vec_env.py`` type hints - Upgraded docker images to use mamba/micromamba and CUDA 11.7 - Updated env checker to reflect what subset of Gymnasium is supported and improve GoalEnv checks - Improve type annotation of wrappers - Tests envs are now checked too - Added render test for ``VecEnv`` - Update issue templates and env info saved with the model +- Changed ``seed()`` method return type from ``List`` to ``Sequence`` Documentation: ^^^^^^^^^^^^^^ diff --git a/pyproject.toml b/pyproject.toml index b44edf5..99bd218 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,23 +38,12 @@ exclude = """(?x)( stable_baselines3/common/buffers.py$ | stable_baselines3/common/callbacks.py$ | stable_baselines3/common/distributions.py$ - | stable_baselines3/common/envs/bit_flipping_env.py$ - | stable_baselines3/common/envs/identity_env.py$ - | stable_baselines3/common/envs/multi_input_envs.py$ - | stable_baselines3/common/logger.py$ | stable_baselines3/common/off_policy_algorithm.py$ | stable_baselines3/common/policies.py$ | stable_baselines3/common/save_util.py$ - | stable_baselines3/common/sb2_compat/rmsprop_tf_like.py$ | stable_baselines3/common/utils.py$ | stable_baselines3/common/vec_env/__init__.py$ - | stable_baselines3/common/vec_env/base_vec_env.py$ - | stable_baselines3/common/vec_env/dummy_vec_env.py$ | stable_baselines3/common/vec_env/subproc_vec_env.py$ - | stable_baselines3/common/vec_env/util.py$ - | stable_baselines3/common/vec_env/vec_extract_dict_obs.py$ - | stable_baselines3/common/vec_env/vec_frame_stack.py$ - | stable_baselines3/common/vec_env/vec_monitor.py$ | stable_baselines3/common/vec_env/vec_normalize.py$ | stable_baselines3/common/vec_env/vec_transpose.py$ | stable_baselines3/common/vec_env/vec_video_recorder.py$ diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index e52f08f..fe633e1 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -681,6 +681,8 @@ class DictRolloutBuffer(RolloutBuffer): :param n_envs: Number of parallel environments """ + observations: Dict[str, np.ndarray] + def __init__( self, buffer_size: int, @@ -697,8 +699,7 @@ class DictRolloutBuffer(RolloutBuffer): self.gae_lambda = gae_lambda self.gamma = gamma - self.observations, self.actions, self.rewards, self.advantages = None, None, None, None - self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None + self.generator_ready = False self.reset() diff --git a/stable_baselines3/common/envs/bit_flipping_env.py b/stable_baselines3/common/envs/bit_flipping_env.py index ec0de2b..869b463 100644 --- a/stable_baselines3/common/envs/bit_flipping_env.py +++ b/stable_baselines3/common/envs/bit_flipping_env.py @@ -26,6 +26,7 @@ class BitFlippingEnv(Env): """ spec = EnvSpec("BitFlippingEnv-v0", "no-entry-point") + state: np.ndarray def __init__( self, @@ -35,8 +36,10 @@ class BitFlippingEnv(Env): discrete_obs_space: bool = False, image_obs_space: bool = False, channel_first: bool = True, + render_mode: str = "human", ): super().__init__() + self.render_mode = render_mode # Shape of the observation when using image space self.image_shape = (1, 36, 36) if channel_first else (36, 36, 1) # The achieved goal is determined by the current state @@ -95,7 +98,6 @@ class BitFlippingEnv(Env): self.continuous = continuous self.discrete_obs_space = discrete_obs_space self.image_obs_space = image_obs_space - self.state = None self.desired_goal = np.ones((n_bits,), dtype=self.observation_space["desired_goal"].dtype) if max_steps is None: max_steps = n_bits @@ -127,21 +129,20 @@ class BitFlippingEnv(Env): """ Convert to bit vector if needed. - :param state: - :param batch_size: - :return: + :param state: The state to be converted, which can be either an integer or a numpy array. + :param batch_size: The batch size. + :return: The state converted into a bit vector. """ # Convert back to bit vector if isinstance(state, int): - state = np.array(state).reshape(batch_size, -1) + bit_vector = np.array(state).reshape(batch_size, -1) # Convert to binary representation - state = ((state[:, :] & (1 << np.arange(len(self.state)))) > 0).astype(int) + bit_vector = ((bit_vector[:, :] & (1 << np.arange(len(self.state)))) > 0).astype(int) elif self.image_obs_space: - state = state.reshape(batch_size, -1)[:, : len(self.state)] / 255 + bit_vector = state.reshape(batch_size, -1)[:, : len(self.state)] / 255 else: - state = np.array(state).reshape(batch_size, -1) - - return state + bit_vector = np.array(state).reshape(batch_size, -1) + return bit_vector def _get_obs(self) -> Dict[str, Union[int, np.ndarray]]: """ @@ -205,10 +206,11 @@ class BitFlippingEnv(Env): distance = np.linalg.norm(achieved_goal - desired_goal, axis=-1) return -(distance > 0).astype(np.float32) - def render(self, mode: str = "human") -> Optional[np.ndarray]: - if mode == "rgb_array": + def render(self) -> Optional[np.ndarray]: # type: ignore[override] + if self.render_mode == "rgb_array": return self.state.copy() print(self.state) + return None def close(self) -> None: pass diff --git a/stable_baselines3/common/envs/multi_input_envs.py b/stable_baselines3/common/envs/multi_input_envs.py index 3bb0710..d7e222d 100644 --- a/stable_baselines3/common/envs/multi_input_envs.py +++ b/stable_baselines3/common/envs/multi_input_envs.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import gymnasium as gym import numpy as np @@ -73,7 +73,7 @@ class SimpleMultiObsEnv(gym.Env): self.init_possible_transitions() self.num_col = num_col - self.state_mapping = [] + self.state_mapping: List[Dict[str, np.ndarray]] = [] self.init_state_mapping(num_col, num_row) self.max_state = len(self.state_mapping) - 1 @@ -121,20 +121,18 @@ class SimpleMultiObsEnv(gym.Env): self.right_possible = [0, 1, 2, 12, 13, 14] self.up_possible = [4, 8, 12, 7, 11, 15] - def step(self, action: Union[float, np.ndarray]) -> GymStepReturn: + def step(self, action: Union[int, np.ndarray]) -> GymStepReturn: """ Run one timestep of the environment's dynamics. When end of episode is reached, you are responsible for calling `reset()` to reset this environment's state. - Accepts an action and returns a tuple (observation, reward, done, info). + Accepts an action and returns a tuple (observation, reward, terminated, truncated, info). :param action: - :return: tuple (observation, reward, done, info). + :return: tuple (observation, reward, terminated, truncated, info). """ if not self.discrete_actions: - action = np.argmax(action) - else: - action = int(action) + action = np.argmax(action) # type: ignore[assignment] self.count += 1 diff --git a/stable_baselines3/common/logger.py b/stable_baselines3/common/logger.py index c379388..f1aadab 100644 --- a/stable_baselines3/common/logger.py +++ b/stable_baselines3/common/logger.py @@ -5,6 +5,7 @@ import sys import tempfile import warnings from collections import defaultdict +from io import TextIOWrapper from typing import Any, Dict, List, Mapping, Optional, Sequence, TextIO, Tuple, Union import numpy as np @@ -113,7 +114,7 @@ class KVWriter: Key Value writer """ - def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None: + def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None: """ Write a dictionary to file @@ -135,7 +136,7 @@ class SeqWriter: sequence writer """ - def write_sequence(self, sequence: List) -> None: + def write_sequence(self, sequence: List[str]) -> None: """ write_sequence an array to file @@ -163,15 +164,16 @@ class HumanOutputFormat(KVWriter, SeqWriter): if isinstance(filename_or_file, str): self.file = open(filename_or_file, "w") self.own_file = True - else: - assert hasattr(filename_or_file, "write"), f"Expected file or str, got {filename_or_file}" + elif isinstance(filename_or_file, TextIOWrapper): # equivalent to `isinstance(..., TextIO)` (not supported) self.file = filename_or_file self.own_file = False + else: + raise ValueError(f"Expected file or str, got {filename_or_file}") - def write(self, key_values: Dict, key_excluded: Dict, step: int = 0) -> None: + def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None: # Create strings for printing key2str = {} - tag = None + tag = "" for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())): if excluded is not None and ("stdout" in excluded or "log" in excluded): continue @@ -197,9 +199,9 @@ class HumanOutputFormat(KVWriter, SeqWriter): if key.find("/") > 0: # Find tag and add it to the dict tag = key[: key.find("/") + 1] key2str[(tag, self._truncate(tag))] = "" - # Remove tag from key - if tag is not None and tag in key: - key = str(" " + key[len(tag) :]) + # Remove tag from key and indent the key + if len(tag) > 0 and tag in key: + key = f"{'':3}{key[len(tag) :]}" truncated_key = self._truncate(key) if (tag, truncated_key) in key2str: @@ -240,8 +242,7 @@ class HumanOutputFormat(KVWriter, SeqWriter): string = string[: self.max_length - 3] + "..." return string - def write_sequence(self, sequence: List) -> None: - sequence = list(sequence) + def write_sequence(self, sequence: List[str]) -> None: for i, elem in enumerate(sequence): self.file.write(elem) if i < len(sequence) - 1: # add space unless this is the last one @@ -257,9 +258,7 @@ class HumanOutputFormat(KVWriter, SeqWriter): self.file.close() -def filter_excluded_keys( - key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], _format: str -) -> Dict[str, Any]: +def filter_excluded_keys(key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], _format: str) -> Dict[str, Any]: """ Filters the keys specified by ``key_exclude`` for the specified format @@ -285,7 +284,7 @@ class JSONOutputFormat(KVWriter): def __init__(self, filename: str): self.file = open(filename, "w") - def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None: + def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None: def cast_to_json_serializable(value: Any): if isinstance(value, Video): raise FormatUnsupportedError(["json"], "video") @@ -332,7 +331,7 @@ class CSVOutputFormat(KVWriter): self.separator = "," self.quotechar = '"' - def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None: + def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None: # Add our current row to the history key_values = filter_excluded_keys(key_values, key_excluded, "csv") extra_keys = key_values.keys() - self.keys @@ -394,10 +393,12 @@ class TensorBoardOutputFormat(KVWriter): """ def __init__(self, folder: str): - assert SummaryWriter is not None, "tensorboard is not installed, you can use " "pip install tensorboard to do so" + assert SummaryWriter is not None, "tensorboard is not installed, you can use `pip install tensorboard` to do so" self.writer = SummaryWriter(log_dir=folder) + self._is_closed = False - def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None: + def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None: + assert not self._is_closed, "The SummaryWriter was closed, please re-create one." for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())): if excluded is not None and "tensorboard" in excluded: continue @@ -437,7 +438,7 @@ class TensorBoardOutputFormat(KVWriter): """ if self.writer: self.writer.close() - self.writer = None + self._is_closed = True def make_output_format(_format: str, log_dir: str, log_suffix: str = "") -> KVWriter: @@ -478,13 +479,24 @@ class Logger: """ def __init__(self, folder: Optional[str], output_formats: List[KVWriter]): - self.name_to_value = defaultdict(float) # values this iteration - self.name_to_count = defaultdict(int) - self.name_to_excluded = defaultdict(str) + self.name_to_value: Dict[str, float] = defaultdict(float) # values this iteration + self.name_to_count: Dict[str, int] = defaultdict(int) + self.name_to_excluded: Dict[str, Tuple[str, ...]] = {} self.level = INFO self.dir = folder self.output_formats = output_formats + @staticmethod + def to_tuple(string_or_tuple: Optional[Union[str, Tuple[str, ...]]]) -> Tuple[str, ...]: + """ + Helper function to convert str to tuple of str. + """ + if string_or_tuple is None: + return ("",) + if isinstance(string_or_tuple, tuple): + return string_or_tuple + return (string_or_tuple,) + def record(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None: """ Log a value of some diagnostic @@ -496,9 +508,9 @@ class Logger: :param exclude: outputs to be excluded """ self.name_to_value[key] = value - self.name_to_excluded[key] = exclude + self.name_to_excluded[key] = self.to_tuple(exclude) - def record_mean(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None: + def record_mean(self, key: str, value: Optional[float], exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None: """ The same as record(), but if called many times, values averaged. @@ -507,12 +519,11 @@ class Logger: :param exclude: outputs to be excluded """ if value is None: - self.name_to_value[key] = None return old_val, count = self.name_to_value[key], self.name_to_count[key] self.name_to_value[key] = old_val * count / (count + 1) + value / (count + 1) self.name_to_count[key] = count + 1 - self.name_to_excluded[key] = exclude + self.name_to_excluded[key] = self.to_tuple(exclude) def dump(self, step: int = 0) -> None: """ @@ -592,7 +603,7 @@ class Logger: """ self.level = level - def get_dir(self) -> str: + def get_dir(self) -> Optional[str]: """ Get directory that log files are being written to. will be None if there is no output directory (i.e., if you didn't call start) @@ -610,7 +621,7 @@ class Logger: # Misc # ---------------------------------------- - def _do_log(self, args) -> None: + def _do_log(self, args: Tuple[Any, ...]) -> None: """ log to the requested format outputs @@ -618,7 +629,7 @@ class Logger: """ for _format in self.output_formats: if isinstance(_format, SeqWriter): - _format.write_sequence(map(str, args)) + _format.write_sequence(list(map(str, args))) def configure(folder: Optional[str] = None, format_strings: Optional[List[str]] = None) -> Logger: diff --git a/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py b/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py index 377b7f6..9d74798 100644 --- a/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py +++ b/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py @@ -74,7 +74,7 @@ class RMSpropTFLike(Optimizer): group.setdefault("centered", False) @torch.no_grad() - def step(self, closure: Optional[Callable[[], None]] = None) -> Optional[torch.Tensor]: + def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: """Performs a single optimization step. :param closure: A closure that reevaluates the model diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index a157503..4a97026 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -19,17 +19,17 @@ VecEnvObs = Union[np.ndarray, Dict[str, np.ndarray], Tuple[np.ndarray, ...]] VecEnvStepReturn = Tuple[VecEnvObs, np.ndarray, np.ndarray, List[Dict]] -def tile_images(img_nhwc: Sequence[np.ndarray]) -> np.ndarray: # pragma: no cover +def tile_images(images_nhwc: Sequence[np.ndarray]) -> np.ndarray: # pragma: no cover """ Tile N images into one big PxQ image (P,Q) are chosen to be as close as possible, and if N is square, then P=Q. - :param img_nhwc: list or array of images, ndim=4 once turned into array. img nhwc + :param images_nhwc: list or array of images, ndim=4 once turned into array. n = batch index, h = height, w = width, c = channel :return: img_HWc, ndim=3 """ - img_nhwc = np.asarray(img_nhwc) + img_nhwc = np.asarray(images_nhwc) n_images, height, width, n_channels = img_nhwc.shape # new_height was named H before new_height = int(np.ceil(np.sqrt(n_images))) @@ -67,7 +67,8 @@ class VecEnv(ABC): self.observation_space = observation_space self.action_space = action_space self.render_mode = render_mode - self.reset_infos = [{} for _ in range(num_envs)] # store info returned by the reset method + # store info returned by the reset method + self.reset_infos: List[Dict[str, Any]] = [{} for _ in range(num_envs)] @abstractmethod def reset(self) -> VecEnvObs: @@ -192,7 +193,7 @@ class VecEnv(ABC): "but the render mode defined when initializing the environment must be " f"'human' or 'rgb_array', not '{self.render_mode}'." ) - return + return None elif mode and self.render_mode != mode: warnings.warn( @@ -200,26 +201,26 @@ class VecEnv(ABC): We allow to pass a mode argument to maintain a backwards compatible VecEnv API, but the mode ({mode}) has to be the same as the environment render mode ({self.render_mode}) which is not the case.""" ) - return + return None mode = mode or self.render_mode if mode is None: warnings.warn("You tried to call render() but no `render_mode` was passed to the env constructor.") - return + return None # mode == self.render_mode == "human" # In that case, we try to call `self.env.render()` but it might # crash for subprocesses if self.render_mode == "human": self.env_method("render") - return + return None if mode == "rgb_array" or mode == "human": # call the render method of the environments images = self.get_images() # Create a big image by tiling images from subprocesses - bigimg = tile_images(images) + bigimg = tile_images(images) # type: ignore[arg-type] if mode == "human": # Display it using OpenCV @@ -236,9 +237,10 @@ class VecEnv(ABC): # crash for subprocesses # and we don't return the values self.env_method("render") + return None @abstractmethod - def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: + def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]: """ Sets the random seeds for all environments, based on a given seed. Each individual environment will still get its own seed, by incrementing the given seed. @@ -319,7 +321,7 @@ class VecEnvWrapper(VecEnv): def step_wait(self) -> VecEnvStepReturn: pass - def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: + def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]: return self.venv.seed(seed) def close(self) -> None: @@ -394,7 +396,7 @@ class VecEnvWrapper(VecEnv): all_attributes = self._get_all_attributes() if name in all_attributes and already_found: # this venv's attribute is being hidden because of a higher venv. - shadowed_wrapper_class = f"{type(self).__module__}.{type(self).__name__}" + shadowed_wrapper_class: Optional[str] = f"{type(self).__module__}.{type(self).__name__}" elif name in all_attributes and not already_found: # we have found the first reference to the attribute. Now check for duplicates. shadowed_wrapper_class = self.venv.getattr_depth_check(name, True) diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index 7f092e0..822025f 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -1,7 +1,7 @@ import warnings from collections import OrderedDict from copy import deepcopy -from typing import Any, Callable, List, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union import gymnasium as gym import numpy as np @@ -24,6 +24,8 @@ class DummyVecEnv(VecEnv): :raises ValueError: If the same environment instance is passed as the output of two or more different env_fn. """ + actions: np.ndarray + def __init__(self, env_fns: List[Callable[[], gym.Env]]): self.envs = [_patch_env(fn()) for fn in env_fns] if len(set([id(env.unwrapped) for env in self.envs])) != len(self.envs): @@ -44,8 +46,7 @@ class DummyVecEnv(VecEnv): self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs, *tuple(shapes[k])), dtype=dtypes[k])) for k in self.keys]) self.buf_dones = np.zeros((self.num_envs,), dtype=bool) self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32) - self.buf_infos = [{} for _ in range(self.num_envs)] - self.actions = None + self.buf_infos: List[Dict[str, Any]] = [{} for _ in range(self.num_envs)] self.metadata = env.metadata def step_async(self, actions: np.ndarray) -> None: @@ -70,7 +71,7 @@ class DummyVecEnv(VecEnv): self._save_obs(env_idx, obs) return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos)) - def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: + def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]: # Avoid circular import from stable_baselines3.common.utils import compat_gym_seed @@ -78,7 +79,7 @@ class DummyVecEnv(VecEnv): seed = np.random.randint(0, 2**32 - 1) seeds = [] for idx, env in enumerate(self.envs): - seeds.append(compat_gym_seed(env, seed=seed + idx)) + seeds.append(compat_gym_seed(env, seed=seed + idx)) # type: ignore[func-returns-value] return seeds def reset(self) -> VecEnvObs: @@ -97,7 +98,7 @@ class DummyVecEnv(VecEnv): f"The render mode is {self.render_mode}, but this method assumes it is `rgb_array` to obtain images." ) return [None for _ in self.envs] - return [env.render() for env in self.envs] + return [env.render() for env in self.envs] # type: ignore[misc] def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]: """ @@ -113,7 +114,7 @@ class DummyVecEnv(VecEnv): if key is None: self.buf_obs[key][env_idx] = obs else: - self.buf_obs[key][env_idx] = obs[key] + self.buf_obs[key][env_idx] = obs[key] # type: ignore[call-overload] def _obs_from_buf(self) -> VecEnvObs: return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs)) diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index ccefd20..d3fab98 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -135,7 +135,7 @@ class SubprocVecEnv(VecEnv): obs, rews, dones, infos, self.reset_infos = zip(*results) return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos - def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: + def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]: if seed is None: seed = np.random.randint(0, 2**32 - 1) for idx, remote in enumerate(self.remotes): diff --git a/stable_baselines3/common/vec_env/util.py b/stable_baselines3/common/vec_env/util.py index 6d55db8..2a03d8e 100644 --- a/stable_baselines3/common/vec_env/util.py +++ b/stable_baselines3/common/vec_env/util.py @@ -62,10 +62,10 @@ def obs_space_info(obs_space: spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[ assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces" subspaces = obs_space.spaces elif isinstance(obs_space, spaces.Tuple): - subspaces = {i: space for i, space in enumerate(obs_space.spaces)} + subspaces = {i: space for i, space in enumerate(obs_space.spaces)} # type: ignore[assignment] else: assert not hasattr(obs_space, "spaces"), f"Unsupported structured space '{type(obs_space)}'" - subspaces = {None: obs_space} + subspaces = {None: obs_space} # type: ignore[assignment] keys = [] shapes = {} dtypes = {} diff --git a/stable_baselines3/common/vec_env/vec_extract_dict_obs.py b/stable_baselines3/common/vec_env/vec_extract_dict_obs.py index 66872dd..72679b1 100644 --- a/stable_baselines3/common/vec_env/vec_extract_dict_obs.py +++ b/stable_baselines3/common/vec_env/vec_extract_dict_obs.py @@ -1,4 +1,5 @@ import numpy as np +from gymnasium import spaces from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper @@ -13,14 +14,19 @@ class VecExtractDictObs(VecEnvWrapper): def __init__(self, venv: VecEnv, key: str): self.key = key + assert isinstance( + venv.observation_space, spaces.Dict + ), f"VecExtractDictObs can only be used with Dict obs space, not {venv.observation_space}" super().__init__(venv=venv, observation_space=venv.observation_space.spaces[self.key]) def reset(self) -> np.ndarray: obs = self.venv.reset() + assert isinstance(obs, dict) return obs[self.key] def step_wait(self) -> VecEnvStepReturn: obs, reward, done, infos = self.venv.step_wait() + assert isinstance(obs, dict) for info in infos: if "terminal_observation" in info: info["terminal_observation"] = info["terminal_observation"][self.key] diff --git a/stable_baselines3/common/vec_env/vec_frame_stack.py b/stable_baselines3/common/vec_env/vec_frame_stack.py index 200201f..2396664 100644 --- a/stable_baselines3/common/vec_env/vec_frame_stack.py +++ b/stable_baselines3/common/vec_env/vec_frame_stack.py @@ -31,7 +31,7 @@ class VecFrameStack(VecEnvWrapper): self, ) -> Tuple[Union[np.ndarray, Dict[str, np.ndarray]], np.ndarray, np.ndarray, List[Dict[str, Any]],]: observations, rewards, dones, infos = self.venv.step_wait() - observations, infos = self.stacked_obs.update(observations, dones, infos) + observations, infos = self.stacked_obs.update(observations, dones, infos) # type: ignore[arg-type] return observations, rewards, dones, infos def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: @@ -39,5 +39,5 @@ class VecFrameStack(VecEnvWrapper): Reset all environments """ observation = self.venv.reset() # pytype:disable=annotation-type-mismatch - observation = self.stacked_obs.reset(observation) + observation = self.stacked_obs.reset(observation) # type: ignore[arg-type] return observation diff --git a/stable_baselines3/common/vec_env/vec_monitor.py b/stable_baselines3/common/vec_env/vec_monitor.py index ddc099a..0d7f18a 100644 --- a/stable_baselines3/common/vec_env/vec_monitor.py +++ b/stable_baselines3/common/vec_env/vec_monitor.py @@ -49,8 +49,6 @@ class VecMonitor(VecEnvWrapper): ) VecEnvWrapper.__init__(self, venv) - self.episode_returns = None - self.episode_lengths = None self.episode_count = 0 self.t_start = time.time() @@ -58,13 +56,15 @@ class VecMonitor(VecEnvWrapper): if hasattr(venv, "spec") and venv.spec is not None: env_id = venv.spec.id + self.results_writer: Optional[ResultsWriter] = None if filename: self.results_writer = ResultsWriter( - filename, header={"t_start": self.t_start, "env_id": env_id}, extra_keys=info_keywords + filename, header={"t_start": self.t_start, "env_id": str(env_id)}, extra_keys=info_keywords ) - else: - self.results_writer = None + self.info_keywords = info_keywords + self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) + self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32) def reset(self) -> VecEnvObs: obs = self.venv.reset() diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 758664d..4b23e04 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.0.0a6 +2.0.0a7 diff --git a/tests/test_utils.py b/tests/test_utils.py index 02128ed..4cc8b7e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -13,7 +13,7 @@ from stable_baselines3.common.atari_wrappers import MaxAndSkipEnv from stable_baselines3.common.env_util import is_wrapped, make_atari_env, make_vec_env, unwrap_wrapper from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.monitor import Monitor -from stable_baselines3.common.noise import ActionNoise, OrnsteinUhlenbeckActionNoise, VectorizedActionNoise +from stable_baselines3.common.noise import OrnsteinUhlenbeckActionNoise, VectorizedActionNoise from stable_baselines3.common.utils import ( check_shape_equal, get_parameters_by_name, @@ -349,7 +349,7 @@ def test_vec_noise(): num_actions = 10 mu = np.zeros(num_actions) sigma = np.ones(num_actions) * 0.4 - base: ActionNoise = OrnsteinUhlenbeckActionNoise(mu, sigma) + base = OrnsteinUhlenbeckActionNoise(mu, sigma) with pytest.raises(ValueError): vec = VectorizedActionNoise(base, -1) with pytest.raises(ValueError):