Refactor observation stacking (#1238)

* refactor stacking obs

* Improve docstring

* remove all StackedDictObservations

* Update tests and make stacked obs clearer

* Fix type check

* fix stacked_observation_space

* undo init change, deprecate StackedDictObservations

* deprecate stack_observation_space

* type hints

* ignore pytype errors

* undo vecenv doc change

* Deprecation warning in StackedDictObs doctstring

* Fix vec_env.rst

* Fix __all__ sorting

* fix pytype ignore statement

* Update docstring

* stack

* Remove n_stack

* Update changelog

* Simplify code

* Rename test file

* Re-use variable for shift

* Fix doc build

* Remove pytype comment

* Disable pytype error

---------

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
Quentin Gallouédec 2023-02-06 22:41:59 +01:00 committed by GitHub
parent 411ff697dd
commit 2e4a45020e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 459 additions and 234 deletions

View file

@ -122,12 +122,6 @@ StackedObservations
.. autoclass:: stable_baselines3.common.vec_env.stacked_observations.StackedObservations
:members:
StackedDictObservations
~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: stable_baselines3.common.vec_env.stacked_observations.StackedDictObservations
:members:
VecNormalize
~~~~~~~~~~~~

View file

@ -4,13 +4,14 @@ Changelog
==========
Release 1.8.0a3 (WIP)
Release 1.8.0a4 (WIP)
--------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Removed shared layers in ``mlp_extractor`` (@AlexPasqua)
- Refactored ``StackedObservations`` (it now handles dict obs, ``StackedDictObservations`` was removed)
New Features:
^^^^^^^^^^^^^
@ -36,6 +37,7 @@ Others:
- Fixed ``tests/test_tensorboard.py`` type hint
- Fixed ``tests/test_vec_normalize.py`` type hint
- Fixed ``stable_baselines3/common/monitor.py`` type hint
- Added tests for StackedObservations
Documentation:
^^^^^^^^^^^^^^

View file

@ -48,7 +48,6 @@ exclude = (?x)(
| stable_baselines3/common/vec_env/__init__.py$
| stable_baselines3/common/vec_env/base_vec_env.py$
| stable_baselines3/common/vec_env/dummy_vec_env.py$
| stable_baselines3/common/vec_env/stacked_observations.py$
| stable_baselines3/common/vec_env/subproc_vec_env.py$
| stable_baselines3/common/vec_env/util.py$
| stable_baselines3/common/vec_env/vec_extract_dict_obs.py$

View file

@ -4,7 +4,7 @@ from typing import Optional, Type, Union
from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines3.common.vec_env.stacked_observations import StackedDictObservations, StackedObservations
from stable_baselines3.common.vec_env.stacked_observations import StackedObservations
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan
from stable_baselines3.common.vec_env.vec_extract_dict_obs import VecExtractDictObs
@ -78,7 +78,6 @@ __all__ = [
"VecEnv",
"VecEnvWrapper",
"DummyVecEnv",
"StackedDictObservations",
"StackedObservations",
"SubprocVecEnv",
"VecCheckNan",

View file

@ -1,61 +1,80 @@
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, Generic, List, Mapping, Optional, Tuple, TypeVar, Union
import numpy as np
from gym import spaces
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
TObs = TypeVar("TObs", np.ndarray, Dict[str, np.ndarray])
class StackedObservations:
# Disable errors for pytype which doesn't play well with Generic[TypeVar]
# mypy check passes though
# pytype: disable=attribute-error
class StackedObservations(Generic[TObs]):
"""
Frame stacking wrapper for data.
Dimension to stack over is either first (channels-first) or
last (channels-last), which is detected automatically using
``common.preprocessing.is_image_space_channels_first`` if
observation is an image space.
Dimension to stack over is either first (channels-first) or last (channels-last), which is detected automatically using
``common.preprocessing.is_image_space_channels_first`` if observation is an image space.
:param num_envs: number of environments
:param num_envs: Number of environments
:param n_stack: Number of frames to stack
:param observation_space: Environment observation space.
:param observation_space: Environment observation space
:param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension.
If None, automatically detect channel to stack over in case of image observation or default to "last" (default).
If None, automatically detect channel to stack over in case of image observation or default to "last".
For Dict space, channels_order can also be a dictionary.
"""
def __init__(
self,
num_envs: int,
n_stack: int,
observation_space: spaces.Space,
channels_order: Optional[str] = None,
):
observation_space: Union[spaces.Box, spaces.Dict], # Replace by Space[TObs] in gym>=0.26
channels_order: Optional[Union[str, Mapping[str, Optional[str]]]] = None,
) -> None:
self.n_stack = n_stack
(
self.channels_first,
self.stack_dimension,
self.stackedobs,
self.repeat_axis,
) = self.compute_stacking(num_envs, n_stack, observation_space, channels_order)
super().__init__()
self.observation_space = observation_space
if isinstance(observation_space, spaces.Dict):
if not isinstance(channels_order, Mapping):
channels_order = {key: channels_order for key in observation_space.spaces.keys()}
self.sub_stacked_observations = {
key: StackedObservations(num_envs, n_stack, subspace, channels_order[key])
for key, subspace in observation_space.spaces.items()
}
self.stacked_observation_space = spaces.Dict(
{key: substack_obs.stacked_observation_space for key, substack_obs in self.sub_stacked_observations.items()}
) # type: spaces.Dict # make mypy happy
elif isinstance(observation_space, spaces.Box):
if isinstance(channels_order, Mapping):
raise TypeError("When the observation space is Box, channels_order can't be a dict.")
self.channels_first, self.stack_dimension, self.stacked_shape, self.repeat_axis = self.compute_stacking(
n_stack, observation_space, channels_order
)
low = np.repeat(observation_space.low, n_stack, axis=self.repeat_axis)
high = np.repeat(observation_space.high, n_stack, axis=self.repeat_axis)
self.stacked_observation_space = spaces.Box(low=low, high=high, dtype=observation_space.dtype)
self.stacked_obs = np.zeros((num_envs,) + self.stacked_shape, dtype=observation_space.dtype)
else:
raise TypeError(
f"StackedObservations only supports Box and Dict as observation spaces. {observation_space} was provided."
)
@staticmethod
def compute_stacking(
num_envs: int,
n_stack: int,
observation_space: spaces.Box,
channels_order: Optional[str] = None,
) -> Tuple[bool, int, np.ndarray, int]:
n_stack: int, observation_space: spaces.Box, channels_order: Optional[str] = None
) -> Tuple[bool, int, Tuple[int, ...], int]:
"""
Calculates the parameters in order to stack observations
:param num_envs: Number of environments in the stack
:param n_stack: The number of observations to stack
:param observation_space: The observation space
:param channels_order: The order of the channels
:return: tuple of channels_first, stack_dimension, stackedobs, repeat_axis
:param n_stack: Number of observations to stack
:param observation_space: Observation space
:param channels_order: Order of the channels
:return: Tuple of channels_first, stack_dimension, stackedobs, repeat_axis
"""
channels_first = False
if channels_order is None:
# Detect channel location automatically for images
if is_image_space(observation_space):
@ -74,192 +93,113 @@ class StackedObservations:
# This includes the vec-env dimension (first)
stack_dimension = 1 if channels_first else -1
repeat_axis = 0 if channels_first else -1
low = np.repeat(observation_space.low, n_stack, axis=repeat_axis)
stackedobs = np.zeros((num_envs,) + low.shape, low.dtype)
return channels_first, stack_dimension, stackedobs, repeat_axis
stacked_shape = list(observation_space.shape)
stacked_shape[repeat_axis] *= n_stack
return channels_first, stack_dimension, tuple(stacked_shape), repeat_axis
def stack_observation_space(self, observation_space: spaces.Box) -> spaces.Box:
def stack_observation_space(self, observation_space: Union[spaces.Box, spaces.Dict]) -> Union[spaces.Box, spaces.Dict]:
"""
Given an observation space, returns a new observation space with stacked observations
This function is deprecated.
As an alternative, use
.. code-block:: python
low = np.repeat(observation_space.low, stacked_observation.n_stack, axis=stacked_observation.repeat_axis)
high = np.repeat(observation_space.high, stacked_observation.n_stack, axis=stacked_observation.repeat_axis)
stacked_observation_space = spaces.Box(low=low, high=high, dtype=observation_space.dtype)
:return: New observation space with stacked dimensions
"""
warnings.warn(
"stack_observation_space is deprecated and will be removed in the next SB3 release. "
"Please refer to the docstring for a workaround.",
DeprecationWarning,
)
if isinstance(observation_space, spaces.Dict):
return spaces.Dict(
{
key: sub_stacked_observation.stack_observation_space(sub_stacked_observation.observation_space)
for key, sub_stacked_observation in self.sub_stacked_observations.items()
}
)
low = np.repeat(observation_space.low, self.n_stack, axis=self.repeat_axis)
high = np.repeat(observation_space.high, self.n_stack, axis=self.repeat_axis)
return spaces.Box(low=low, high=high, dtype=observation_space.dtype)
def reset(self, observation: np.ndarray) -> np.ndarray:
def reset(self, observation: TObs) -> TObs:
"""
Resets the stackedobs, adds the reset observation to the stack, and returns the stack
Reset the stacked_obs, add the reset observation to the stack, and return the stack.
:param observation: Reset observation
:return: The stacked reset observation
"""
self.stackedobs[...] = 0
if isinstance(observation, dict):
return {key: self.sub_stacked_observations[key].reset(obs) for key, obs in observation.items()}
self.stacked_obs[...] = 0
if self.channels_first:
self.stackedobs[:, -observation.shape[self.stack_dimension] :, ...] = observation
self.stacked_obs[:, -observation.shape[self.stack_dimension] :, ...] = observation
else:
self.stackedobs[..., -observation.shape[self.stack_dimension] :] = observation
return self.stackedobs
self.stacked_obs[..., -observation.shape[self.stack_dimension] :] = observation
return self.stacked_obs
def update(
self,
observations: np.ndarray,
observations: TObs,
dones: np.ndarray,
infos: List[Dict[str, Any]],
) -> Tuple[np.ndarray, List[Dict[str, Any]]]:
) -> Tuple[TObs, List[Dict[str, Any]]]:
"""
Adds the observations to the stack and uses the dones to update the infos.
Add the observations to the stack and use the dones to update the infos.
:param observations: numpy array of observations
:param dones: numpy array of done info
:param infos: numpy array of info dicts
:return: tuple of the stacked observations and the updated infos
:param observations: Observations
:param dones: Dones
:param infos: Infos
:return: Tuple of the stacked observations and the updated infos
"""
stack_ax_size = observations.shape[self.stack_dimension]
self.stackedobs = np.roll(self.stackedobs, shift=-stack_ax_size, axis=self.stack_dimension)
for i, done in enumerate(dones):
if isinstance(observations, dict):
# From [{}, {terminal_obs: {key1: ..., key2: ...}}]
# to {key1: [{}, {terminal_obs: ...}], key2: [{}, {terminal_obs: ...}]}
sub_infos = {
key: [
{"terminal_observation": info["terminal_observation"][key]} if "terminal_observation" in info else {}
for info in infos
]
for key in observations.keys()
}
stacked_obs = {}
stacked_infos = {}
for key, obs in observations.items():
stacked_obs[key], stacked_infos[key] = self.sub_stacked_observations[key].update(obs, dones, sub_infos[key])
# From {key1: [{}, {terminal_obs: ...}], key2: [{}, {terminal_obs: ...}]}
# to [{}, {terminal_obs: {key1: ..., key2: ...}}]
for key in stacked_infos.keys():
for env_idx in range(len(infos)):
if "terminal_observation" in infos[env_idx]:
infos[env_idx]["terminal_observation"][key] = stacked_infos[key][env_idx]["terminal_observation"]
return stacked_obs, infos
shift = -observations.shape[self.stack_dimension]
self.stacked_obs = np.roll(self.stacked_obs, shift, axis=self.stack_dimension)
for env_idx, done in enumerate(dones):
if done:
if "terminal_observation" in infos[i]:
old_terminal = infos[i]["terminal_observation"]
if "terminal_observation" in infos[env_idx]:
old_terminal = infos[env_idx]["terminal_observation"]
if self.channels_first:
new_terminal = np.concatenate(
(self.stackedobs[i, :-stack_ax_size, ...], old_terminal),
axis=0, # self.stack_dimension - 1, as there is not batch dim
)
previous_stack = self.stacked_obs[env_idx, :shift, ...]
else:
new_terminal = np.concatenate(
(self.stackedobs[i, ..., :-stack_ax_size], old_terminal),
axis=self.stack_dimension,
)
infos[i]["terminal_observation"] = new_terminal
previous_stack = self.stacked_obs[env_idx, ..., :shift]
new_terminal = np.concatenate((previous_stack, old_terminal), axis=self.repeat_axis)
infos[env_idx]["terminal_observation"] = new_terminal
else:
warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info")
self.stackedobs[i] = 0
self.stacked_obs[env_idx] = 0
if self.channels_first:
self.stackedobs[:, -observations.shape[self.stack_dimension] :, ...] = observations
self.stacked_obs[:, shift:, ...] = observations
else:
self.stackedobs[..., -observations.shape[self.stack_dimension] :] = observations
return self.stackedobs, infos
class StackedDictObservations(StackedObservations):
"""
Frame stacking wrapper for dictionary data.
Dimension to stack over is either first (channels-first) or
last (channels-last), which is detected automatically using
``common.preprocessing.is_image_space_channels_first`` if
observation is an image space.
:param num_envs: number of environments
:param n_stack: Number of frames to stack
:param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension.
If None, automatically detect channel to stack over in case of image observation or default to "last" (default).
"""
def __init__(
self,
num_envs: int,
n_stack: int,
observation_space: spaces.Dict,
channels_order: Optional[Union[str, Dict[str, str]]] = None,
):
self.n_stack = n_stack
self.channels_first = {}
self.stack_dimension = {}
self.stackedobs = {}
self.repeat_axis = {}
for key, subspace in observation_space.spaces.items():
assert isinstance(subspace, spaces.Box), "StackedDictObservations only works with nested gym.spaces.Box"
if isinstance(channels_order, str) or channels_order is None:
subspace_channel_order = channels_order
else:
subspace_channel_order = channels_order[key]
(
self.channels_first[key],
self.stack_dimension[key],
self.stackedobs[key],
self.repeat_axis[key],
) = self.compute_stacking(num_envs, n_stack, subspace, subspace_channel_order)
def stack_observation_space(self, observation_space: spaces.Dict) -> spaces.Dict:
"""
Returns the stacked version of a Dict observation space
:param observation_space: Dict observation space to stack
:return: stacked observation space
"""
spaces_dict = {}
for key, subspace in observation_space.spaces.items():
low = np.repeat(subspace.low, self.n_stack, axis=self.repeat_axis[key])
high = np.repeat(subspace.high, self.n_stack, axis=self.repeat_axis[key])
spaces_dict[key] = spaces.Box(low=low, high=high, dtype=subspace.dtype)
return spaces.Dict(spaces=spaces_dict)
def reset(self, observation: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: # pytype: disable=signature-mismatch
"""
Resets the stacked observations, adds the reset observation to the stack, and returns the stack
:param observation: Reset observation
:return: Stacked reset observations
"""
for key, obs in observation.items():
self.stackedobs[key][...] = 0
if self.channels_first[key]:
self.stackedobs[key][:, -obs.shape[self.stack_dimension[key]] :, ...] = obs
else:
self.stackedobs[key][..., -obs.shape[self.stack_dimension[key]] :] = obs
return self.stackedobs
def update(
self,
observations: Dict[str, np.ndarray],
dones: np.ndarray,
infos: List[Dict[str, Any]],
) -> Tuple[Dict[str, np.ndarray], List[Dict[str, Any]]]: # pytype: disable=signature-mismatch
"""
Adds the observations to the stack and uses the dones to update the infos.
:param observations: Dict of numpy arrays of observations
:param dones: numpy array of dones
:param infos: dict of infos
:return: tuple of the stacked observations and the updated infos
"""
for key in self.stackedobs.keys():
stack_ax_size = observations[key].shape[self.stack_dimension[key]]
self.stackedobs[key] = np.roll(
self.stackedobs[key],
shift=-stack_ax_size,
axis=self.stack_dimension[key],
)
for i, done in enumerate(dones):
if done:
if "terminal_observation" in infos[i]:
old_terminal = infos[i]["terminal_observation"][key]
if self.channels_first[key]:
new_terminal = np.vstack(
(
self.stackedobs[key][i, :-stack_ax_size, ...],
old_terminal,
)
)
else:
new_terminal = np.concatenate(
(
self.stackedobs[key][i, ..., :-stack_ax_size],
old_terminal,
),
axis=self.stack_dimension[key],
)
infos[i]["terminal_observation"][key] = new_terminal
else:
warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info")
self.stackedobs[key][i] = 0
if self.channels_first[key]:
self.stackedobs[key][:, -stack_ax_size:, ...] = observations[key]
else:
self.stackedobs[key][..., -stack_ax_size:] = observations[key]
return self.stackedobs, infos
self.stacked_obs[..., shift:] = observations
return self.stacked_obs, infos

View file

@ -1,63 +1,40 @@
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
import numpy as np
from gym import spaces
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper
from stable_baselines3.common.vec_env.stacked_observations import StackedDictObservations, StackedObservations
from stable_baselines3.common.vec_env.stacked_observations import StackedObservations
class VecFrameStack(VecEnvWrapper):
"""
Frame stacking wrapper for vectorized environment. Designed for image observations.
Uses the StackedObservations class, or StackedDictObservations depending on the observations space
:param venv: the vectorized environment to wrap
:param venv: Vectorized environment to wrap
:param n_stack: Number of frames to stack
:param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension.
If None, automatically detect channel to stack over in case of image observation or default to "last" (default).
Alternatively channels_order can be a dictionary which can be used with environments with Dict observation spaces
"""
def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[Union[str, Dict[str, str]]] = None):
self.venv = venv
self.n_stack = n_stack
def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[Union[str, Mapping[str, str]]] = None) -> None:
assert isinstance(
venv.observation_space, (spaces.Box, spaces.Dict)
), "VecFrameStack only works with gym.spaces.Box and gym.spaces.Dict observation spaces"
wrapped_obs_space = venv.observation_space
if isinstance(wrapped_obs_space, spaces.Box):
assert not isinstance(
channels_order, dict
), f"Expected None or string for channels_order but received {channels_order}"
self.stackedobs = StackedObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order)
elif isinstance(wrapped_obs_space, spaces.Dict):
self.stackedobs = StackedDictObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order)
else:
raise Exception("VecFrameStack only works with gym.spaces.Box and gym.spaces.Dict observation spaces")
observation_space = self.stackedobs.stack_observation_space(wrapped_obs_space)
VecEnvWrapper.__init__(self, venv, observation_space=observation_space)
self.stacked_obs = StackedObservations(venv.num_envs, n_stack, venv.observation_space, channels_order)
observation_space = self.stacked_obs.stacked_observation_space
super().__init__(venv, observation_space=observation_space)
def step_wait(
self,
) -> Tuple[Union[np.ndarray, Dict[str, np.ndarray]], np.ndarray, np.ndarray, List[Dict[str, Any]],]:
observations, rewards, dones, infos = self.venv.step_wait()
observations, infos = self.stackedobs.update(observations, dones, infos)
observations, infos = self.stacked_obs.update(observations, dones, infos)
return observations, rewards, dones, infos
def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
"""
Reset all environments
"""
observation = self.venv.reset() # pytype:disable=annotation-type-mismatch
observation = self.stackedobs.reset(observation)
observation = self.stacked_obs.reset(observation)
return observation
def close(self) -> None:
self.venv.close()

