mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Refactor observation stacking (#1238)
* refactor stacking obs * Improve docstring * remove all StackedDictObservations * Update tests and make stacked obs clearer * Fix type check * fix stacked_observation_space * undo init change, deprecate StackedDictObservations * deprecate stack_observation_space * type hints * ignore pytype errors * undo vecenv doc change * Deprecation warning in StackedDictObs doctstring * Fix vec_env.rst * Fix __all__ sorting * fix pytype ignore statement * Update docstring * stack * Remove n_stack * Update changelog * Simplify code * Rename test file * Re-use variable for shift * Fix doc build * Remove pytype comment * Disable pytype error --------- Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
411ff697dd
commit
2e4a45020e
8 changed files with 459 additions and 234 deletions
|
|
@ -122,12 +122,6 @@ StackedObservations
|
|||
.. autoclass:: stable_baselines3.common.vec_env.stacked_observations.StackedObservations
|
||||
:members:
|
||||
|
||||
StackedDictObservations
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: stable_baselines3.common.vec_env.stacked_observations.StackedDictObservations
|
||||
:members:
|
||||
|
||||
VecNormalize
|
||||
~~~~~~~~~~~~
|
||||
|
||||
|
|
|
|||
|
|
@ -4,13 +4,14 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.8.0a3 (WIP)
|
||||
Release 1.8.0a4 (WIP)
|
||||
--------------------------
|
||||
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Removed shared layers in ``mlp_extractor`` (@AlexPasqua)
|
||||
- Refactored ``StackedObservations`` (it now handles dict obs, ``StackedDictObservations`` was removed)
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -36,6 +37,7 @@ Others:
|
|||
- Fixed ``tests/test_tensorboard.py`` type hint
|
||||
- Fixed ``tests/test_vec_normalize.py`` type hint
|
||||
- Fixed ``stable_baselines3/common/monitor.py`` type hint
|
||||
- Added tests for StackedObservations
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -48,7 +48,6 @@ exclude = (?x)(
|
|||
| stable_baselines3/common/vec_env/__init__.py$
|
||||
| stable_baselines3/common/vec_env/base_vec_env.py$
|
||||
| stable_baselines3/common/vec_env/dummy_vec_env.py$
|
||||
| stable_baselines3/common/vec_env/stacked_observations.py$
|
||||
| stable_baselines3/common/vec_env/subproc_vec_env.py$
|
||||
| stable_baselines3/common/vec_env/util.py$
|
||||
| stable_baselines3/common/vec_env/vec_extract_dict_obs.py$
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import Optional, Type, Union
|
|||
|
||||
from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper
|
||||
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||
from stable_baselines3.common.vec_env.stacked_observations import StackedDictObservations, StackedObservations
|
||||
from stable_baselines3.common.vec_env.stacked_observations import StackedObservations
|
||||
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
|
||||
from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan
|
||||
from stable_baselines3.common.vec_env.vec_extract_dict_obs import VecExtractDictObs
|
||||
|
|
@ -78,7 +78,6 @@ __all__ = [
|
|||
"VecEnv",
|
||||
"VecEnvWrapper",
|
||||
"DummyVecEnv",
|
||||
"StackedDictObservations",
|
||||
"StackedObservations",
|
||||
"SubprocVecEnv",
|
||||
"VecCheckNan",
|
||||
|
|
|
|||
|
|
@ -1,61 +1,80 @@
|
|||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Generic, List, Mapping, Optional, Tuple, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
|
||||
|
||||
TObs = TypeVar("TObs", np.ndarray, Dict[str, np.ndarray])
|
||||
|
||||
class StackedObservations:
|
||||
|
||||
# Disable errors for pytype which doesn't play well with Generic[TypeVar]
|
||||
# mypy check passes though
|
||||
# pytype: disable=attribute-error
|
||||
class StackedObservations(Generic[TObs]):
|
||||
"""
|
||||
Frame stacking wrapper for data.
|
||||
|
||||
Dimension to stack over is either first (channels-first) or
|
||||
last (channels-last), which is detected automatically using
|
||||
``common.preprocessing.is_image_space_channels_first`` if
|
||||
observation is an image space.
|
||||
Dimension to stack over is either first (channels-first) or last (channels-last), which is detected automatically using
|
||||
``common.preprocessing.is_image_space_channels_first`` if observation is an image space.
|
||||
|
||||
:param num_envs: number of environments
|
||||
:param num_envs: Number of environments
|
||||
:param n_stack: Number of frames to stack
|
||||
:param observation_space: Environment observation space.
|
||||
:param observation_space: Environment observation space
|
||||
:param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension.
|
||||
If None, automatically detect channel to stack over in case of image observation or default to "last" (default).
|
||||
If None, automatically detect channel to stack over in case of image observation or default to "last".
|
||||
For Dict space, channels_order can also be a dictionary.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_envs: int,
|
||||
n_stack: int,
|
||||
observation_space: spaces.Space,
|
||||
channels_order: Optional[str] = None,
|
||||
):
|
||||
observation_space: Union[spaces.Box, spaces.Dict], # Replace by Space[TObs] in gym>=0.26
|
||||
channels_order: Optional[Union[str, Mapping[str, Optional[str]]]] = None,
|
||||
) -> None:
|
||||
self.n_stack = n_stack
|
||||
(
|
||||
self.channels_first,
|
||||
self.stack_dimension,
|
||||
self.stackedobs,
|
||||
self.repeat_axis,
|
||||
) = self.compute_stacking(num_envs, n_stack, observation_space, channels_order)
|
||||
super().__init__()
|
||||
self.observation_space = observation_space
|
||||
if isinstance(observation_space, spaces.Dict):
|
||||
if not isinstance(channels_order, Mapping):
|
||||
channels_order = {key: channels_order for key in observation_space.spaces.keys()}
|
||||
self.sub_stacked_observations = {
|
||||
key: StackedObservations(num_envs, n_stack, subspace, channels_order[key])
|
||||
for key, subspace in observation_space.spaces.items()
|
||||
}
|
||||
self.stacked_observation_space = spaces.Dict(
|
||||
{key: substack_obs.stacked_observation_space for key, substack_obs in self.sub_stacked_observations.items()}
|
||||
) # type: spaces.Dict # make mypy happy
|
||||
elif isinstance(observation_space, spaces.Box):
|
||||
if isinstance(channels_order, Mapping):
|
||||
raise TypeError("When the observation space is Box, channels_order can't be a dict.")
|
||||
|
||||
self.channels_first, self.stack_dimension, self.stacked_shape, self.repeat_axis = self.compute_stacking(
|
||||
n_stack, observation_space, channels_order
|
||||
)
|
||||
low = np.repeat(observation_space.low, n_stack, axis=self.repeat_axis)
|
||||
high = np.repeat(observation_space.high, n_stack, axis=self.repeat_axis)
|
||||
self.stacked_observation_space = spaces.Box(low=low, high=high, dtype=observation_space.dtype)
|
||||
self.stacked_obs = np.zeros((num_envs,) + self.stacked_shape, dtype=observation_space.dtype)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"StackedObservations only supports Box and Dict as observation spaces. {observation_space} was provided."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def compute_stacking(
|
||||
num_envs: int,
|
||||
n_stack: int,
|
||||
observation_space: spaces.Box,
|
||||
channels_order: Optional[str] = None,
|
||||
) -> Tuple[bool, int, np.ndarray, int]:
|
||||
n_stack: int, observation_space: spaces.Box, channels_order: Optional[str] = None
|
||||
) -> Tuple[bool, int, Tuple[int, ...], int]:
|
||||
"""
|
||||
Calculates the parameters in order to stack observations
|
||||
|
||||
:param num_envs: Number of environments in the stack
|
||||
:param n_stack: The number of observations to stack
|
||||
:param observation_space: The observation space
|
||||
:param channels_order: The order of the channels
|
||||
:return: tuple of channels_first, stack_dimension, stackedobs, repeat_axis
|
||||
:param n_stack: Number of observations to stack
|
||||
:param observation_space: Observation space
|
||||
:param channels_order: Order of the channels
|
||||
:return: Tuple of channels_first, stack_dimension, stackedobs, repeat_axis
|
||||
"""
|
||||
channels_first = False
|
||||
|
||||
if channels_order is None:
|
||||
# Detect channel location automatically for images
|
||||
if is_image_space(observation_space):
|
||||
|
|
@ -74,192 +93,113 @@ class StackedObservations:
|
|||
# This includes the vec-env dimension (first)
|
||||
stack_dimension = 1 if channels_first else -1
|
||||
repeat_axis = 0 if channels_first else -1
|
||||
low = np.repeat(observation_space.low, n_stack, axis=repeat_axis)
|
||||
stackedobs = np.zeros((num_envs,) + low.shape, low.dtype)
|
||||
return channels_first, stack_dimension, stackedobs, repeat_axis
|
||||
stacked_shape = list(observation_space.shape)
|
||||
stacked_shape[repeat_axis] *= n_stack
|
||||
return channels_first, stack_dimension, tuple(stacked_shape), repeat_axis
|
||||
|
||||
def stack_observation_space(self, observation_space: spaces.Box) -> spaces.Box:
|
||||
def stack_observation_space(self, observation_space: Union[spaces.Box, spaces.Dict]) -> Union[spaces.Box, spaces.Dict]:
|
||||
"""
|
||||
Given an observation space, returns a new observation space with stacked observations
|
||||
This function is deprecated.
|
||||
|
||||
As an alternative, use
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
low = np.repeat(observation_space.low, stacked_observation.n_stack, axis=stacked_observation.repeat_axis)
|
||||
high = np.repeat(observation_space.high, stacked_observation.n_stack, axis=stacked_observation.repeat_axis)
|
||||
stacked_observation_space = spaces.Box(low=low, high=high, dtype=observation_space.dtype)
|
||||
|
||||
:return: New observation space with stacked dimensions
|
||||
"""
|
||||
warnings.warn(
|
||||
"stack_observation_space is deprecated and will be removed in the next SB3 release. "
|
||||
"Please refer to the docstring for a workaround.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
if isinstance(observation_space, spaces.Dict):
|
||||
return spaces.Dict(
|
||||
{
|
||||
key: sub_stacked_observation.stack_observation_space(sub_stacked_observation.observation_space)
|
||||
for key, sub_stacked_observation in self.sub_stacked_observations.items()
|
||||
}
|
||||
)
|
||||
low = np.repeat(observation_space.low, self.n_stack, axis=self.repeat_axis)
|
||||
high = np.repeat(observation_space.high, self.n_stack, axis=self.repeat_axis)
|
||||
return spaces.Box(low=low, high=high, dtype=observation_space.dtype)
|
||||
|
||||
def reset(self, observation: np.ndarray) -> np.ndarray:
|
||||
def reset(self, observation: TObs) -> TObs:
|
||||
"""
|
||||
Resets the stackedobs, adds the reset observation to the stack, and returns the stack
|
||||
Reset the stacked_obs, add the reset observation to the stack, and return the stack.
|
||||
|
||||
:param observation: Reset observation
|
||||
:return: The stacked reset observation
|
||||
"""
|
||||
self.stackedobs[...] = 0
|
||||
if isinstance(observation, dict):
|
||||
return {key: self.sub_stacked_observations[key].reset(obs) for key, obs in observation.items()}
|
||||
|
||||
self.stacked_obs[...] = 0
|
||||
if self.channels_first:
|
||||
self.stackedobs[:, -observation.shape[self.stack_dimension] :, ...] = observation
|
||||
self.stacked_obs[:, -observation.shape[self.stack_dimension] :, ...] = observation
|
||||
else:
|
||||
self.stackedobs[..., -observation.shape[self.stack_dimension] :] = observation
|
||||
return self.stackedobs
|
||||
self.stacked_obs[..., -observation.shape[self.stack_dimension] :] = observation
|
||||
return self.stacked_obs
|
||||
|
||||
def update(
|
||||
self,
|
||||
observations: np.ndarray,
|
||||
observations: TObs,
|
||||
dones: np.ndarray,
|
||||
infos: List[Dict[str, Any]],
|
||||
) -> Tuple[np.ndarray, List[Dict[str, Any]]]:
|
||||
) -> Tuple[TObs, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Adds the observations to the stack and uses the dones to update the infos.
|
||||
Add the observations to the stack and use the dones to update the infos.
|
||||
|
||||
:param observations: numpy array of observations
|
||||
:param dones: numpy array of done info
|
||||
:param infos: numpy array of info dicts
|
||||
:return: tuple of the stacked observations and the updated infos
|
||||
:param observations: Observations
|
||||
:param dones: Dones
|
||||
:param infos: Infos
|
||||
:return: Tuple of the stacked observations and the updated infos
|
||||
"""
|
||||
stack_ax_size = observations.shape[self.stack_dimension]
|
||||
self.stackedobs = np.roll(self.stackedobs, shift=-stack_ax_size, axis=self.stack_dimension)
|
||||
for i, done in enumerate(dones):
|
||||
if isinstance(observations, dict):
|
||||
# From [{}, {terminal_obs: {key1: ..., key2: ...}}]
|
||||
# to {key1: [{}, {terminal_obs: ...}], key2: [{}, {terminal_obs: ...}]}
|
||||
sub_infos = {
|
||||
key: [
|
||||
{"terminal_observation": info["terminal_observation"][key]} if "terminal_observation" in info else {}
|
||||
for info in infos
|
||||
]
|
||||
for key in observations.keys()
|
||||
}
|
||||
|
||||
stacked_obs = {}
|
||||
stacked_infos = {}
|
||||
for key, obs in observations.items():
|
||||
stacked_obs[key], stacked_infos[key] = self.sub_stacked_observations[key].update(obs, dones, sub_infos[key])
|
||||
|
||||
# From {key1: [{}, {terminal_obs: ...}], key2: [{}, {terminal_obs: ...}]}
|
||||
# to [{}, {terminal_obs: {key1: ..., key2: ...}}]
|
||||
for key in stacked_infos.keys():
|
||||
for env_idx in range(len(infos)):
|
||||
if "terminal_observation" in infos[env_idx]:
|
||||
infos[env_idx]["terminal_observation"][key] = stacked_infos[key][env_idx]["terminal_observation"]
|
||||
return stacked_obs, infos
|
||||
|
||||
shift = -observations.shape[self.stack_dimension]
|
||||
self.stacked_obs = np.roll(self.stacked_obs, shift, axis=self.stack_dimension)
|
||||
for env_idx, done in enumerate(dones):
|
||||
if done:
|
||||
if "terminal_observation" in infos[i]:
|
||||
old_terminal = infos[i]["terminal_observation"]
|
||||
if "terminal_observation" in infos[env_idx]:
|
||||
old_terminal = infos[env_idx]["terminal_observation"]
|
||||
if self.channels_first:
|
||||
new_terminal = np.concatenate(
|
||||
(self.stackedobs[i, :-stack_ax_size, ...], old_terminal),
|
||||
axis=0, # self.stack_dimension - 1, as there is not batch dim
|
||||
)
|
||||
previous_stack = self.stacked_obs[env_idx, :shift, ...]
|
||||
else:
|
||||
new_terminal = np.concatenate(
|
||||
(self.stackedobs[i, ..., :-stack_ax_size], old_terminal),
|
||||
axis=self.stack_dimension,
|
||||
)
|
||||
infos[i]["terminal_observation"] = new_terminal
|
||||
previous_stack = self.stacked_obs[env_idx, ..., :shift]
|
||||
|
||||
new_terminal = np.concatenate((previous_stack, old_terminal), axis=self.repeat_axis)
|
||||
infos[env_idx]["terminal_observation"] = new_terminal
|
||||
else:
|
||||
warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info")
|
||||
self.stackedobs[i] = 0
|
||||
self.stacked_obs[env_idx] = 0
|
||||
if self.channels_first:
|
||||
self.stackedobs[:, -observations.shape[self.stack_dimension] :, ...] = observations
|
||||
self.stacked_obs[:, shift:, ...] = observations
|
||||
else:
|
||||
self.stackedobs[..., -observations.shape[self.stack_dimension] :] = observations
|
||||
return self.stackedobs, infos
|
||||
|
||||
|
||||
class StackedDictObservations(StackedObservations):
|
||||
"""
|
||||
Frame stacking wrapper for dictionary data.
|
||||
|
||||
Dimension to stack over is either first (channels-first) or
|
||||
last (channels-last), which is detected automatically using
|
||||
``common.preprocessing.is_image_space_channels_first`` if
|
||||
observation is an image space.
|
||||
|
||||
:param num_envs: number of environments
|
||||
:param n_stack: Number of frames to stack
|
||||
:param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension.
|
||||
If None, automatically detect channel to stack over in case of image observation or default to "last" (default).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_envs: int,
|
||||
n_stack: int,
|
||||
observation_space: spaces.Dict,
|
||||
channels_order: Optional[Union[str, Dict[str, str]]] = None,
|
||||
):
|
||||
self.n_stack = n_stack
|
||||
self.channels_first = {}
|
||||
self.stack_dimension = {}
|
||||
self.stackedobs = {}
|
||||
self.repeat_axis = {}
|
||||
|
||||
for key, subspace in observation_space.spaces.items():
|
||||
assert isinstance(subspace, spaces.Box), "StackedDictObservations only works with nested gym.spaces.Box"
|
||||
if isinstance(channels_order, str) or channels_order is None:
|
||||
subspace_channel_order = channels_order
|
||||
else:
|
||||
subspace_channel_order = channels_order[key]
|
||||
(
|
||||
self.channels_first[key],
|
||||
self.stack_dimension[key],
|
||||
self.stackedobs[key],
|
||||
self.repeat_axis[key],
|
||||
) = self.compute_stacking(num_envs, n_stack, subspace, subspace_channel_order)
|
||||
|
||||
def stack_observation_space(self, observation_space: spaces.Dict) -> spaces.Dict:
|
||||
"""
|
||||
Returns the stacked version of a Dict observation space
|
||||
|
||||
:param observation_space: Dict observation space to stack
|
||||
:return: stacked observation space
|
||||
"""
|
||||
spaces_dict = {}
|
||||
for key, subspace in observation_space.spaces.items():
|
||||
low = np.repeat(subspace.low, self.n_stack, axis=self.repeat_axis[key])
|
||||
high = np.repeat(subspace.high, self.n_stack, axis=self.repeat_axis[key])
|
||||
spaces_dict[key] = spaces.Box(low=low, high=high, dtype=subspace.dtype)
|
||||
return spaces.Dict(spaces=spaces_dict)
|
||||
|
||||
def reset(self, observation: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: # pytype: disable=signature-mismatch
|
||||
"""
|
||||
Resets the stacked observations, adds the reset observation to the stack, and returns the stack
|
||||
|
||||
:param observation: Reset observation
|
||||
:return: Stacked reset observations
|
||||
"""
|
||||
for key, obs in observation.items():
|
||||
self.stackedobs[key][...] = 0
|
||||
if self.channels_first[key]:
|
||||
self.stackedobs[key][:, -obs.shape[self.stack_dimension[key]] :, ...] = obs
|
||||
else:
|
||||
self.stackedobs[key][..., -obs.shape[self.stack_dimension[key]] :] = obs
|
||||
return self.stackedobs
|
||||
|
||||
def update(
|
||||
self,
|
||||
observations: Dict[str, np.ndarray],
|
||||
dones: np.ndarray,
|
||||
infos: List[Dict[str, Any]],
|
||||
) -> Tuple[Dict[str, np.ndarray], List[Dict[str, Any]]]: # pytype: disable=signature-mismatch
|
||||
"""
|
||||
Adds the observations to the stack and uses the dones to update the infos.
|
||||
|
||||
:param observations: Dict of numpy arrays of observations
|
||||
:param dones: numpy array of dones
|
||||
:param infos: dict of infos
|
||||
:return: tuple of the stacked observations and the updated infos
|
||||
"""
|
||||
for key in self.stackedobs.keys():
|
||||
stack_ax_size = observations[key].shape[self.stack_dimension[key]]
|
||||
self.stackedobs[key] = np.roll(
|
||||
self.stackedobs[key],
|
||||
shift=-stack_ax_size,
|
||||
axis=self.stack_dimension[key],
|
||||
)
|
||||
|
||||
for i, done in enumerate(dones):
|
||||
if done:
|
||||
if "terminal_observation" in infos[i]:
|
||||
old_terminal = infos[i]["terminal_observation"][key]
|
||||
if self.channels_first[key]:
|
||||
new_terminal = np.vstack(
|
||||
(
|
||||
self.stackedobs[key][i, :-stack_ax_size, ...],
|
||||
old_terminal,
|
||||
)
|
||||
)
|
||||
else:
|
||||
new_terminal = np.concatenate(
|
||||
(
|
||||
self.stackedobs[key][i, ..., :-stack_ax_size],
|
||||
old_terminal,
|
||||
),
|
||||
axis=self.stack_dimension[key],
|
||||
)
|
||||
infos[i]["terminal_observation"][key] = new_terminal
|
||||
else:
|
||||
warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info")
|
||||
self.stackedobs[key][i] = 0
|
||||
if self.channels_first[key]:
|
||||
self.stackedobs[key][:, -stack_ax_size:, ...] = observations[key]
|
||||
else:
|
||||
self.stackedobs[key][..., -stack_ax_size:] = observations[key]
|
||||
return self.stackedobs, infos
|
||||
self.stacked_obs[..., shift:] = observations
|
||||
return self.stacked_obs, infos
|
||||
|
|
|
|||
|
|
@ -1,63 +1,40 @@
|
|||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper
|
||||
from stable_baselines3.common.vec_env.stacked_observations import StackedDictObservations, StackedObservations
|
||||
from stable_baselines3.common.vec_env.stacked_observations import StackedObservations
|
||||
|
||||
|
||||
class VecFrameStack(VecEnvWrapper):
|
||||
"""
|
||||
Frame stacking wrapper for vectorized environment. Designed for image observations.
|
||||
|
||||
Uses the StackedObservations class, or StackedDictObservations depending on the observations space
|
||||
|
||||
:param venv: the vectorized environment to wrap
|
||||
:param venv: Vectorized environment to wrap
|
||||
:param n_stack: Number of frames to stack
|
||||
:param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension.
|
||||
If None, automatically detect channel to stack over in case of image observation or default to "last" (default).
|
||||
Alternatively channels_order can be a dictionary which can be used with environments with Dict observation spaces
|
||||
"""
|
||||
|
||||
def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[Union[str, Dict[str, str]]] = None):
|
||||
self.venv = venv
|
||||
self.n_stack = n_stack
|
||||
def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[Union[str, Mapping[str, str]]] = None) -> None:
|
||||
assert isinstance(
|
||||
venv.observation_space, (spaces.Box, spaces.Dict)
|
||||
), "VecFrameStack only works with gym.spaces.Box and gym.spaces.Dict observation spaces"
|
||||
|
||||
wrapped_obs_space = venv.observation_space
|
||||
|
||||
if isinstance(wrapped_obs_space, spaces.Box):
|
||||
assert not isinstance(
|
||||
channels_order, dict
|
||||
), f"Expected None or string for channels_order but received {channels_order}"
|
||||
self.stackedobs = StackedObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order)
|
||||
|
||||
elif isinstance(wrapped_obs_space, spaces.Dict):
|
||||
self.stackedobs = StackedDictObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order)
|
||||
|
||||
else:
|
||||
raise Exception("VecFrameStack only works with gym.spaces.Box and gym.spaces.Dict observation spaces")
|
||||
|
||||
observation_space = self.stackedobs.stack_observation_space(wrapped_obs_space)
|
||||
VecEnvWrapper.__init__(self, venv, observation_space=observation_space)
|
||||
self.stacked_obs = StackedObservations(venv.num_envs, n_stack, venv.observation_space, channels_order)
|
||||
observation_space = self.stacked_obs.stacked_observation_space
|
||||
super().__init__(venv, observation_space=observation_space)
|
||||
|
||||
def step_wait(
|
||||
self,
|
||||
) -> Tuple[Union[np.ndarray, Dict[str, np.ndarray]], np.ndarray, np.ndarray, List[Dict[str, Any]],]:
|
||||
observations, rewards, dones, infos = self.venv.step_wait()
|
||||
|
||||
observations, infos = self.stackedobs.update(observations, dones, infos)
|
||||
|
||||
observations, infos = self.stacked_obs.update(observations, dones, infos)
|
||||
return observations, rewards, dones, infos
|
||||
|
||||
def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
|
||||
"""
|
||||
Reset all environments
|
||||
"""
|
||||
observation = self.venv.reset() # pytype:disable=annotation-type-mismatch
|
||||
|
||||
observation = self.stackedobs.reset(observation)
|
||||
observation = self.stacked_obs.reset(observation)
|
||||
return observation
|
||||
|
||||
def close(self) -> None:
|
||||
self.venv.close()
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.8.0a3
|
||||
1.8.0a4
|
||||
|
|
|
|||
314
tests/test_vec_stacked_obs.py
Normal file
314
tests/test_vec_stacked_obs.py
Normal file
|
|
@ -0,0 +1,314 @@
|
|||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
from stable_baselines3.common.vec_env.stacked_observations import StackedObservations
|
||||
|
||||
compute_stacking = StackedObservations.compute_stacking
|
||||
NUM_ENVS = 2
|
||||
N_STACK = 4
|
||||
H, W, C = 16, 24, 3
|
||||
|
||||
|
||||
def test_compute_stacking_box():
|
||||
space = spaces.Box(-1, 1, (4,))
|
||||
channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking(N_STACK, observation_space=space)
|
||||
assert not channels_first # default is channel last
|
||||
assert stack_dimension == -1
|
||||
assert stacked_shape == (N_STACK * 4,)
|
||||
assert repeat_axis == -1
|
||||
|
||||
|
||||
def test_compute_stacking_multidim_box():
|
||||
space = spaces.Box(-1, 1, (4, 5))
|
||||
channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking(N_STACK, observation_space=space)
|
||||
assert not channels_first # default is channel last
|
||||
assert stack_dimension == -1
|
||||
assert stacked_shape == (4, N_STACK * 5)
|
||||
assert repeat_axis == -1
|
||||
|
||||
|
||||
def test_compute_stacking_multidim_box_channel_first():
|
||||
space = spaces.Box(-1, 1, (4, 5))
|
||||
channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking(
|
||||
N_STACK, observation_space=space, channels_order="first"
|
||||
)
|
||||
assert channels_first # default is channel last
|
||||
assert stack_dimension == 1
|
||||
assert stacked_shape == (N_STACK * 4, 5)
|
||||
assert repeat_axis == 0
|
||||
|
||||
|
||||
def test_compute_stacking_image_channel_first():
|
||||
"""Detect that image is channel first and stack in that dimension."""
|
||||
space = spaces.Box(0, 255, (C, H, W), dtype=np.uint8)
|
||||
channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking(N_STACK, observation_space=space)
|
||||
assert channels_first # default is channel last
|
||||
assert stack_dimension == 1
|
||||
assert stacked_shape == (N_STACK * C, H, W)
|
||||
assert repeat_axis == 0
|
||||
|
||||
|
||||
def test_compute_stacking_image_channel_last():
|
||||
"""Detect that image is channel last and stack in that dimension."""
|
||||
space = spaces.Box(0, 255, (H, W, C), dtype=np.uint8)
|
||||
channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking(N_STACK, observation_space=space)
|
||||
assert not channels_first # default is channel last
|
||||
assert stack_dimension == -1
|
||||
assert stacked_shape == (H, W, N_STACK * C)
|
||||
assert repeat_axis == -1
|
||||
|
||||
|
||||
def test_compute_stacking_image_channel_first_stack_last():
|
||||
"""Detect that image is channel first and stack in that dimension."""
|
||||
space = spaces.Box(0, 255, (C, H, W), dtype=np.uint8)
|
||||
channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking(
|
||||
N_STACK, observation_space=space, channels_order="last"
|
||||
)
|
||||
assert not channels_first # default is channel last
|
||||
assert stack_dimension == -1
|
||||
assert stacked_shape == (C, H, N_STACK * W)
|
||||
assert repeat_axis == -1
|
||||
|
||||
|
||||
def test_compute_stacking_image_channel_last_stack_first():
|
||||
"""Detect that image is channel last and stack in that dimension."""
|
||||
space = spaces.Box(0, 255, (H, W, C), dtype=np.uint8)
|
||||
channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking(
|
||||
N_STACK, observation_space=space, channels_order="first"
|
||||
)
|
||||
assert channels_first # default is channel last
|
||||
assert stack_dimension == 1
|
||||
assert stacked_shape == (N_STACK * H, W, C)
|
||||
assert repeat_axis == 0
|
||||
|
||||
|
||||
def test_reset_update_box():
|
||||
space = spaces.Box(-1, 1, (4,))
|
||||
stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space)
|
||||
observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)])
|
||||
stacked_obs = stacked_observations.reset(observations_1)
|
||||
assert stacked_obs.shape == (NUM_ENVS, N_STACK * 4)
|
||||
assert stacked_obs.dtype == space.dtype
|
||||
observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)])
|
||||
dones = np.zeros((NUM_ENVS,), dtype=bool)
|
||||
infos = [{} for _ in range(NUM_ENVS)]
|
||||
stacked_obs, infos = stacked_observations.update(observations_2, dones, infos)
|
||||
assert stacked_obs.shape == (NUM_ENVS, N_STACK * 4)
|
||||
assert stacked_obs.dtype == space.dtype
|
||||
assert np.array_equal(
|
||||
stacked_obs,
|
||||
np.concatenate(
|
||||
(np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=-1
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_reset_update_multidim_box():
|
||||
space = spaces.Box(-1, 1, (4, 5))
|
||||
stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space)
|
||||
observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)])
|
||||
stacked_obs = stacked_observations.reset(observations_1)
|
||||
assert stacked_obs.shape == (NUM_ENVS, 4, N_STACK * 5)
|
||||
assert stacked_obs.dtype == space.dtype
|
||||
observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)])
|
||||
dones = np.zeros((NUM_ENVS,), dtype=bool)
|
||||
infos = [{} for _ in range(NUM_ENVS)]
|
||||
stacked_obs, infos = stacked_observations.update(observations_2, dones, infos)
|
||||
assert stacked_obs.shape == (NUM_ENVS, 4, N_STACK * 5)
|
||||
assert stacked_obs.dtype == space.dtype
|
||||
assert np.array_equal(
|
||||
stacked_obs,
|
||||
np.concatenate(
|
||||
(np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=-1
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_reset_update_multidim_box_channel_first():
|
||||
space = spaces.Box(-1, 1, (4, 5))
|
||||
stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space, channels_order="first")
|
||||
observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)])
|
||||
stacked_obs = stacked_observations.reset(observations_1)
|
||||
assert stacked_obs.shape == (NUM_ENVS, N_STACK * 4, 5)
|
||||
assert stacked_obs.dtype == space.dtype
|
||||
observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)])
|
||||
dones = np.zeros((NUM_ENVS,), dtype=bool)
|
||||
infos = [{} for _ in range(NUM_ENVS)]
|
||||
stacked_obs, infos = stacked_observations.update(observations_2, dones, infos)
|
||||
assert stacked_obs.shape == (NUM_ENVS, N_STACK * 4, 5)
|
||||
assert stacked_obs.dtype == space.dtype
|
||||
assert np.array_equal(
|
||||
stacked_obs,
|
||||
np.concatenate((np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=1),
|
||||
)
|
||||
|
||||
|
||||
def test_reset_update_image_channel_first():
|
||||
space = spaces.Box(0, 255, (C, H, W), dtype=np.uint8)
|
||||
stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space)
|
||||
observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)])
|
||||
stacked_obs = stacked_observations.reset(observations_1)
|
||||
assert stacked_obs.shape == (NUM_ENVS, N_STACK * C, H, W)
|
||||
assert stacked_obs.dtype == space.dtype
|
||||
observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)])
|
||||
dones = np.zeros((NUM_ENVS,), dtype=bool)
|
||||
infos = [{} for _ in range(NUM_ENVS)]
|
||||
stacked_obs, infos = stacked_observations.update(observations_2, dones, infos)
|
||||
assert stacked_obs.shape == (NUM_ENVS, N_STACK * C, H, W)
|
||||
assert stacked_obs.dtype == space.dtype
|
||||
assert np.array_equal(
|
||||
stacked_obs,
|
||||
np.concatenate((np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=1),
|
||||
)
|
||||
|
||||
|
||||
def test_reset_update_image_channel_last():
|
||||
space = spaces.Box(0, 255, (H, W, C), dtype=np.uint8)
|
||||
stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space)
|
||||
observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)])
|
||||
stacked_obs = stacked_observations.reset(observations_1)
|
||||
assert stacked_obs.shape == (NUM_ENVS, H, W, N_STACK * C)
|
||||
assert stacked_obs.dtype == space.dtype
|
||||
observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)])
|
||||
dones = np.zeros((NUM_ENVS,), dtype=bool)
|
||||
infos = [{} for _ in range(NUM_ENVS)]
|
||||
stacked_obs, infos = stacked_observations.update(observations_2, dones, infos)
|
||||
assert stacked_obs.shape == (NUM_ENVS, H, W, N_STACK * C)
|
||||
assert stacked_obs.dtype == space.dtype
|
||||
assert np.array_equal(
|
||||
stacked_obs,
|
||||
np.concatenate(
|
||||
(np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=-1
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_reset_update_image_channel_first_stack_last():
|
||||
space = spaces.Box(0, 255, (C, H, W), dtype=np.uint8)
|
||||
stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space, channels_order="last")
|
||||
observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)])
|
||||
stacked_obs = stacked_observations.reset(observations_1)
|
||||
assert stacked_obs.shape == (NUM_ENVS, C, H, N_STACK * W)
|
||||
assert stacked_obs.dtype == space.dtype
|
||||
observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)])
|
||||
dones = np.zeros((NUM_ENVS,), dtype=bool)
|
||||
infos = [{} for _ in range(NUM_ENVS)]
|
||||
stacked_obs, infos = stacked_observations.update(observations_2, dones, infos)
|
||||
assert stacked_obs.shape == (NUM_ENVS, C, H, N_STACK * W)
|
||||
assert stacked_obs.dtype == space.dtype
|
||||
assert np.array_equal(
|
||||
stacked_obs,
|
||||
np.concatenate(
|
||||
(np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=-1
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_reset_update_image_channel_last_stack_first():
|
||||
space = spaces.Box(0, 255, (H, W, C), dtype=np.uint8)
|
||||
stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space, channels_order="first")
|
||||
observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)])
|
||||
stacked_obs = stacked_observations.reset(observations_1)
|
||||
assert stacked_obs.shape == (NUM_ENVS, N_STACK * H, W, C)
|
||||
assert stacked_obs.dtype == space.dtype
|
||||
observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)])
|
||||
dones = np.zeros((NUM_ENVS,), dtype=bool)
|
||||
infos = [{} for _ in range(NUM_ENVS)]
|
||||
stacked_obs, infos = stacked_observations.update(observations_2, dones, infos)
|
||||
assert stacked_obs.shape == (NUM_ENVS, N_STACK * H, W, C)
|
||||
assert stacked_obs.dtype == space.dtype
|
||||
assert np.array_equal(
|
||||
stacked_obs,
|
||||
np.concatenate((np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=1),
|
||||
)
|
||||
|
||||
|
||||
def test_reset_update_dict():
|
||||
space = spaces.Dict({"key1": spaces.Box(0, 255, (H, W, C), dtype=np.uint8), "key2": spaces.Box(-1, 1, (4, 5))})
|
||||
stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space, channels_order={"key1": "first", "key2": "last"})
|
||||
observations_1 = {key: np.stack([subspace.sample() for _ in range(NUM_ENVS)]) for key, subspace in space.spaces.items()}
|
||||
stacked_obs = stacked_observations.reset(observations_1)
|
||||
assert isinstance(stacked_obs, dict)
|
||||
assert stacked_obs["key1"].shape == (NUM_ENVS, N_STACK * H, W, C)
|
||||
assert stacked_obs["key2"].shape == (NUM_ENVS, 4, N_STACK * 5)
|
||||
assert stacked_obs["key1"].dtype == space["key1"].dtype
|
||||
assert stacked_obs["key2"].dtype == space["key2"].dtype
|
||||
observations_2 = {key: np.stack([subspace.sample() for _ in range(NUM_ENVS)]) for key, subspace in space.spaces.items()}
|
||||
dones = np.zeros((NUM_ENVS,), dtype=bool)
|
||||
infos = [{} for _ in range(NUM_ENVS)]
|
||||
stacked_obs, infos = stacked_observations.update(observations_2, dones, infos)
|
||||
assert stacked_obs["key1"].shape == (NUM_ENVS, N_STACK * H, W, C)
|
||||
assert stacked_obs["key2"].shape == (NUM_ENVS, 4, N_STACK * 5)
|
||||
assert stacked_obs["key1"].dtype == space["key1"].dtype
|
||||
assert stacked_obs["key2"].dtype == space["key2"].dtype
|
||||
|
||||
assert np.array_equal(
|
||||
stacked_obs["key1"],
|
||||
np.concatenate(
|
||||
(
|
||||
np.zeros_like(observations_1["key1"]),
|
||||
np.zeros_like(observations_1["key1"]),
|
||||
observations_1["key1"],
|
||||
observations_2["key1"],
|
||||
),
|
||||
axis=1,
|
||||
),
|
||||
)
|
||||
assert np.array_equal(
|
||||
stacked_obs["key2"],
|
||||
np.concatenate(
|
||||
(
|
||||
np.zeros_like(observations_1["key2"]),
|
||||
np.zeros_like(observations_1["key2"]),
|
||||
observations_1["key2"],
|
||||
observations_2["key2"],
|
||||
),
|
||||
axis=-1,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_episode_termination_box():
|
||||
space = spaces.Box(-1, 1, (4,))
|
||||
stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space)
|
||||
observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)])
|
||||
stacked_observations.reset(observations_1)
|
||||
observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)])
|
||||
dones = np.zeros((NUM_ENVS,), dtype=bool)
|
||||
infos = [{} for _ in range(NUM_ENVS)]
|
||||
stacked_observations.update(observations_2, dones, infos)
|
||||
terminal_observation = space.sample()
|
||||
infos[1]["terminal_observation"] = terminal_observation # episode termination in env1
|
||||
dones[1] = True
|
||||
observations_3 = np.stack([space.sample() for _ in range(NUM_ENVS)])
|
||||
stacked_obs, infos = stacked_observations.update(observations_3, dones, infos)
|
||||
zeros = np.zeros_like(observations_1[0])
|
||||
true_stacked_obs_env1 = np.concatenate((zeros, observations_1[0], observations_2[0], observations_3[0]), axis=-1)
|
||||
true_stacked_obs_env2 = np.concatenate((zeros, zeros, zeros, observations_3[1]), axis=-1)
|
||||
true_stacked_obs = np.stack((true_stacked_obs_env1, true_stacked_obs_env2))
|
||||
assert np.array_equal(true_stacked_obs, stacked_obs)
|
||||
|
||||
|
||||
def test_episode_termination_dict():
|
||||
space = spaces.Dict({"key1": spaces.Box(0, 255, (H, W, 3), dtype=np.uint8), "key2": spaces.Box(-1, 1, (4, 5))})
|
||||
stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space, channels_order={"key1": "first", "key2": "last"})
|
||||
observations_1 = {key: np.stack([subspace.sample() for _ in range(NUM_ENVS)]) for key, subspace in space.spaces.items()}
|
||||
stacked_observations.reset(observations_1)
|
||||
observations_2 = {key: np.stack([subspace.sample() for _ in range(NUM_ENVS)]) for key, subspace in space.spaces.items()}
|
||||
dones = np.zeros((NUM_ENVS,), dtype=bool)
|
||||
infos = [{} for _ in range(NUM_ENVS)]
|
||||
stacked_observations.update(observations_2, dones, infos)
|
||||
terminal_observation = space.sample()
|
||||
infos[1]["terminal_observation"] = terminal_observation # episode termination in env1
|
||||
dones[1] = True
|
||||
observations_3 = {key: np.stack([subspace.sample() for _ in range(NUM_ENVS)]) for key, subspace in space.spaces.items()}
|
||||
stacked_obs, infos = stacked_observations.update(observations_3, dones, infos)
|
||||
|
||||
for key, axis in zip(observations_1.keys(), [0, -1]):
|
||||
zeros = np.zeros_like(observations_1[key][0])
|
||||
true_stacked_obs_env1 = np.concatenate(
|
||||
(zeros, observations_1[key][0], observations_2[key][0], observations_3[key][0]), axis
|
||||
)
|
||||
true_stacked_obs_env2 = np.concatenate((zeros, zeros, zeros, observations_3[key][1]), axis)
|
||||
true_stacked_obs = np.stack((true_stacked_obs_env1, true_stacked_obs_env2))
|
||||
assert np.array_equal(true_stacked_obs, stacked_obs[key])
|
||||
Loading…
Reference in a new issue