mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-06 00:03:28 +00:00
Allow any number of channels
This commit is contained in:
parent
f3cb0688c4
commit
f38ddcb278
3 changed files with 16 additions and 17 deletions
|
|
@ -180,7 +180,7 @@ def test_save_load_policy(model_class, policy_str):
|
|||
# Avoid memory error when using replay buffer
|
||||
# Reduce the size of the features
|
||||
kwargs = dict(buffer_size=250)
|
||||
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=3,
|
||||
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2,
|
||||
discrete=False)
|
||||
|
||||
env = DummyVecEnv([lambda: env])
|
||||
|
|
|
|||
|
|
@ -62,8 +62,7 @@ class NatureCNN(BaseFeaturesExtractor):
|
|||
# TODO: custom init?
|
||||
# We assume CxWxH images (channels first)
|
||||
# Re-ordering will be done by pre-preprocessing or wrapper
|
||||
is_image_input = is_image_space(observation_space) or is_image_space(observation_space, channels_last=False)
|
||||
assert is_image_input, ('You should use NatureCNN '
|
||||
assert is_image_space(observation_space), ('You should use NatureCNN '
|
||||
f'only with images not with {observation_space} '
|
||||
'(you are probably using `CnnPolicy` instead of `MlpPolicy`)')
|
||||
n_input_channels = observation_space.shape[0]
|
||||
|
|
@ -203,7 +202,7 @@ class BasePolicy(nn.Module):
|
|||
|
||||
# Handle the different cases for images
|
||||
# as PyTorch use channel first format
|
||||
if is_image_space(self.observation_space, channels_last=False):
|
||||
if is_image_space(self.observation_space):
|
||||
if (observation.shape == self.observation_space.shape or
|
||||
observation.shape[1:] == self.observation_space.shape):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -6,7 +6,9 @@ import torch.nn.functional as F
|
|||
from gym import spaces
|
||||
|
||||
|
||||
def is_image_space(observation_space: spaces.Space, channels_last: bool = True) -> bool:
|
||||
def is_image_space(observation_space: spaces.Space,
|
||||
channels_last: bool = True,
|
||||
check_channels: bool = False) -> bool:
|
||||
"""
|
||||
Check if a observation space has the shape, limits and dtype
|
||||
of a valid image.
|
||||
|
|
@ -17,6 +19,8 @@ def is_image_space(observation_space: spaces.Space, channels_last: bool = True)
|
|||
|
||||
:param observation_space: (spaces.Space)
|
||||
:param channels_last: (bool)
|
||||
:param check_channels: (bool) Whether to do or not the check for the number of channels.
|
||||
Because of frame-skip, the observation space may have more channels than expected.
|
||||
:return: (bool)
|
||||
"""
|
||||
if isinstance(observation_space, spaces.Box) and len(observation_space.shape) == 3:
|
||||
|
|
@ -28,11 +32,15 @@ def is_image_space(observation_space: spaces.Space, channels_last: bool = True)
|
|||
if np.any(observation_space.low != 0) or np.any(observation_space.high != 255):
|
||||
return False
|
||||
|
||||
# Skip channels check
|
||||
if not check_channels:
|
||||
return True
|
||||
# Check the number of channels
|
||||
if channels_last:
|
||||
n_channels = observation_space.shape[-1]
|
||||
else:
|
||||
n_channels = observation_space.shape[0]
|
||||
# RGB, RGBD, GrayScale
|
||||
return n_channels in [1, 3, 4]
|
||||
return False
|
||||
|
||||
|
|
@ -79,24 +87,16 @@ def get_obs_shape(observation_space: spaces.Space) -> Tuple[int, ...]:
|
|||
raise NotImplementedError()
|
||||
|
||||
|
||||
def get_flattened_obs_dim(observation_space: spaces.Space) -> Union[int, Tuple[int, ...]]:
|
||||
def get_flattened_obs_dim(observation_space: spaces.Space) -> int:
|
||||
"""
|
||||
Get the dimension of the observation space when flattened.
|
||||
It does not apply to image observation space.
|
||||
|
||||
:param observation_space: (spaces.Space)
|
||||
:return: (Union[int, Tuple[int, ...]])
|
||||
:return: (int)
|
||||
"""
|
||||
if isinstance(observation_space, spaces.Box):
|
||||
# if is_image_space(observation_space):
|
||||
# raise NotImplementedError()
|
||||
return np.prod(observation_space.shape)
|
||||
elif isinstance(observation_space, spaces.Discrete):
|
||||
# Observation is a one hot vector
|
||||
return observation_space.n
|
||||
else:
|
||||
# TODO: Multidiscrete, Binary, MultiBinary, Tuple, Dict
|
||||
raise NotImplementedError()
|
||||
# Use Gym internal method
|
||||
return spaces.utils.flatdim(observation_space)
|
||||
|
||||
|
||||
def get_action_dim(action_space: spaces.Space) -> int:
|
||||
|
|
|
|||
Loading…
Reference in a new issue