mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-28 22:56:53 +00:00
Type annotation bundle (logger, vec env, custom envs) (#1479)
* Switch from List to Sequence for `seed()` type hint * Fix logger type hints * Improve replay buffer type hints * Fix custom envs type annotations * Fix VecMonitor type hints * Fix RMSprop type hint * Fix vec extract dict obs type hints * Fix vec frame stack type annotations * Fix base vec env type hints * Fix dummy vec env type hints * Fix for mypy * Fixes for the tests * mypy doesn't like when we overwrite type * fix step of SimpleMultiObsEnv * remove useless type specification * Rm useless type hint * Improve logger type hint * format * rm useless type hint * Re-add variables in constructor, remove unused import --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
This commit is contained in:
parent
d6ddee9366
commit
63a0bb9da1
16 changed files with 113 additions and 96 deletions
|
|
@ -3,7 +3,7 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 2.0.0a6 (WIP)
|
||||
Release 2.0.0a7 (WIP)
|
||||
--------------------------
|
||||
|
||||
**Gymnasium support**
|
||||
|
|
@ -48,12 +48,19 @@ Others:
|
|||
- Fixed ``stable_baselines3/sac/*.py`` type hints
|
||||
- Fixed ``stable_baselines3/td3/*.py`` type hints
|
||||
- Fixed ``stable_baselines3/common/base_class.py`` type hints
|
||||
- Fixed ``stable_baselines3/common/logger.py`` type hints
|
||||
- Fixed ``stable_baselines3/common/envs/*.py`` type hints
|
||||
- Fixed ``stable_baselines3/common/vec_env/vec_monitor|vec_extract_dict_obs|util.py`` type hints
|
||||
- Fixed ``stable_baselines3/common/vec_env/base_vec_env.py`` type hints
|
||||
- Fixed ``stable_baselines3/common/vec_env/vec_frame_stack.py`` type hints
|
||||
- Fixed ``stable_baselines3/common/vec_env/dummy_vec_env.py`` type hints
|
||||
- Upgraded docker images to use mamba/micromamba and CUDA 11.7
|
||||
- Updated env checker to reflect what subset of Gymnasium is supported and improve GoalEnv checks
|
||||
- Improve type annotation of wrappers
|
||||
- Tests envs are now checked too
|
||||
- Added render test for ``VecEnv``
|
||||
- Update issue templates and env info saved with the model
|
||||
- Changed ``seed()`` method return type from ``List`` to ``Sequence``
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -38,23 +38,12 @@ exclude = """(?x)(
|
|||
stable_baselines3/common/buffers.py$
|
||||
| stable_baselines3/common/callbacks.py$
|
||||
| stable_baselines3/common/distributions.py$
|
||||
| stable_baselines3/common/envs/bit_flipping_env.py$
|
||||
| stable_baselines3/common/envs/identity_env.py$
|
||||
| stable_baselines3/common/envs/multi_input_envs.py$
|
||||
| stable_baselines3/common/logger.py$
|
||||
| stable_baselines3/common/off_policy_algorithm.py$
|
||||
| stable_baselines3/common/policies.py$
|
||||
| stable_baselines3/common/save_util.py$
|
||||
| stable_baselines3/common/sb2_compat/rmsprop_tf_like.py$
|
||||
| stable_baselines3/common/utils.py$
|
||||
| stable_baselines3/common/vec_env/__init__.py$
|
||||
| stable_baselines3/common/vec_env/base_vec_env.py$
|
||||
| stable_baselines3/common/vec_env/dummy_vec_env.py$
|
||||
| stable_baselines3/common/vec_env/subproc_vec_env.py$
|
||||
| stable_baselines3/common/vec_env/util.py$
|
||||
| stable_baselines3/common/vec_env/vec_extract_dict_obs.py$
|
||||
| stable_baselines3/common/vec_env/vec_frame_stack.py$
|
||||
| stable_baselines3/common/vec_env/vec_monitor.py$
|
||||
| stable_baselines3/common/vec_env/vec_normalize.py$
|
||||
| stable_baselines3/common/vec_env/vec_transpose.py$
|
||||
| stable_baselines3/common/vec_env/vec_video_recorder.py$
|
||||
|
|
|
|||
|
|
@ -681,6 +681,8 @@ class DictRolloutBuffer(RolloutBuffer):
|
|||
:param n_envs: Number of parallel environments
|
||||
"""
|
||||
|
||||
observations: Dict[str, np.ndarray]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
buffer_size: int,
|
||||
|
|
@ -697,8 +699,7 @@ class DictRolloutBuffer(RolloutBuffer):
|
|||
|
||||
self.gae_lambda = gae_lambda
|
||||
self.gamma = gamma
|
||||
self.observations, self.actions, self.rewards, self.advantages = None, None, None, None
|
||||
self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None
|
||||
|
||||
self.generator_ready = False
|
||||
self.reset()
|
||||
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ class BitFlippingEnv(Env):
|
|||
"""
|
||||
|
||||
spec = EnvSpec("BitFlippingEnv-v0", "no-entry-point")
|
||||
state: np.ndarray
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -35,8 +36,10 @@ class BitFlippingEnv(Env):
|
|||
discrete_obs_space: bool = False,
|
||||
image_obs_space: bool = False,
|
||||
channel_first: bool = True,
|
||||
render_mode: str = "human",
|
||||
):
|
||||
super().__init__()
|
||||
self.render_mode = render_mode
|
||||
# Shape of the observation when using image space
|
||||
self.image_shape = (1, 36, 36) if channel_first else (36, 36, 1)
|
||||
# The achieved goal is determined by the current state
|
||||
|
|
@ -95,7 +98,6 @@ class BitFlippingEnv(Env):
|
|||
self.continuous = continuous
|
||||
self.discrete_obs_space = discrete_obs_space
|
||||
self.image_obs_space = image_obs_space
|
||||
self.state = None
|
||||
self.desired_goal = np.ones((n_bits,), dtype=self.observation_space["desired_goal"].dtype)
|
||||
if max_steps is None:
|
||||
max_steps = n_bits
|
||||
|
|
@ -127,21 +129,20 @@ class BitFlippingEnv(Env):
|
|||
"""
|
||||
Convert to bit vector if needed.
|
||||
|
||||
:param state:
|
||||
:param batch_size:
|
||||
:return:
|
||||
:param state: The state to be converted, which can be either an integer or a numpy array.
|
||||
:param batch_size: The batch size.
|
||||
:return: The state converted into a bit vector.
|
||||
"""
|
||||
# Convert back to bit vector
|
||||
if isinstance(state, int):
|
||||
state = np.array(state).reshape(batch_size, -1)
|
||||
bit_vector = np.array(state).reshape(batch_size, -1)
|
||||
# Convert to binary representation
|
||||
state = ((state[:, :] & (1 << np.arange(len(self.state)))) > 0).astype(int)
|
||||
bit_vector = ((bit_vector[:, :] & (1 << np.arange(len(self.state)))) > 0).astype(int)
|
||||
elif self.image_obs_space:
|
||||
state = state.reshape(batch_size, -1)[:, : len(self.state)] / 255
|
||||
bit_vector = state.reshape(batch_size, -1)[:, : len(self.state)] / 255
|
||||
else:
|
||||
state = np.array(state).reshape(batch_size, -1)
|
||||
|
||||
return state
|
||||
bit_vector = np.array(state).reshape(batch_size, -1)
|
||||
return bit_vector
|
||||
|
||||
def _get_obs(self) -> Dict[str, Union[int, np.ndarray]]:
|
||||
"""
|
||||
|
|
@ -205,10 +206,11 @@ class BitFlippingEnv(Env):
|
|||
distance = np.linalg.norm(achieved_goal - desired_goal, axis=-1)
|
||||
return -(distance > 0).astype(np.float32)
|
||||
|
||||
def render(self, mode: str = "human") -> Optional[np.ndarray]:
|
||||
if mode == "rgb_array":
|
||||
def render(self) -> Optional[np.ndarray]: # type: ignore[override]
|
||||
if self.render_mode == "rgb_array":
|
||||
return self.state.copy()
|
||||
print(self.state)
|
||||
return None
|
||||
|
||||
def close(self) -> None:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
|
@ -73,7 +73,7 @@ class SimpleMultiObsEnv(gym.Env):
|
|||
self.init_possible_transitions()
|
||||
|
||||
self.num_col = num_col
|
||||
self.state_mapping = []
|
||||
self.state_mapping: List[Dict[str, np.ndarray]] = []
|
||||
self.init_state_mapping(num_col, num_row)
|
||||
|
||||
self.max_state = len(self.state_mapping) - 1
|
||||
|
|
@ -121,20 +121,18 @@ class SimpleMultiObsEnv(gym.Env):
|
|||
self.right_possible = [0, 1, 2, 12, 13, 14]
|
||||
self.up_possible = [4, 8, 12, 7, 11, 15]
|
||||
|
||||
def step(self, action: Union[float, np.ndarray]) -> GymStepReturn:
|
||||
def step(self, action: Union[int, np.ndarray]) -> GymStepReturn:
|
||||
"""
|
||||
Run one timestep of the environment's dynamics. When end of
|
||||
episode is reached, you are responsible for calling `reset()`
|
||||
to reset this environment's state.
|
||||
Accepts an action and returns a tuple (observation, reward, done, info).
|
||||
Accepts an action and returns a tuple (observation, reward, terminated, truncated, info).
|
||||
|
||||
:param action:
|
||||
:return: tuple (observation, reward, done, info).
|
||||
:return: tuple (observation, reward, terminated, truncated, info).
|
||||
"""
|
||||
if not self.discrete_actions:
|
||||
action = np.argmax(action)
|
||||
else:
|
||||
action = int(action)
|
||||
action = np.argmax(action) # type: ignore[assignment]
|
||||
|
||||
self.count += 1
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import sys
|
|||
import tempfile
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from io import TextIOWrapper
|
||||
from typing import Any, Dict, List, Mapping, Optional, Sequence, TextIO, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -113,7 +114,7 @@ class KVWriter:
|
|||
Key Value writer
|
||||
"""
|
||||
|
||||
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
|
||||
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None:
|
||||
"""
|
||||
Write a dictionary to file
|
||||
|
||||
|
|
@ -135,7 +136,7 @@ class SeqWriter:
|
|||
sequence writer
|
||||
"""
|
||||
|
||||
def write_sequence(self, sequence: List) -> None:
|
||||
def write_sequence(self, sequence: List[str]) -> None:
|
||||
"""
|
||||
write_sequence an array to file
|
||||
|
||||
|
|
@ -163,15 +164,16 @@ class HumanOutputFormat(KVWriter, SeqWriter):
|
|||
if isinstance(filename_or_file, str):
|
||||
self.file = open(filename_or_file, "w")
|
||||
self.own_file = True
|
||||
else:
|
||||
assert hasattr(filename_or_file, "write"), f"Expected file or str, got {filename_or_file}"
|
||||
elif isinstance(filename_or_file, TextIOWrapper): # equivalent to `isinstance(..., TextIO)` (not supported)
|
||||
self.file = filename_or_file
|
||||
self.own_file = False
|
||||
else:
|
||||
raise ValueError(f"Expected file or str, got {filename_or_file}")
|
||||
|
||||
def write(self, key_values: Dict, key_excluded: Dict, step: int = 0) -> None:
|
||||
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None:
|
||||
# Create strings for printing
|
||||
key2str = {}
|
||||
tag = None
|
||||
tag = ""
|
||||
for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())):
|
||||
if excluded is not None and ("stdout" in excluded or "log" in excluded):
|
||||
continue
|
||||
|
|
@ -197,9 +199,9 @@ class HumanOutputFormat(KVWriter, SeqWriter):
|
|||
if key.find("/") > 0: # Find tag and add it to the dict
|
||||
tag = key[: key.find("/") + 1]
|
||||
key2str[(tag, self._truncate(tag))] = ""
|
||||
# Remove tag from key
|
||||
if tag is not None and tag in key:
|
||||
key = str(" " + key[len(tag) :])
|
||||
# Remove tag from key and indent the key
|
||||
if len(tag) > 0 and tag in key:
|
||||
key = f"{'':3}{key[len(tag) :]}"
|
||||
|
||||
truncated_key = self._truncate(key)
|
||||
if (tag, truncated_key) in key2str:
|
||||
|
|
@ -240,8 +242,7 @@ class HumanOutputFormat(KVWriter, SeqWriter):
|
|||
string = string[: self.max_length - 3] + "..."
|
||||
return string
|
||||
|
||||
def write_sequence(self, sequence: List) -> None:
|
||||
sequence = list(sequence)
|
||||
def write_sequence(self, sequence: List[str]) -> None:
|
||||
for i, elem in enumerate(sequence):
|
||||
self.file.write(elem)
|
||||
if i < len(sequence) - 1: # add space unless this is the last one
|
||||
|
|
@ -257,9 +258,7 @@ class HumanOutputFormat(KVWriter, SeqWriter):
|
|||
self.file.close()
|
||||
|
||||
|
||||
def filter_excluded_keys(
|
||||
key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], _format: str
|
||||
) -> Dict[str, Any]:
|
||||
def filter_excluded_keys(key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], _format: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Filters the keys specified by ``key_exclude`` for the specified format
|
||||
|
||||
|
|
@ -285,7 +284,7 @@ class JSONOutputFormat(KVWriter):
|
|||
def __init__(self, filename: str):
|
||||
self.file = open(filename, "w")
|
||||
|
||||
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
|
||||
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None:
|
||||
def cast_to_json_serializable(value: Any):
|
||||
if isinstance(value, Video):
|
||||
raise FormatUnsupportedError(["json"], "video")
|
||||
|
|
@ -332,7 +331,7 @@ class CSVOutputFormat(KVWriter):
|
|||
self.separator = ","
|
||||
self.quotechar = '"'
|
||||
|
||||
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
|
||||
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None:
|
||||
# Add our current row to the history
|
||||
key_values = filter_excluded_keys(key_values, key_excluded, "csv")
|
||||
extra_keys = key_values.keys() - self.keys
|
||||
|
|
@ -394,10 +393,12 @@ class TensorBoardOutputFormat(KVWriter):
|
|||
"""
|
||||
|
||||
def __init__(self, folder: str):
|
||||
assert SummaryWriter is not None, "tensorboard is not installed, you can use " "pip install tensorboard to do so"
|
||||
assert SummaryWriter is not None, "tensorboard is not installed, you can use `pip install tensorboard` to do so"
|
||||
self.writer = SummaryWriter(log_dir=folder)
|
||||
self._is_closed = False
|
||||
|
||||
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
|
||||
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None:
|
||||
assert not self._is_closed, "The SummaryWriter was closed, please re-create one."
|
||||
for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())):
|
||||
if excluded is not None and "tensorboard" in excluded:
|
||||
continue
|
||||
|
|
@ -437,7 +438,7 @@ class TensorBoardOutputFormat(KVWriter):
|
|||
"""
|
||||
if self.writer:
|
||||
self.writer.close()
|
||||
self.writer = None
|
||||
self._is_closed = True
|
||||
|
||||
|
||||
def make_output_format(_format: str, log_dir: str, log_suffix: str = "") -> KVWriter:
|
||||
|
|
@ -478,13 +479,24 @@ class Logger:
|
|||
"""
|
||||
|
||||
def __init__(self, folder: Optional[str], output_formats: List[KVWriter]):
|
||||
self.name_to_value = defaultdict(float) # values this iteration
|
||||
self.name_to_count = defaultdict(int)
|
||||
self.name_to_excluded = defaultdict(str)
|
||||
self.name_to_value: Dict[str, float] = defaultdict(float) # values this iteration
|
||||
self.name_to_count: Dict[str, int] = defaultdict(int)
|
||||
self.name_to_excluded: Dict[str, Tuple[str, ...]] = {}
|
||||
self.level = INFO
|
||||
self.dir = folder
|
||||
self.output_formats = output_formats
|
||||
|
||||
@staticmethod
|
||||
def to_tuple(string_or_tuple: Optional[Union[str, Tuple[str, ...]]]) -> Tuple[str, ...]:
|
||||
"""
|
||||
Helper function to convert str to tuple of str.
|
||||
"""
|
||||
if string_or_tuple is None:
|
||||
return ("",)
|
||||
if isinstance(string_or_tuple, tuple):
|
||||
return string_or_tuple
|
||||
return (string_or_tuple,)
|
||||
|
||||
def record(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
|
||||
"""
|
||||
Log a value of some diagnostic
|
||||
|
|
@ -496,9 +508,9 @@ class Logger:
|
|||
:param exclude: outputs to be excluded
|
||||
"""
|
||||
self.name_to_value[key] = value
|
||||
self.name_to_excluded[key] = exclude
|
||||
self.name_to_excluded[key] = self.to_tuple(exclude)
|
||||
|
||||
def record_mean(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
|
||||
def record_mean(self, key: str, value: Optional[float], exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
|
||||
"""
|
||||
The same as record(), but if called many times, values averaged.
|
||||
|
||||
|
|
@ -507,12 +519,11 @@ class Logger:
|
|||
:param exclude: outputs to be excluded
|
||||
"""
|
||||
if value is None:
|
||||
self.name_to_value[key] = None
|
||||
return
|
||||
old_val, count = self.name_to_value[key], self.name_to_count[key]
|
||||
self.name_to_value[key] = old_val * count / (count + 1) + value / (count + 1)
|
||||
self.name_to_count[key] = count + 1
|
||||
self.name_to_excluded[key] = exclude
|
||||
self.name_to_excluded[key] = self.to_tuple(exclude)
|
||||
|
||||
def dump(self, step: int = 0) -> None:
|
||||
"""
|
||||
|
|
@ -592,7 +603,7 @@ class Logger:
|
|||
"""
|
||||
self.level = level
|
||||
|
||||
def get_dir(self) -> str:
|
||||
def get_dir(self) -> Optional[str]:
|
||||
"""
|
||||
Get directory that log files are being written to.
|
||||
will be None if there is no output directory (i.e., if you didn't call start)
|
||||
|
|
@ -610,7 +621,7 @@ class Logger:
|
|||
|
||||
# Misc
|
||||
# ----------------------------------------
|
||||
def _do_log(self, args) -> None:
|
||||
def _do_log(self, args: Tuple[Any, ...]) -> None:
|
||||
"""
|
||||
log to the requested format outputs
|
||||
|
||||
|
|
@ -618,7 +629,7 @@ class Logger:
|
|||
"""
|
||||
for _format in self.output_formats:
|
||||
if isinstance(_format, SeqWriter):
|
||||
_format.write_sequence(map(str, args))
|
||||
_format.write_sequence(list(map(str, args)))
|
||||
|
||||
|
||||
def configure(folder: Optional[str] = None, format_strings: Optional[List[str]] = None) -> Logger:
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ class RMSpropTFLike(Optimizer):
|
|||
group.setdefault("centered", False)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure: Optional[Callable[[], None]] = None) -> Optional[torch.Tensor]:
|
||||
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
|
||||
"""Performs a single optimization step.
|
||||
|
||||
:param closure: A closure that reevaluates the model
|
||||
|
|
|
|||
|
|
@ -19,17 +19,17 @@ VecEnvObs = Union[np.ndarray, Dict[str, np.ndarray], Tuple[np.ndarray, ...]]
|
|||
VecEnvStepReturn = Tuple[VecEnvObs, np.ndarray, np.ndarray, List[Dict]]
|
||||
|
||||
|
||||
def tile_images(img_nhwc: Sequence[np.ndarray]) -> np.ndarray: # pragma: no cover
|
||||
def tile_images(images_nhwc: Sequence[np.ndarray]) -> np.ndarray: # pragma: no cover
|
||||
"""
|
||||
Tile N images into one big PxQ image
|
||||
(P,Q) are chosen to be as close as possible, and if N
|
||||
is square, then P=Q.
|
||||
|
||||
:param img_nhwc: list or array of images, ndim=4 once turned into array. img nhwc
|
||||
:param images_nhwc: list or array of images, ndim=4 once turned into array.
|
||||
n = batch index, h = height, w = width, c = channel
|
||||
:return: img_HWc, ndim=3
|
||||
"""
|
||||
img_nhwc = np.asarray(img_nhwc)
|
||||
img_nhwc = np.asarray(images_nhwc)
|
||||
n_images, height, width, n_channels = img_nhwc.shape
|
||||
# new_height was named H before
|
||||
new_height = int(np.ceil(np.sqrt(n_images)))
|
||||
|
|
@ -67,7 +67,8 @@ class VecEnv(ABC):
|
|||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
self.render_mode = render_mode
|
||||
self.reset_infos = [{} for _ in range(num_envs)] # store info returned by the reset method
|
||||
# store info returned by the reset method
|
||||
self.reset_infos: List[Dict[str, Any]] = [{} for _ in range(num_envs)]
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> VecEnvObs:
|
||||
|
|
@ -192,7 +193,7 @@ class VecEnv(ABC):
|
|||
"but the render mode defined when initializing the environment must be "
|
||||
f"'human' or 'rgb_array', not '{self.render_mode}'."
|
||||
)
|
||||
return
|
||||
return None
|
||||
|
||||
elif mode and self.render_mode != mode:
|
||||
warnings.warn(
|
||||
|
|
@ -200,26 +201,26 @@ class VecEnv(ABC):
|
|||
We allow to pass a mode argument to maintain a backwards compatible VecEnv API, but the mode ({mode})
|
||||
has to be the same as the environment render mode ({self.render_mode}) which is not the case."""
|
||||
)
|
||||
return
|
||||
return None
|
||||
|
||||
mode = mode or self.render_mode
|
||||
|
||||
if mode is None:
|
||||
warnings.warn("You tried to call render() but no `render_mode` was passed to the env constructor.")
|
||||
return
|
||||
return None
|
||||
|
||||
# mode == self.render_mode == "human"
|
||||
# In that case, we try to call `self.env.render()` but it might
|
||||
# crash for subprocesses
|
||||
if self.render_mode == "human":
|
||||
self.env_method("render")
|
||||
return
|
||||
return None
|
||||
|
||||
if mode == "rgb_array" or mode == "human":
|
||||
# call the render method of the environments
|
||||
images = self.get_images()
|
||||
# Create a big image by tiling images from subprocesses
|
||||
bigimg = tile_images(images)
|
||||
bigimg = tile_images(images) # type: ignore[arg-type]
|
||||
|
||||
if mode == "human":
|
||||
# Display it using OpenCV
|
||||
|
|
@ -236,9 +237,10 @@ class VecEnv(ABC):
|
|||
# crash for subprocesses
|
||||
# and we don't return the values
|
||||
self.env_method("render")
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
|
||||
def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]:
|
||||
"""
|
||||
Sets the random seeds for all environments, based on a given seed.
|
||||
Each individual environment will still get its own seed, by incrementing the given seed.
|
||||
|
|
@ -319,7 +321,7 @@ class VecEnvWrapper(VecEnv):
|
|||
def step_wait(self) -> VecEnvStepReturn:
|
||||
pass
|
||||
|
||||
def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
|
||||
def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]:
|
||||
return self.venv.seed(seed)
|
||||
|
||||
def close(self) -> None:
|
||||
|
|
@ -394,7 +396,7 @@ class VecEnvWrapper(VecEnv):
|
|||
all_attributes = self._get_all_attributes()
|
||||
if name in all_attributes and already_found:
|
||||
# this venv's attribute is being hidden because of a higher venv.
|
||||
shadowed_wrapper_class = f"{type(self).__module__}.{type(self).__name__}"
|
||||
shadowed_wrapper_class: Optional[str] = f"{type(self).__module__}.{type(self).__name__}"
|
||||
elif name in all_attributes and not already_found:
|
||||
# we have found the first reference to the attribute. Now check for duplicates.
|
||||
shadowed_wrapper_class = self.venv.getattr_depth_check(name, True)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import warnings
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, List, Optional, Sequence, Type, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
|
@ -24,6 +24,8 @@ class DummyVecEnv(VecEnv):
|
|||
:raises ValueError: If the same environment instance is passed as the output of two or more different env_fn.
|
||||
"""
|
||||
|
||||
actions: np.ndarray
|
||||
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]]):
|
||||
self.envs = [_patch_env(fn()) for fn in env_fns]
|
||||
if len(set([id(env.unwrapped) for env in self.envs])) != len(self.envs):
|
||||
|
|
@ -44,8 +46,7 @@ class DummyVecEnv(VecEnv):
|
|||
self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs, *tuple(shapes[k])), dtype=dtypes[k])) for k in self.keys])
|
||||
self.buf_dones = np.zeros((self.num_envs,), dtype=bool)
|
||||
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
|
||||
self.buf_infos = [{} for _ in range(self.num_envs)]
|
||||
self.actions = None
|
||||
self.buf_infos: List[Dict[str, Any]] = [{} for _ in range(self.num_envs)]
|
||||
self.metadata = env.metadata
|
||||
|
||||
def step_async(self, actions: np.ndarray) -> None:
|
||||
|
|
@ -70,7 +71,7 @@ class DummyVecEnv(VecEnv):
|
|||
self._save_obs(env_idx, obs)
|
||||
return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos))
|
||||
|
||||
def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
|
||||
def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]:
|
||||
# Avoid circular import
|
||||
from stable_baselines3.common.utils import compat_gym_seed
|
||||
|
||||
|
|
@ -78,7 +79,7 @@ class DummyVecEnv(VecEnv):
|
|||
seed = np.random.randint(0, 2**32 - 1)
|
||||
seeds = []
|
||||
for idx, env in enumerate(self.envs):
|
||||
seeds.append(compat_gym_seed(env, seed=seed + idx))
|
||||
seeds.append(compat_gym_seed(env, seed=seed + idx)) # type: ignore[func-returns-value]
|
||||
return seeds
|
||||
|
||||
def reset(self) -> VecEnvObs:
|
||||
|
|
@ -97,7 +98,7 @@ class DummyVecEnv(VecEnv):
|
|||
f"The render mode is {self.render_mode}, but this method assumes it is `rgb_array` to obtain images."
|
||||
)
|
||||
return [None for _ in self.envs]
|
||||
return [env.render() for env in self.envs]
|
||||
return [env.render() for env in self.envs] # type: ignore[misc]
|
||||
|
||||
def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]:
|
||||
"""
|
||||
|
|
@ -113,7 +114,7 @@ class DummyVecEnv(VecEnv):
|
|||
if key is None:
|
||||
self.buf_obs[key][env_idx] = obs
|
||||
else:
|
||||
self.buf_obs[key][env_idx] = obs[key]
|
||||
self.buf_obs[key][env_idx] = obs[key] # type: ignore[call-overload]
|
||||
|
||||
def _obs_from_buf(self) -> VecEnvObs:
|
||||
return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs))
|
||||
|
|
|
|||
|
|
@ -135,7 +135,7 @@ class SubprocVecEnv(VecEnv):
|
|||
obs, rews, dones, infos, self.reset_infos = zip(*results)
|
||||
return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos
|
||||
|
||||
def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
|
||||
def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]:
|
||||
if seed is None:
|
||||
seed = np.random.randint(0, 2**32 - 1)
|
||||
for idx, remote in enumerate(self.remotes):
|
||||
|
|
|
|||
|
|
@ -62,10 +62,10 @@ def obs_space_info(obs_space: spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[
|
|||
assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces"
|
||||
subspaces = obs_space.spaces
|
||||
elif isinstance(obs_space, spaces.Tuple):
|
||||
subspaces = {i: space for i, space in enumerate(obs_space.spaces)}
|
||||
subspaces = {i: space for i, space in enumerate(obs_space.spaces)} # type: ignore[assignment]
|
||||
else:
|
||||
assert not hasattr(obs_space, "spaces"), f"Unsupported structured space '{type(obs_space)}'"
|
||||
subspaces = {None: obs_space}
|
||||
subspaces = {None: obs_space} # type: ignore[assignment]
|
||||
keys = []
|
||||
shapes = {}
|
||||
dtypes = {}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
||||
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
|
||||
|
||||
|
|
@ -13,14 +14,19 @@ class VecExtractDictObs(VecEnvWrapper):
|
|||
|
||||
def __init__(self, venv: VecEnv, key: str):
|
||||
self.key = key
|
||||
assert isinstance(
|
||||
venv.observation_space, spaces.Dict
|
||||
), f"VecExtractDictObs can only be used with Dict obs space, not {venv.observation_space}"
|
||||
super().__init__(venv=venv, observation_space=venv.observation_space.spaces[self.key])
|
||||
|
||||
def reset(self) -> np.ndarray:
|
||||
obs = self.venv.reset()
|
||||
assert isinstance(obs, dict)
|
||||
return obs[self.key]
|
||||
|
||||
def step_wait(self) -> VecEnvStepReturn:
|
||||
obs, reward, done, infos = self.venv.step_wait()
|
||||
assert isinstance(obs, dict)
|
||||
for info in infos:
|
||||
if "terminal_observation" in info:
|
||||
info["terminal_observation"] = info["terminal_observation"][self.key]
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ class VecFrameStack(VecEnvWrapper):
|
|||
self,
|
||||
) -> Tuple[Union[np.ndarray, Dict[str, np.ndarray]], np.ndarray, np.ndarray, List[Dict[str, Any]],]:
|
||||
observations, rewards, dones, infos = self.venv.step_wait()
|
||||
observations, infos = self.stacked_obs.update(observations, dones, infos)
|
||||
observations, infos = self.stacked_obs.update(observations, dones, infos) # type: ignore[arg-type]
|
||||
return observations, rewards, dones, infos
|
||||
|
||||
def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
|
||||
|
|
@ -39,5 +39,5 @@ class VecFrameStack(VecEnvWrapper):
|
|||
Reset all environments
|
||||
"""
|
||||
observation = self.venv.reset() # pytype:disable=annotation-type-mismatch
|
||||
observation = self.stacked_obs.reset(observation)
|
||||
observation = self.stacked_obs.reset(observation) # type: ignore[arg-type]
|
||||
return observation
|
||||
|
|
|
|||
|
|
@ -49,8 +49,6 @@ class VecMonitor(VecEnvWrapper):
|
|||
)
|
||||
|
||||
VecEnvWrapper.__init__(self, venv)
|
||||
self.episode_returns = None
|
||||
self.episode_lengths = None
|
||||
self.episode_count = 0
|
||||
self.t_start = time.time()
|
||||
|
||||
|
|
@ -58,13 +56,15 @@ class VecMonitor(VecEnvWrapper):
|
|||
if hasattr(venv, "spec") and venv.spec is not None:
|
||||
env_id = venv.spec.id
|
||||
|
||||
self.results_writer: Optional[ResultsWriter] = None
|
||||
if filename:
|
||||
self.results_writer = ResultsWriter(
|
||||
filename, header={"t_start": self.t_start, "env_id": env_id}, extra_keys=info_keywords
|
||||
filename, header={"t_start": self.t_start, "env_id": str(env_id)}, extra_keys=info_keywords
|
||||
)
|
||||
else:
|
||||
self.results_writer = None
|
||||
|
||||
self.info_keywords = info_keywords
|
||||
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
|
||||
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
|
||||
|
||||
def reset(self) -> VecEnvObs:
|
||||
obs = self.venv.reset()
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
2.0.0a6
|
||||
2.0.0a7
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from stable_baselines3.common.atari_wrappers import MaxAndSkipEnv
|
|||
from stable_baselines3.common.env_util import is_wrapped, make_atari_env, make_vec_env, unwrap_wrapper
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.noise import ActionNoise, OrnsteinUhlenbeckActionNoise, VectorizedActionNoise
|
||||
from stable_baselines3.common.noise import OrnsteinUhlenbeckActionNoise, VectorizedActionNoise
|
||||
from stable_baselines3.common.utils import (
|
||||
check_shape_equal,
|
||||
get_parameters_by_name,
|
||||
|
|
@ -349,7 +349,7 @@ def test_vec_noise():
|
|||
num_actions = 10
|
||||
mu = np.zeros(num_actions)
|
||||
sigma = np.ones(num_actions) * 0.4
|
||||
base: ActionNoise = OrnsteinUhlenbeckActionNoise(mu, sigma)
|
||||
base = OrnsteinUhlenbeckActionNoise(mu, sigma)
|
||||
with pytest.raises(ValueError):
|
||||
vec = VectorizedActionNoise(base, -1)
|
||||
with pytest.raises(ValueError):
|
||||
|
|
|
|||
Loading…
Reference in a new issue