Allow any number of channels

This commit is contained in:
Antonin RAFFIN 2020-04-22 16:11:23 +02:00
parent f3cb0688c4
commit f38ddcb278
3 changed files with 16 additions and 17 deletions

View file

@ -180,7 +180,7 @@ def test_save_load_policy(model_class, policy_str):
# Avoid memory error when using replay buffer
# Reduce the size of the features
kwargs = dict(buffer_size=250)
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=3,
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2,
discrete=False)
env = DummyVecEnv([lambda: env])

View file

@ -62,8 +62,7 @@ class NatureCNN(BaseFeaturesExtractor):
# TODO: custom init?
# We assume CxWxH images (channels first)
# Re-ordering will be done by pre-preprocessing or wrapper
is_image_input = is_image_space(observation_space) or is_image_space(observation_space, channels_last=False)
assert is_image_input, ('You should use NatureCNN '
assert is_image_space(observation_space), ('You should use NatureCNN '
f'only with images not with {observation_space} '
'(you are probably using `CnnPolicy` instead of `MlpPolicy`)')
n_input_channels = observation_space.shape[0]
@ -203,7 +202,7 @@ class BasePolicy(nn.Module):
# Handle the different cases for images
# as PyTorch use channel first format
if is_image_space(self.observation_space, channels_last=False):
if is_image_space(self.observation_space):
if (observation.shape == self.observation_space.shape or
observation.shape[1:] == self.observation_space.shape):
pass

View file

@ -6,7 +6,9 @@ import torch.nn.functional as F
from gym import spaces
def is_image_space(observation_space: spaces.Space, channels_last: bool = True) -> bool:
def is_image_space(observation_space: spaces.Space,
channels_last: bool = True,
check_channels: bool = False) -> bool:
"""
Check if a observation space has the shape, limits and dtype
of a valid image.
@ -17,6 +19,8 @@ def is_image_space(observation_space: spaces.Space, channels_last: bool = True)
:param observation_space: (spaces.Space)
:param channels_last: (bool)
:param check_channels: (bool) Whether to do or not the check for the number of channels.
Because of frame-skip, the observation space may have more channels than expected.
:return: (bool)
"""
if isinstance(observation_space, spaces.Box) and len(observation_space.shape) == 3:
@ -28,11 +32,15 @@ def is_image_space(observation_space: spaces.Space, channels_last: bool = True)
if np.any(observation_space.low != 0) or np.any(observation_space.high != 255):
return False
# Skip channels check
if not check_channels:
return True
# Check the number of channels
if channels_last:
n_channels = observation_space.shape[-1]
else:
n_channels = observation_space.shape[0]
# RGB, RGBD, GrayScale
return n_channels in [1, 3, 4]
return False
@ -79,24 +87,16 @@ def get_obs_shape(observation_space: spaces.Space) -> Tuple[int, ...]:
raise NotImplementedError()
def get_flattened_obs_dim(observation_space: spaces.Space) -> Union[int, Tuple[int, ...]]:
def get_flattened_obs_dim(observation_space: spaces.Space) -> int:
"""
Get the dimension of the observation space when flattened.
It does not apply to image observation space.
:param observation_space: (spaces.Space)
:return: (Union[int, Tuple[int, ...]])
:return: (int)
"""
if isinstance(observation_space, spaces.Box):
# if is_image_space(observation_space):
# raise NotImplementedError()
return np.prod(observation_space.shape)
elif isinstance(observation_space, spaces.Discrete):
# Observation is a one hot vector
return observation_space.n
else:
# TODO: Multidiscrete, Binary, MultiBinary, Tuple, Dict
raise NotImplementedError()
# Use Gym internal method
return spaces.utils.flatdim(observation_space)
def get_action_dim(action_space: spaces.Space) -> int: