From f38ddcb2787b86bcc288df31cb048ba62802970e Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Wed, 22 Apr 2020 16:11:23 +0200 Subject: [PATCH] Allow any number of channels --- tests/test_save_load.py | 2 +- torchy_baselines/common/policies.py | 5 ++--- torchy_baselines/common/preprocessing.py | 26 ++++++++++++------------ 3 files changed, 16 insertions(+), 17 deletions(-) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index b1005cc..5e637bc 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -180,7 +180,7 @@ def test_save_load_policy(model_class, policy_str): # Avoid memory error when using replay buffer # Reduce the size of the features kwargs = dict(buffer_size=250) - env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=3, + env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=False) env = DummyVecEnv([lambda: env]) diff --git a/torchy_baselines/common/policies.py b/torchy_baselines/common/policies.py index 701cc9d..fc3f7a6 100644 --- a/torchy_baselines/common/policies.py +++ b/torchy_baselines/common/policies.py @@ -62,8 +62,7 @@ class NatureCNN(BaseFeaturesExtractor): # TODO: custom init? # We assume CxWxH images (channels first) # Re-ordering will be done by pre-preprocessing or wrapper - is_image_input = is_image_space(observation_space) or is_image_space(observation_space, channels_last=False) - assert is_image_input, ('You should use NatureCNN ' + assert is_image_space(observation_space), ('You should use NatureCNN ' f'only with images not with {observation_space} ' '(you are probably using `CnnPolicy` instead of `MlpPolicy`)') n_input_channels = observation_space.shape[0] @@ -203,7 +202,7 @@ class BasePolicy(nn.Module): # Handle the different cases for images # as PyTorch use channel first format - if is_image_space(self.observation_space, channels_last=False): + if is_image_space(self.observation_space): if (observation.shape == self.observation_space.shape or observation.shape[1:] == self.observation_space.shape): pass diff --git a/torchy_baselines/common/preprocessing.py b/torchy_baselines/common/preprocessing.py index 2caebcc..af81684 100644 --- a/torchy_baselines/common/preprocessing.py +++ b/torchy_baselines/common/preprocessing.py @@ -6,7 +6,9 @@ import torch.nn.functional as F from gym import spaces -def is_image_space(observation_space: spaces.Space, channels_last: bool = True) -> bool: +def is_image_space(observation_space: spaces.Space, + channels_last: bool = True, + check_channels: bool = False) -> bool: """ Check if a observation space has the shape, limits and dtype of a valid image. @@ -17,6 +19,8 @@ def is_image_space(observation_space: spaces.Space, channels_last: bool = True) :param observation_space: (spaces.Space) :param channels_last: (bool) + :param check_channels: (bool) Whether to do or not the check for the number of channels. + Because of frame-skip, the observation space may have more channels than expected. :return: (bool) """ if isinstance(observation_space, spaces.Box) and len(observation_space.shape) == 3: @@ -28,11 +32,15 @@ def is_image_space(observation_space: spaces.Space, channels_last: bool = True) if np.any(observation_space.low != 0) or np.any(observation_space.high != 255): return False + # Skip channels check + if not check_channels: + return True # Check the number of channels if channels_last: n_channels = observation_space.shape[-1] else: n_channels = observation_space.shape[0] + # RGB, RGBD, GrayScale return n_channels in [1, 3, 4] return False @@ -79,24 +87,16 @@ def get_obs_shape(observation_space: spaces.Space) -> Tuple[int, ...]: raise NotImplementedError() -def get_flattened_obs_dim(observation_space: spaces.Space) -> Union[int, Tuple[int, ...]]: +def get_flattened_obs_dim(observation_space: spaces.Space) -> int: """ Get the dimension of the observation space when flattened. It does not apply to image observation space. :param observation_space: (spaces.Space) - :return: (Union[int, Tuple[int, ...]]) + :return: (int) """ - if isinstance(observation_space, spaces.Box): - # if is_image_space(observation_space): - # raise NotImplementedError() - return np.prod(observation_space.shape) - elif isinstance(observation_space, spaces.Discrete): - # Observation is a one hot vector - return observation_space.n - else: - # TODO: Multidiscrete, Binary, MultiBinary, Tuple, Dict - raise NotImplementedError() + # Use Gym internal method + return spaces.utils.flatdim(observation_space) def get_action_dim(action_space: spaces.Space) -> int: