From 2e4a45020ec619b09e2b1ccff14fa4f2c291dc77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Mon, 6 Feb 2023 22:41:59 +0100 Subject: [PATCH] Refactor observation stacking (#1238) * refactor stacking obs * Improve docstring * remove all StackedDictObservations * Update tests and make stacked obs clearer * Fix type check * fix stacked_observation_space * undo init change, deprecate StackedDictObservations * deprecate stack_observation_space * type hints * ignore pytype errors * undo vecenv doc change * Deprecation warning in StackedDictObs doctstring * Fix vec_env.rst * Fix __all__ sorting * fix pytype ignore statement * Update docstring * stack * Remove n_stack * Update changelog * Simplify code * Rename test file * Re-use variable for shift * Fix doc build * Remove pytype comment * Disable pytype error --------- Co-authored-by: Antonin RAFFIN --- docs/guide/vec_envs.rst | 6 - docs/misc/changelog.rst | 4 +- setup.cfg | 1 - stable_baselines3/common/vec_env/__init__.py | 3 +- .../common/vec_env/stacked_observations.py | 316 +++++++----------- .../common/vec_env/vec_frame_stack.py | 47 +-- stable_baselines3/version.txt | 2 +- tests/test_vec_stacked_obs.py | 314 +++++++++++++++++ 8 files changed, 459 insertions(+), 234 deletions(-) create mode 100644 tests/test_vec_stacked_obs.py diff --git a/docs/guide/vec_envs.rst b/docs/guide/vec_envs.rst index b074dad..d847811 100644 --- a/docs/guide/vec_envs.rst +++ b/docs/guide/vec_envs.rst @@ -122,12 +122,6 @@ StackedObservations .. autoclass:: stable_baselines3.common.vec_env.stacked_observations.StackedObservations :members: -StackedDictObservations -~~~~~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: stable_baselines3.common.vec_env.stacked_observations.StackedDictObservations - :members: - VecNormalize ~~~~~~~~~~~~ diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index acfa5e3..9d24bbe 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,13 +4,14 @@ Changelog ========== -Release 1.8.0a3 (WIP) +Release 1.8.0a4 (WIP) -------------------------- Breaking Changes: ^^^^^^^^^^^^^^^^^ - Removed shared layers in ``mlp_extractor`` (@AlexPasqua) +- Refactored ``StackedObservations`` (it now handles dict obs, ``StackedDictObservations`` was removed) New Features: ^^^^^^^^^^^^^ @@ -36,6 +37,7 @@ Others: - Fixed ``tests/test_tensorboard.py`` type hint - Fixed ``tests/test_vec_normalize.py`` type hint - Fixed ``stable_baselines3/common/monitor.py`` type hint +- Added tests for StackedObservations Documentation: ^^^^^^^^^^^^^^ diff --git a/setup.cfg b/setup.cfg index 37ffa17..3698ded 100644 --- a/setup.cfg +++ b/setup.cfg @@ -48,7 +48,6 @@ exclude = (?x)( | stable_baselines3/common/vec_env/__init__.py$ | stable_baselines3/common/vec_env/base_vec_env.py$ | stable_baselines3/common/vec_env/dummy_vec_env.py$ - | stable_baselines3/common/vec_env/stacked_observations.py$ | stable_baselines3/common/vec_env/subproc_vec_env.py$ | stable_baselines3/common/vec_env/util.py$ | stable_baselines3/common/vec_env/vec_extract_dict_obs.py$ diff --git a/stable_baselines3/common/vec_env/__init__.py b/stable_baselines3/common/vec_env/__init__.py index 33a103a..2c03637 100644 --- a/stable_baselines3/common/vec_env/__init__.py +++ b/stable_baselines3/common/vec_env/__init__.py @@ -4,7 +4,7 @@ from typing import Optional, Type, Union from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv -from stable_baselines3.common.vec_env.stacked_observations import StackedDictObservations, StackedObservations +from stable_baselines3.common.vec_env.stacked_observations import StackedObservations from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan from stable_baselines3.common.vec_env.vec_extract_dict_obs import VecExtractDictObs @@ -78,7 +78,6 @@ __all__ = [ "VecEnv", "VecEnvWrapper", "DummyVecEnv", - "StackedDictObservations", "StackedObservations", "SubprocVecEnv", "VecCheckNan", diff --git a/stable_baselines3/common/vec_env/stacked_observations.py b/stable_baselines3/common/vec_env/stacked_observations.py index d373b87..a26812c 100644 --- a/stable_baselines3/common/vec_env/stacked_observations.py +++ b/stable_baselines3/common/vec_env/stacked_observations.py @@ -1,61 +1,80 @@ import warnings -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Generic, List, Mapping, Optional, Tuple, TypeVar, Union import numpy as np from gym import spaces from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first +TObs = TypeVar("TObs", np.ndarray, Dict[str, np.ndarray]) -class StackedObservations: + +# Disable errors for pytype which doesn't play well with Generic[TypeVar] +# mypy check passes though +# pytype: disable=attribute-error +class StackedObservations(Generic[TObs]): """ Frame stacking wrapper for data. - Dimension to stack over is either first (channels-first) or - last (channels-last), which is detected automatically using - ``common.preprocessing.is_image_space_channels_first`` if - observation is an image space. + Dimension to stack over is either first (channels-first) or last (channels-last), which is detected automatically using + ``common.preprocessing.is_image_space_channels_first`` if observation is an image space. - :param num_envs: number of environments + :param num_envs: Number of environments :param n_stack: Number of frames to stack - :param observation_space: Environment observation space. + :param observation_space: Environment observation space :param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension. - If None, automatically detect channel to stack over in case of image observation or default to "last" (default). + If None, automatically detect channel to stack over in case of image observation or default to "last". + For Dict space, channels_order can also be a dictionary. """ def __init__( self, num_envs: int, n_stack: int, - observation_space: spaces.Space, - channels_order: Optional[str] = None, - ): + observation_space: Union[spaces.Box, spaces.Dict], # Replace by Space[TObs] in gym>=0.26 + channels_order: Optional[Union[str, Mapping[str, Optional[str]]]] = None, + ) -> None: self.n_stack = n_stack - ( - self.channels_first, - self.stack_dimension, - self.stackedobs, - self.repeat_axis, - ) = self.compute_stacking(num_envs, n_stack, observation_space, channels_order) - super().__init__() + self.observation_space = observation_space + if isinstance(observation_space, spaces.Dict): + if not isinstance(channels_order, Mapping): + channels_order = {key: channels_order for key in observation_space.spaces.keys()} + self.sub_stacked_observations = { + key: StackedObservations(num_envs, n_stack, subspace, channels_order[key]) + for key, subspace in observation_space.spaces.items() + } + self.stacked_observation_space = spaces.Dict( + {key: substack_obs.stacked_observation_space for key, substack_obs in self.sub_stacked_observations.items()} + ) # type: spaces.Dict # make mypy happy + elif isinstance(observation_space, spaces.Box): + if isinstance(channels_order, Mapping): + raise TypeError("When the observation space is Box, channels_order can't be a dict.") + + self.channels_first, self.stack_dimension, self.stacked_shape, self.repeat_axis = self.compute_stacking( + n_stack, observation_space, channels_order + ) + low = np.repeat(observation_space.low, n_stack, axis=self.repeat_axis) + high = np.repeat(observation_space.high, n_stack, axis=self.repeat_axis) + self.stacked_observation_space = spaces.Box(low=low, high=high, dtype=observation_space.dtype) + self.stacked_obs = np.zeros((num_envs,) + self.stacked_shape, dtype=observation_space.dtype) + else: + raise TypeError( + f"StackedObservations only supports Box and Dict as observation spaces. {observation_space} was provided." + ) @staticmethod def compute_stacking( - num_envs: int, - n_stack: int, - observation_space: spaces.Box, - channels_order: Optional[str] = None, - ) -> Tuple[bool, int, np.ndarray, int]: + n_stack: int, observation_space: spaces.Box, channels_order: Optional[str] = None + ) -> Tuple[bool, int, Tuple[int, ...], int]: """ Calculates the parameters in order to stack observations - :param num_envs: Number of environments in the stack - :param n_stack: The number of observations to stack - :param observation_space: The observation space - :param channels_order: The order of the channels - :return: tuple of channels_first, stack_dimension, stackedobs, repeat_axis + :param n_stack: Number of observations to stack + :param observation_space: Observation space + :param channels_order: Order of the channels + :return: Tuple of channels_first, stack_dimension, stackedobs, repeat_axis """ - channels_first = False + if channels_order is None: # Detect channel location automatically for images if is_image_space(observation_space): @@ -74,192 +93,113 @@ class StackedObservations: # This includes the vec-env dimension (first) stack_dimension = 1 if channels_first else -1 repeat_axis = 0 if channels_first else -1 - low = np.repeat(observation_space.low, n_stack, axis=repeat_axis) - stackedobs = np.zeros((num_envs,) + low.shape, low.dtype) - return channels_first, stack_dimension, stackedobs, repeat_axis + stacked_shape = list(observation_space.shape) + stacked_shape[repeat_axis] *= n_stack + return channels_first, stack_dimension, tuple(stacked_shape), repeat_axis - def stack_observation_space(self, observation_space: spaces.Box) -> spaces.Box: + def stack_observation_space(self, observation_space: Union[spaces.Box, spaces.Dict]) -> Union[spaces.Box, spaces.Dict]: """ - Given an observation space, returns a new observation space with stacked observations + This function is deprecated. + + As an alternative, use + + .. code-block:: python + + low = np.repeat(observation_space.low, stacked_observation.n_stack, axis=stacked_observation.repeat_axis) + high = np.repeat(observation_space.high, stacked_observation.n_stack, axis=stacked_observation.repeat_axis) + stacked_observation_space = spaces.Box(low=low, high=high, dtype=observation_space.dtype) :return: New observation space with stacked dimensions """ + warnings.warn( + "stack_observation_space is deprecated and will be removed in the next SB3 release. " + "Please refer to the docstring for a workaround.", + DeprecationWarning, + ) + if isinstance(observation_space, spaces.Dict): + return spaces.Dict( + { + key: sub_stacked_observation.stack_observation_space(sub_stacked_observation.observation_space) + for key, sub_stacked_observation in self.sub_stacked_observations.items() + } + ) low = np.repeat(observation_space.low, self.n_stack, axis=self.repeat_axis) high = np.repeat(observation_space.high, self.n_stack, axis=self.repeat_axis) return spaces.Box(low=low, high=high, dtype=observation_space.dtype) - def reset(self, observation: np.ndarray) -> np.ndarray: + def reset(self, observation: TObs) -> TObs: """ - Resets the stackedobs, adds the reset observation to the stack, and returns the stack + Reset the stacked_obs, add the reset observation to the stack, and return the stack. :param observation: Reset observation :return: The stacked reset observation """ - self.stackedobs[...] = 0 + if isinstance(observation, dict): + return {key: self.sub_stacked_observations[key].reset(obs) for key, obs in observation.items()} + + self.stacked_obs[...] = 0 if self.channels_first: - self.stackedobs[:, -observation.shape[self.stack_dimension] :, ...] = observation + self.stacked_obs[:, -observation.shape[self.stack_dimension] :, ...] = observation else: - self.stackedobs[..., -observation.shape[self.stack_dimension] :] = observation - return self.stackedobs + self.stacked_obs[..., -observation.shape[self.stack_dimension] :] = observation + return self.stacked_obs def update( self, - observations: np.ndarray, + observations: TObs, dones: np.ndarray, infos: List[Dict[str, Any]], - ) -> Tuple[np.ndarray, List[Dict[str, Any]]]: + ) -> Tuple[TObs, List[Dict[str, Any]]]: """ - Adds the observations to the stack and uses the dones to update the infos. + Add the observations to the stack and use the dones to update the infos. - :param observations: numpy array of observations - :param dones: numpy array of done info - :param infos: numpy array of info dicts - :return: tuple of the stacked observations and the updated infos + :param observations: Observations + :param dones: Dones + :param infos: Infos + :return: Tuple of the stacked observations and the updated infos """ - stack_ax_size = observations.shape[self.stack_dimension] - self.stackedobs = np.roll(self.stackedobs, shift=-stack_ax_size, axis=self.stack_dimension) - for i, done in enumerate(dones): + if isinstance(observations, dict): + # From [{}, {terminal_obs: {key1: ..., key2: ...}}] + # to {key1: [{}, {terminal_obs: ...}], key2: [{}, {terminal_obs: ...}]} + sub_infos = { + key: [ + {"terminal_observation": info["terminal_observation"][key]} if "terminal_observation" in info else {} + for info in infos + ] + for key in observations.keys() + } + + stacked_obs = {} + stacked_infos = {} + for key, obs in observations.items(): + stacked_obs[key], stacked_infos[key] = self.sub_stacked_observations[key].update(obs, dones, sub_infos[key]) + + # From {key1: [{}, {terminal_obs: ...}], key2: [{}, {terminal_obs: ...}]} + # to [{}, {terminal_obs: {key1: ..., key2: ...}}] + for key in stacked_infos.keys(): + for env_idx in range(len(infos)): + if "terminal_observation" in infos[env_idx]: + infos[env_idx]["terminal_observation"][key] = stacked_infos[key][env_idx]["terminal_observation"] + return stacked_obs, infos + + shift = -observations.shape[self.stack_dimension] + self.stacked_obs = np.roll(self.stacked_obs, shift, axis=self.stack_dimension) + for env_idx, done in enumerate(dones): if done: - if "terminal_observation" in infos[i]: - old_terminal = infos[i]["terminal_observation"] + if "terminal_observation" in infos[env_idx]: + old_terminal = infos[env_idx]["terminal_observation"] if self.channels_first: - new_terminal = np.concatenate( - (self.stackedobs[i, :-stack_ax_size, ...], old_terminal), - axis=0, # self.stack_dimension - 1, as there is not batch dim - ) + previous_stack = self.stacked_obs[env_idx, :shift, ...] else: - new_terminal = np.concatenate( - (self.stackedobs[i, ..., :-stack_ax_size], old_terminal), - axis=self.stack_dimension, - ) - infos[i]["terminal_observation"] = new_terminal + previous_stack = self.stacked_obs[env_idx, ..., :shift] + + new_terminal = np.concatenate((previous_stack, old_terminal), axis=self.repeat_axis) + infos[env_idx]["terminal_observation"] = new_terminal else: warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info") - self.stackedobs[i] = 0 + self.stacked_obs[env_idx] = 0 if self.channels_first: - self.stackedobs[:, -observations.shape[self.stack_dimension] :, ...] = observations + self.stacked_obs[:, shift:, ...] = observations else: - self.stackedobs[..., -observations.shape[self.stack_dimension] :] = observations - return self.stackedobs, infos - - -class StackedDictObservations(StackedObservations): - """ - Frame stacking wrapper for dictionary data. - - Dimension to stack over is either first (channels-first) or - last (channels-last), which is detected automatically using - ``common.preprocessing.is_image_space_channels_first`` if - observation is an image space. - - :param num_envs: number of environments - :param n_stack: Number of frames to stack - :param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension. - If None, automatically detect channel to stack over in case of image observation or default to "last" (default). - """ - - def __init__( - self, - num_envs: int, - n_stack: int, - observation_space: spaces.Dict, - channels_order: Optional[Union[str, Dict[str, str]]] = None, - ): - self.n_stack = n_stack - self.channels_first = {} - self.stack_dimension = {} - self.stackedobs = {} - self.repeat_axis = {} - - for key, subspace in observation_space.spaces.items(): - assert isinstance(subspace, spaces.Box), "StackedDictObservations only works with nested gym.spaces.Box" - if isinstance(channels_order, str) or channels_order is None: - subspace_channel_order = channels_order - else: - subspace_channel_order = channels_order[key] - ( - self.channels_first[key], - self.stack_dimension[key], - self.stackedobs[key], - self.repeat_axis[key], - ) = self.compute_stacking(num_envs, n_stack, subspace, subspace_channel_order) - - def stack_observation_space(self, observation_space: spaces.Dict) -> spaces.Dict: - """ - Returns the stacked version of a Dict observation space - - :param observation_space: Dict observation space to stack - :return: stacked observation space - """ - spaces_dict = {} - for key, subspace in observation_space.spaces.items(): - low = np.repeat(subspace.low, self.n_stack, axis=self.repeat_axis[key]) - high = np.repeat(subspace.high, self.n_stack, axis=self.repeat_axis[key]) - spaces_dict[key] = spaces.Box(low=low, high=high, dtype=subspace.dtype) - return spaces.Dict(spaces=spaces_dict) - - def reset(self, observation: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: # pytype: disable=signature-mismatch - """ - Resets the stacked observations, adds the reset observation to the stack, and returns the stack - - :param observation: Reset observation - :return: Stacked reset observations - """ - for key, obs in observation.items(): - self.stackedobs[key][...] = 0 - if self.channels_first[key]: - self.stackedobs[key][:, -obs.shape[self.stack_dimension[key]] :, ...] = obs - else: - self.stackedobs[key][..., -obs.shape[self.stack_dimension[key]] :] = obs - return self.stackedobs - - def update( - self, - observations: Dict[str, np.ndarray], - dones: np.ndarray, - infos: List[Dict[str, Any]], - ) -> Tuple[Dict[str, np.ndarray], List[Dict[str, Any]]]: # pytype: disable=signature-mismatch - """ - Adds the observations to the stack and uses the dones to update the infos. - - :param observations: Dict of numpy arrays of observations - :param dones: numpy array of dones - :param infos: dict of infos - :return: tuple of the stacked observations and the updated infos - """ - for key in self.stackedobs.keys(): - stack_ax_size = observations[key].shape[self.stack_dimension[key]] - self.stackedobs[key] = np.roll( - self.stackedobs[key], - shift=-stack_ax_size, - axis=self.stack_dimension[key], - ) - - for i, done in enumerate(dones): - if done: - if "terminal_observation" in infos[i]: - old_terminal = infos[i]["terminal_observation"][key] - if self.channels_first[key]: - new_terminal = np.vstack( - ( - self.stackedobs[key][i, :-stack_ax_size, ...], - old_terminal, - ) - ) - else: - new_terminal = np.concatenate( - ( - self.stackedobs[key][i, ..., :-stack_ax_size], - old_terminal, - ), - axis=self.stack_dimension[key], - ) - infos[i]["terminal_observation"][key] = new_terminal - else: - warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info") - self.stackedobs[key][i] = 0 - if self.channels_first[key]: - self.stackedobs[key][:, -stack_ax_size:, ...] = observations[key] - else: - self.stackedobs[key][..., -stack_ax_size:] = observations[key] - return self.stackedobs, infos + self.stacked_obs[..., shift:] = observations + return self.stacked_obs, infos diff --git a/stable_baselines3/common/vec_env/vec_frame_stack.py b/stable_baselines3/common/vec_env/vec_frame_stack.py index d933104..8a020dd 100644 --- a/stable_baselines3/common/vec_env/vec_frame_stack.py +++ b/stable_baselines3/common/vec_env/vec_frame_stack.py @@ -1,63 +1,40 @@ -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Mapping, Optional, Tuple, Union import numpy as np from gym import spaces from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper -from stable_baselines3.common.vec_env.stacked_observations import StackedDictObservations, StackedObservations +from stable_baselines3.common.vec_env.stacked_observations import StackedObservations class VecFrameStack(VecEnvWrapper): """ Frame stacking wrapper for vectorized environment. Designed for image observations. - Uses the StackedObservations class, or StackedDictObservations depending on the observations space - - :param venv: the vectorized environment to wrap + :param venv: Vectorized environment to wrap :param n_stack: Number of frames to stack :param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension. If None, automatically detect channel to stack over in case of image observation or default to "last" (default). Alternatively channels_order can be a dictionary which can be used with environments with Dict observation spaces """ - def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[Union[str, Dict[str, str]]] = None): - self.venv = venv - self.n_stack = n_stack + def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[Union[str, Mapping[str, str]]] = None) -> None: + assert isinstance( + venv.observation_space, (spaces.Box, spaces.Dict) + ), "VecFrameStack only works with gym.spaces.Box and gym.spaces.Dict observation spaces" - wrapped_obs_space = venv.observation_space - - if isinstance(wrapped_obs_space, spaces.Box): - assert not isinstance( - channels_order, dict - ), f"Expected None or string for channels_order but received {channels_order}" - self.stackedobs = StackedObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order) - - elif isinstance(wrapped_obs_space, spaces.Dict): - self.stackedobs = StackedDictObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order) - - else: - raise Exception("VecFrameStack only works with gym.spaces.Box and gym.spaces.Dict observation spaces") - - observation_space = self.stackedobs.stack_observation_space(wrapped_obs_space) - VecEnvWrapper.__init__(self, venv, observation_space=observation_space) + self.stacked_obs = StackedObservations(venv.num_envs, n_stack, venv.observation_space, channels_order) + observation_space = self.stacked_obs.stacked_observation_space + super().__init__(venv, observation_space=observation_space) def step_wait( self, ) -> Tuple[Union[np.ndarray, Dict[str, np.ndarray]], np.ndarray, np.ndarray, List[Dict[str, Any]],]: observations, rewards, dones, infos = self.venv.step_wait() - - observations, infos = self.stackedobs.update(observations, dones, infos) - + observations, infos = self.stacked_obs.update(observations, dones, infos) return observations, rewards, dones, infos def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: - """ - Reset all environments - """ observation = self.venv.reset() # pytype:disable=annotation-type-mismatch - - observation = self.stackedobs.reset(observation) + observation = self.stacked_obs.reset(observation) return observation - - def close(self) -> None: - self.venv.close() diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index f5e9264..e8175d3 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.8.0a3 +1.8.0a4 diff --git a/tests/test_vec_stacked_obs.py b/tests/test_vec_stacked_obs.py new file mode 100644 index 0000000..0a7aa39 --- /dev/null +++ b/tests/test_vec_stacked_obs.py @@ -0,0 +1,314 @@ +import numpy as np +from gym import spaces + +from stable_baselines3.common.vec_env.stacked_observations import StackedObservations + +compute_stacking = StackedObservations.compute_stacking +NUM_ENVS = 2 +N_STACK = 4 +H, W, C = 16, 24, 3 + + +def test_compute_stacking_box(): + space = spaces.Box(-1, 1, (4,)) + channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking(N_STACK, observation_space=space) + assert not channels_first # default is channel last + assert stack_dimension == -1 + assert stacked_shape == (N_STACK * 4,) + assert repeat_axis == -1 + + +def test_compute_stacking_multidim_box(): + space = spaces.Box(-1, 1, (4, 5)) + channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking(N_STACK, observation_space=space) + assert not channels_first # default is channel last + assert stack_dimension == -1 + assert stacked_shape == (4, N_STACK * 5) + assert repeat_axis == -1 + + +def test_compute_stacking_multidim_box_channel_first(): + space = spaces.Box(-1, 1, (4, 5)) + channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking( + N_STACK, observation_space=space, channels_order="first" + ) + assert channels_first # default is channel last + assert stack_dimension == 1 + assert stacked_shape == (N_STACK * 4, 5) + assert repeat_axis == 0 + + +def test_compute_stacking_image_channel_first(): + """Detect that image is channel first and stack in that dimension.""" + space = spaces.Box(0, 255, (C, H, W), dtype=np.uint8) + channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking(N_STACK, observation_space=space) + assert channels_first # default is channel last + assert stack_dimension == 1 + assert stacked_shape == (N_STACK * C, H, W) + assert repeat_axis == 0 + + +def test_compute_stacking_image_channel_last(): + """Detect that image is channel last and stack in that dimension.""" + space = spaces.Box(0, 255, (H, W, C), dtype=np.uint8) + channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking(N_STACK, observation_space=space) + assert not channels_first # default is channel last + assert stack_dimension == -1 + assert stacked_shape == (H, W, N_STACK * C) + assert repeat_axis == -1 + + +def test_compute_stacking_image_channel_first_stack_last(): + """Detect that image is channel first and stack in that dimension.""" + space = spaces.Box(0, 255, (C, H, W), dtype=np.uint8) + channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking( + N_STACK, observation_space=space, channels_order="last" + ) + assert not channels_first # default is channel last + assert stack_dimension == -1 + assert stacked_shape == (C, H, N_STACK * W) + assert repeat_axis == -1 + + +def test_compute_stacking_image_channel_last_stack_first(): + """Detect that image is channel last and stack in that dimension.""" + space = spaces.Box(0, 255, (H, W, C), dtype=np.uint8) + channels_first, stack_dimension, stacked_shape, repeat_axis = compute_stacking( + N_STACK, observation_space=space, channels_order="first" + ) + assert channels_first # default is channel last + assert stack_dimension == 1 + assert stacked_shape == (N_STACK * H, W, C) + assert repeat_axis == 0 + + +def test_reset_update_box(): + space = spaces.Box(-1, 1, (4,)) + stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space) + observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)]) + stacked_obs = stacked_observations.reset(observations_1) + assert stacked_obs.shape == (NUM_ENVS, N_STACK * 4) + assert stacked_obs.dtype == space.dtype + observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)]) + dones = np.zeros((NUM_ENVS,), dtype=bool) + infos = [{} for _ in range(NUM_ENVS)] + stacked_obs, infos = stacked_observations.update(observations_2, dones, infos) + assert stacked_obs.shape == (NUM_ENVS, N_STACK * 4) + assert stacked_obs.dtype == space.dtype + assert np.array_equal( + stacked_obs, + np.concatenate( + (np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=-1 + ), + ) + + +def test_reset_update_multidim_box(): + space = spaces.Box(-1, 1, (4, 5)) + stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space) + observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)]) + stacked_obs = stacked_observations.reset(observations_1) + assert stacked_obs.shape == (NUM_ENVS, 4, N_STACK * 5) + assert stacked_obs.dtype == space.dtype + observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)]) + dones = np.zeros((NUM_ENVS,), dtype=bool) + infos = [{} for _ in range(NUM_ENVS)] + stacked_obs, infos = stacked_observations.update(observations_2, dones, infos) + assert stacked_obs.shape == (NUM_ENVS, 4, N_STACK * 5) + assert stacked_obs.dtype == space.dtype + assert np.array_equal( + stacked_obs, + np.concatenate( + (np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=-1 + ), + ) + + +def test_reset_update_multidim_box_channel_first(): + space = spaces.Box(-1, 1, (4, 5)) + stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space, channels_order="first") + observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)]) + stacked_obs = stacked_observations.reset(observations_1) + assert stacked_obs.shape == (NUM_ENVS, N_STACK * 4, 5) + assert stacked_obs.dtype == space.dtype + observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)]) + dones = np.zeros((NUM_ENVS,), dtype=bool) + infos = [{} for _ in range(NUM_ENVS)] + stacked_obs, infos = stacked_observations.update(observations_2, dones, infos) + assert stacked_obs.shape == (NUM_ENVS, N_STACK * 4, 5) + assert stacked_obs.dtype == space.dtype + assert np.array_equal( + stacked_obs, + np.concatenate((np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=1), + ) + + +def test_reset_update_image_channel_first(): + space = spaces.Box(0, 255, (C, H, W), dtype=np.uint8) + stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space) + observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)]) + stacked_obs = stacked_observations.reset(observations_1) + assert stacked_obs.shape == (NUM_ENVS, N_STACK * C, H, W) + assert stacked_obs.dtype == space.dtype + observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)]) + dones = np.zeros((NUM_ENVS,), dtype=bool) + infos = [{} for _ in range(NUM_ENVS)] + stacked_obs, infos = stacked_observations.update(observations_2, dones, infos) + assert stacked_obs.shape == (NUM_ENVS, N_STACK * C, H, W) + assert stacked_obs.dtype == space.dtype + assert np.array_equal( + stacked_obs, + np.concatenate((np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=1), + ) + + +def test_reset_update_image_channel_last(): + space = spaces.Box(0, 255, (H, W, C), dtype=np.uint8) + stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space) + observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)]) + stacked_obs = stacked_observations.reset(observations_1) + assert stacked_obs.shape == (NUM_ENVS, H, W, N_STACK * C) + assert stacked_obs.dtype == space.dtype + observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)]) + dones = np.zeros((NUM_ENVS,), dtype=bool) + infos = [{} for _ in range(NUM_ENVS)] + stacked_obs, infos = stacked_observations.update(observations_2, dones, infos) + assert stacked_obs.shape == (NUM_ENVS, H, W, N_STACK * C) + assert stacked_obs.dtype == space.dtype + assert np.array_equal( + stacked_obs, + np.concatenate( + (np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=-1 + ), + ) + + +def test_reset_update_image_channel_first_stack_last(): + space = spaces.Box(0, 255, (C, H, W), dtype=np.uint8) + stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space, channels_order="last") + observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)]) + stacked_obs = stacked_observations.reset(observations_1) + assert stacked_obs.shape == (NUM_ENVS, C, H, N_STACK * W) + assert stacked_obs.dtype == space.dtype + observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)]) + dones = np.zeros((NUM_ENVS,), dtype=bool) + infos = [{} for _ in range(NUM_ENVS)] + stacked_obs, infos = stacked_observations.update(observations_2, dones, infos) + assert stacked_obs.shape == (NUM_ENVS, C, H, N_STACK * W) + assert stacked_obs.dtype == space.dtype + assert np.array_equal( + stacked_obs, + np.concatenate( + (np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=-1 + ), + ) + + +def test_reset_update_image_channel_last_stack_first(): + space = spaces.Box(0, 255, (H, W, C), dtype=np.uint8) + stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space, channels_order="first") + observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)]) + stacked_obs = stacked_observations.reset(observations_1) + assert stacked_obs.shape == (NUM_ENVS, N_STACK * H, W, C) + assert stacked_obs.dtype == space.dtype + observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)]) + dones = np.zeros((NUM_ENVS,), dtype=bool) + infos = [{} for _ in range(NUM_ENVS)] + stacked_obs, infos = stacked_observations.update(observations_2, dones, infos) + assert stacked_obs.shape == (NUM_ENVS, N_STACK * H, W, C) + assert stacked_obs.dtype == space.dtype + assert np.array_equal( + stacked_obs, + np.concatenate((np.zeros_like(observations_1), np.zeros_like(observations_1), observations_1, observations_2), axis=1), + ) + + +def test_reset_update_dict(): + space = spaces.Dict({"key1": spaces.Box(0, 255, (H, W, C), dtype=np.uint8), "key2": spaces.Box(-1, 1, (4, 5))}) + stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space, channels_order={"key1": "first", "key2": "last"}) + observations_1 = {key: np.stack([subspace.sample() for _ in range(NUM_ENVS)]) for key, subspace in space.spaces.items()} + stacked_obs = stacked_observations.reset(observations_1) + assert isinstance(stacked_obs, dict) + assert stacked_obs["key1"].shape == (NUM_ENVS, N_STACK * H, W, C) + assert stacked_obs["key2"].shape == (NUM_ENVS, 4, N_STACK * 5) + assert stacked_obs["key1"].dtype == space["key1"].dtype + assert stacked_obs["key2"].dtype == space["key2"].dtype + observations_2 = {key: np.stack([subspace.sample() for _ in range(NUM_ENVS)]) for key, subspace in space.spaces.items()} + dones = np.zeros((NUM_ENVS,), dtype=bool) + infos = [{} for _ in range(NUM_ENVS)] + stacked_obs, infos = stacked_observations.update(observations_2, dones, infos) + assert stacked_obs["key1"].shape == (NUM_ENVS, N_STACK * H, W, C) + assert stacked_obs["key2"].shape == (NUM_ENVS, 4, N_STACK * 5) + assert stacked_obs["key1"].dtype == space["key1"].dtype + assert stacked_obs["key2"].dtype == space["key2"].dtype + + assert np.array_equal( + stacked_obs["key1"], + np.concatenate( + ( + np.zeros_like(observations_1["key1"]), + np.zeros_like(observations_1["key1"]), + observations_1["key1"], + observations_2["key1"], + ), + axis=1, + ), + ) + assert np.array_equal( + stacked_obs["key2"], + np.concatenate( + ( + np.zeros_like(observations_1["key2"]), + np.zeros_like(observations_1["key2"]), + observations_1["key2"], + observations_2["key2"], + ), + axis=-1, + ), + ) + + +def test_episode_termination_box(): + space = spaces.Box(-1, 1, (4,)) + stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space) + observations_1 = np.stack([space.sample() for _ in range(NUM_ENVS)]) + stacked_observations.reset(observations_1) + observations_2 = np.stack([space.sample() for _ in range(NUM_ENVS)]) + dones = np.zeros((NUM_ENVS,), dtype=bool) + infos = [{} for _ in range(NUM_ENVS)] + stacked_observations.update(observations_2, dones, infos) + terminal_observation = space.sample() + infos[1]["terminal_observation"] = terminal_observation # episode termination in env1 + dones[1] = True + observations_3 = np.stack([space.sample() for _ in range(NUM_ENVS)]) + stacked_obs, infos = stacked_observations.update(observations_3, dones, infos) + zeros = np.zeros_like(observations_1[0]) + true_stacked_obs_env1 = np.concatenate((zeros, observations_1[0], observations_2[0], observations_3[0]), axis=-1) + true_stacked_obs_env2 = np.concatenate((zeros, zeros, zeros, observations_3[1]), axis=-1) + true_stacked_obs = np.stack((true_stacked_obs_env1, true_stacked_obs_env2)) + assert np.array_equal(true_stacked_obs, stacked_obs) + + +def test_episode_termination_dict(): + space = spaces.Dict({"key1": spaces.Box(0, 255, (H, W, 3), dtype=np.uint8), "key2": spaces.Box(-1, 1, (4, 5))}) + stacked_observations = StackedObservations(NUM_ENVS, N_STACK, space, channels_order={"key1": "first", "key2": "last"}) + observations_1 = {key: np.stack([subspace.sample() for _ in range(NUM_ENVS)]) for key, subspace in space.spaces.items()} + stacked_observations.reset(observations_1) + observations_2 = {key: np.stack([subspace.sample() for _ in range(NUM_ENVS)]) for key, subspace in space.spaces.items()} + dones = np.zeros((NUM_ENVS,), dtype=bool) + infos = [{} for _ in range(NUM_ENVS)] + stacked_observations.update(observations_2, dones, infos) + terminal_observation = space.sample() + infos[1]["terminal_observation"] = terminal_observation # episode termination in env1 + dones[1] = True + observations_3 = {key: np.stack([subspace.sample() for _ in range(NUM_ENVS)]) for key, subspace in space.spaces.items()} + stacked_obs, infos = stacked_observations.update(observations_3, dones, infos) + + for key, axis in zip(observations_1.keys(), [0, -1]): + zeros = np.zeros_like(observations_1[key][0]) + true_stacked_obs_env1 = np.concatenate( + (zeros, observations_1[key][0], observations_2[key][0], observations_3[key][0]), axis + ) + true_stacked_obs_env2 = np.concatenate((zeros, zeros, zeros, observations_3[key][1]), axis) + true_stacked_obs = np.stack((true_stacked_obs_env1, true_stacked_obs_env2)) + assert np.array_equal(true_stacked_obs, stacked_obs[key])