View file

@ -1 +1 @@
1.8.0a3
1.8.0a4

View file

@ -0,0 +1,314 @@
import numpy as np
from gym import spaces
from stable_baselines3.common.vec_env.stacked_observations import StackedObservations
compute_stacking = StackedObservations.compute_stacking
NUM_ENVS = 2
N_STACK = 4
H, W, C = 16, 24, 3
def test_compute_stacking_box():
space = spaces.Box(-1, 1, (4,))
channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking(N_STACK, observation_space=space)
assert not channels_first # default is channel last
assert stack_dimension == -1
assert stacked_shape == (N_STACK * 4,)
assert repeat_axis == -1
def test_compute_stacking_multidim_box():
space = spaces.Box(-1, 1, (4, 5))
channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking(N_STACK, observation_space=space)
assert not channels_first # default is channel last
assert stack_dimension == -1
assert stacked_shape == (4, N_STACK * 5)
assert repeat_axis == -1
def test_compute_stacking_multidim_box_channel_first():
space = spaces.Box(-1, 1, (4, 5))
channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking(
N_STACK, observation_space=space, channels_order="first"
)
assert channels_first # default is channel last
assert stack_dimension == 1
assert stacked_shape == (N_STACK * 4, 5)
assert repeat_axis == 0
def test_compute_stacking_image_channel_first():
"""Detect that image is channel first and stack in that dimension."""
space = spaces.Box(0, 255, (C, H, W), dtype=np.uint8)
channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking(N_STACK, observation_space=space)
assert channels_first # default is channel last
assert stack_dimension == 1
assert stacked_shape == (N_STACK * C, H, W)
assert repeat_axis == 0
def test_compute_stacking_image_channel_last():
"""Detect that image is channel last and stack in that dimension."""
space = spaces.Box(0, 255, (H, W, C), dtype=np.uint8)
channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking(N_STACK, observation_space=space)
assert not channels_first # default is channel last
assert stack_dimension == -1
assert stacked_shape == (H, W, N_STACK * C)
assert repeat_axis == -1
def test_compute_stacking_image_channel_first_stack_last():
"""Detect that image is channel first and stack in that dimension."""
space = spaces.Box(0, 255, (C, H, W), dtype=np.uint8)
channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking(
N_STACK, observation_space=space, channels_order="last"
)
assert not channels_first # default is channel last
assert stack_dimension == -1
assert stacked_shape == (C, H, N_STACK * W)
assert repeat_axis == -1
def test_compute_stacking_image_channel_last_stack_first():
"""Detect that image is channel last and stack in that dimension."""
space = spaces.Box(0, 255, (H, W, C), dtype=np.uint8)
channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking(
N_STACK, observation_space=space, channels_order="first"
)
assert channels_first # default is channel last
assert stack_dimension == 1
assert stacked_shape == (N_STACK * H, W, C)
assert repeat_axis == 0
def test_reset_update_box():
space = spaces.Box(-1, 1, (4,))
stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space)
observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)])
stacked_obs = stacked_observations.reset(observations_1)
assert stacked_obs.shape == (NUM_ENVS, N_STACK * 4)
assert stacked_obs.dtype == space.dtype
observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)])
dones = np.zeros((NUM_ENVS,), dtype=bool)
infos = [{} for _ in range(NUM_ENVS)]
stacked_obs, infos = stacked_observations.update(observations_2, dones, infos)
assert stacked_obs.shape == (NUM_ENVS, N_STACK * 4)
assert stacked_obs.dtype == space.dtype
assert np.array_equal(
stacked_obs,
np.concatenate(
(np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=-1
),
)
def test_reset_update_multidim_box():
space = spaces.Box(-1, 1, (4, 5))
stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space)
observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)])
stacked_obs = stacked_observations.reset(observations_1)
assert stacked_obs.shape == (NUM_ENVS, 4, N_STACK * 5)
assert stacked_obs.dtype == space.dtype
observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)])
dones = np.zeros((NUM_ENVS,), dtype=bool)
infos = [{} for _ in range(NUM_ENVS)]
stacked_obs, infos = stacked_observations.update(observations_2, dones, infos)
assert stacked_obs.shape == (NUM_ENVS, 4, N_STACK * 5)
assert stacked_obs.dtype == space.dtype
assert np.array_equal(
stacked_obs,
np.concatenate(
(np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=-1
),
)
def test_reset_update_multidim_box_channel_first():
space = spaces.Box(-1, 1, (4, 5))
stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space, channels_order="first")
observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)])
stacked_obs = stacked_observations.reset(observations_1)
assert stacked_obs.shape == (NUM_ENVS, N_STACK * 4, 5)
assert stacked_obs.dtype == space.dtype
observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)])
dones = np.zeros((NUM_ENVS,), dtype=bool)
infos = [{} for _ in range(NUM_ENVS)]
stacked_obs, infos = stacked_observations.update(observations_2, dones, infos)
assert stacked_obs.shape == (NUM_ENVS, N_STACK * 4, 5)
assert stacked_obs.dtype == space.dtype
assert np.array_equal(
stacked_obs,
np.concatenate((np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=1),
)
def test_reset_update_image_channel_first():
space = spaces.Box(0, 255, (C, H, W), dtype=np.uint8)
stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space)
observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)])
stacked_obs = stacked_observations.reset(observations_1)
assert stacked_obs.shape == (NUM_ENVS, N_STACK * C, H, W)
assert stacked_obs.dtype == space.dtype
observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)])
dones = np.zeros((NUM_ENVS,), dtype=bool)
infos = [{} for _ in range(NUM_ENVS)]
stacked_obs, infos = stacked_observations.update(observations_2, dones, infos)
assert stacked_obs.shape == (NUM_ENVS, N_STACK * C, H, W)
assert stacked_obs.dtype == space.dtype
assert np.array_equal(
stacked_obs,
np.concatenate((np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=1),
)
def test_reset_update_image_channel_last():
space = spaces.Box(0, 255, (H, W, C), dtype=np.uint8)
stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space)
observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)])
stacked_obs = stacked_observations.reset(observations_1)
assert stacked_obs.shape == (NUM_ENVS, H, W, N_STACK * C)
assert stacked_obs.dtype == space.dtype
observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)])
dones = np.zeros((NUM_ENVS,), dtype=bool)
infos = [{} for _ in range(NUM_ENVS)]
stacked_obs, infos = stacked_observations.update(observations_2, dones, infos)
assert stacked_obs.shape == (NUM_ENVS, H, W, N_STACK * C)
assert stacked_obs.dtype == space.dtype
assert np.array_equal(
stacked_obs,
np.concatenate(
(np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=-1
),
)
def test_reset_update_image_channel_first_stack_last():
space = spaces.Box(0, 255, (C, H, W), dtype=np.uint8)
stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space, channels_order="last")
observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)])
stacked_obs = stacked_observations.reset(observations_1)
assert stacked_obs.shape == (NUM_ENVS, C, H, N_STACK * W)
assert stacked_obs.dtype == space.dtype
observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)])
dones = np.zeros((NUM_ENVS,), dtype=bool)
infos = [{} for _ in range(NUM_ENVS)]
stacked_obs, infos = stacked_observations.update(observations_2, dones, infos)
assert stacked_obs.shape == (NUM_ENVS, C, H, N_STACK * W)
assert stacked_obs.dtype == space.dtype
assert np.array_equal(
stacked_obs,
np.concatenate(
(np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=-1
),
)
def test_reset_update_image_channel_last_stack_first():
space = spaces.Box(0, 255, (H, W, C), dtype=np.uint8)
stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space, channels_order="first")
observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)])
stacked_obs = stacked_observations.reset(observations_1)
assert stacked_obs.shape == (NUM_ENVS, N_STACK * H, W, C)
assert stacked_obs.dtype == space.dtype
observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)])
dones = np.zeros((NUM_ENVS,), dtype=bool)
infos = [{} for _ in range(NUM_ENVS)]
stacked_obs, infos = stacked_observations.update(observations_2, dones, infos)
assert stacked_obs.shape == (NUM_ENVS, N_STACK * H, W, C)
assert stacked_obs.dtype == space.dtype
assert np.array_equal(
stacked_obs,
np.concatenate((np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=1),
)
def test_reset_update_dict():
space = spaces.Dict({"key1": spaces.Box(0, 255, (H, W, C), dtype=np.uint8), "key2": spaces.Box(-1, 1, (4, 5))})
stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space, channels_order={"key1": "first", "key2": "last"})
observations_1 = {key: np.stack([subspace.sample() for _ in range(NUM_ENVS)]) for key, subspace in space.spaces.items()}
stacked_obs = stacked_observations.reset(observations_1)
assert isinstance(stacked_obs, dict)
assert stacked_obs["key1"].shape == (NUM_ENVS, N_STACK * H, W, C)
assert stacked_obs["key2"].shape == (NUM_ENVS, 4, N_STACK * 5)
assert stacked_obs["key1"].dtype == space["key1"].dtype
assert stacked_obs["key2"].dtype == space["key2"].dtype
observations_2 = {key: np.stack([subspace.sample() for _ in range(NUM_ENVS)]) for key, subspace in space.spaces.items()}
dones = np.zeros((NUM_ENVS,), dtype=bool)
infos = [{} for _ in range(NUM_ENVS)]
stacked_obs, infos = stacked_observations.update(observations_2, dones, infos)
assert stacked_obs["key1"].shape == (NUM_ENVS, N_STACK * H, W, C)
assert stacked_obs["key2"].shape == (NUM_ENVS, 4, N_STACK * 5)
assert stacked_obs["key1"].dtype == space["key1"].dtype
assert stacked_obs["key2"].dtype == space["key2"].dtype
assert np.array_equal(
stacked_obs["key1"],
np.concatenate(
(
np.zeros_like(observations_1["key1"]),
np.zeros_like(observations_1["key1"]),
observations_1["key1"],
observations_2["key1"],
),
axis=1,
),
)
assert np.array_equal(
stacked_obs["key2"],
np.concatenate(
(
np.zeros_like(observations_1["key2"]),
np.zeros_like(observations_1["key2"]),
observations_1["key2"],
observations_2["key2"],
),
axis=-1,
),
)
def test_episode_termination_box():
space = spaces.Box(-1, 1, (4,))
stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space)
observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)])
stacked_observations.reset(observations_1)
observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)])
dones = np.zeros((NUM_ENVS,), dtype=bool)
infos = [{} for _ in range(NUM_ENVS)]
stacked_observations.update(observations_2, dones, infos)
terminal_observation = space.sample()
infos[1]["terminal_observation"] = terminal_observation # episode termination in env1
dones[1] = True
observations_3 = np.stack([space.sample() for _ in range(NUM_ENVS)])
stacked_obs, infos = stacked_observations.update(observations_3, dones, infos)
zeros = np.zeros_like(observations_1[0])
true_stacked_obs_env1 = np.concatenate((zeros, observations_1[0], observations_2[0], observations_3[0]), axis=-1)
true_stacked_obs_env2 = np.concatenate((zeros, zeros, zeros, observations_3[1]), axis=-1)
true_stacked_obs = np.stack((true_stacked_obs_env1, true_stacked_obs_env2))
assert np.array_equal(true_stacked_obs, stacked_obs)
def test_episode_termination_dict():
space = spaces.Dict({"key1": spaces.Box(0, 255, (H, W, 3), dtype=np.uint8), "key2": spaces.Box(-1, 1, (4, 5))})
stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space, channels_order={"key1": "first", "key2": "last"})
observations_1 = {key: np.stack([subspace.sample() for _ in range(NUM_ENVS)]) for key, subspace in space.spaces.items()}
stacked_observations.reset(observations_1)
observations_2 = {key: np.stack([subspace.sample() for _ in range(NUM_ENVS)]) for key, subspace in space.spaces.items()}
dones = np.zeros((NUM_ENVS,), dtype=bool)
infos = [{} for _ in range(NUM_ENVS)]
stacked_observations.update(observations_2, dones, infos)
terminal_observation = space.sample()
infos[1]["terminal_observation"] = terminal_observation # episode termination in env1
dones[1] = True
observations_3 = {key: np.stack([subspace.sample() for _ in range(NUM_ENVS)]) for key, subspace in space.spaces.items()}
stacked_obs, infos = stacked_observations.update(observations_3, dones, infos)
for key, axis in zip(observations_1.keys(), [0, -1]):
zeros = np.zeros_like(observations_1[key][0])
true_stacked_obs_env1 = np.concatenate(
(zeros, observations_1[key][0], observations_2[key][0], observations_3[key][0]), axis
)
true_stacked_obs_env2 = np.concatenate((zeros, zeros, zeros, observations_3[key][1]), axis)
true_stacked_obs = np.stack((true_stacked_obs_env1, true_stacked_obs_env2))
assert np.array_equal(true_stacked_obs, stacked_obs[key])