From 8452106734ba1749cc4ddd5ae9fe7fd28ca55bf7 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 20 Dec 2022 13:18:28 +0100 Subject: [PATCH] Fix support of image like normalized inputs (#1214) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix support of image like normalized inputs * Improve docstring and warning message. * Don't check if obs is image when normalize_images is False (lil opt) * Comment fix * Fix normalize_images not passed to parent * Check for subclasses too * Remove useless multiline * Update version and add comment * Fix some typos Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- docs/common/logger.rst | 2 +- docs/guide/custom_env.rst | 10 ++-- docs/guide/custom_policy.rst | 6 +-- docs/guide/examples.rst | 2 +- docs/guide/export.rst | 2 +- docs/misc/changelog.rst | 9 +++- docs/misc/projects.rst | 2 +- stable_baselines3/common/env_checker.py | 10 ++-- stable_baselines3/common/logger.py | 2 +- stable_baselines3/common/policies.py | 5 +- stable_baselines3/common/preprocessing.py | 15 ++++-- stable_baselines3/common/torch_layers.py | 39 +++++++++++---- .../common/vec_env/stacked_observations.py | 2 +- .../common/vec_env/vec_normalize.py | 27 +++++++++++ stable_baselines3/dqn/policies.py | 3 +- stable_baselines3/sac/policies.py | 1 + stable_baselines3/td3/policies.py | 3 +- stable_baselines3/version.txt | 2 +- tests/test_cnn.py | 47 ++++++++++++++++++- tests/test_dict_env.py | 11 ++++- 20 files changed, 163 insertions(+), 37 deletions(-) diff --git a/docs/common/logger.rst b/docs/common/logger.rst index 70c0379..0aa64f8 100644 --- a/docs/common/logger.rst +++ b/docs/common/logger.rst @@ -105,7 +105,7 @@ train/ - ``loss``: Current total loss value - ``n_updates``: Number of gradient updates applied so far - ``policy_gradient_loss``: Current value of the policy gradient loss (its value does not have much meaning) -- ``value_loss``: Current value for the value function loss for on-policy algorithms, usually error between value function output and Monte-Carle estimate (or TD(lambda) estimate) +- ``value_loss``: Current value for the value function loss for on-policy algorithms, usually error between value function output and Monte-Carlo estimate (or TD(lambda) estimate) - ``std``: Current standard deviation of the noise when using generalized State-Dependent Exploration (gSDE) diff --git a/docs/guide/custom_env.rst b/docs/guide/custom_env.rst index 9fbb527..a561b2d 100644 --- a/docs/guide/custom_env.rst +++ b/docs/guide/custom_env.rst @@ -8,9 +8,13 @@ That is to say, your environment must implement the following methods (and inher .. note:: - If you are using images as input, the observation must be of type ``np.uint8`` and be contained in [0, 255] - is normalized (dividing by 255 to have values in [0, 1]) when using CNN policies. Images can be either - channel-first or channel-last. + If you are using images as input, the observation must be of type ``np.uint8`` and be contained in [0, 255]. + By default, the observation is normalized by SB3 pre-processing (dividing by 255 to have values in [0, 1]) when using CNN policies. + Images can be either channel-first or channel-last. + + If you want to use ``CnnPolicy`` or ``MultiInputPolicy`` with image-like observation (3D tensor) that are already normalized, you must pass ``normalize_images=False`` + to the policy (using ``policy_kwargs`` parameter, ``policy_kwargs=dict(normalize_images=False)``) + and make sure your image is in the **channel-first** format. .. note:: diff --git a/docs/guide/custom_policy.rst b/docs/guide/custom_policy.rst index fb25c18..5db97b1 100644 --- a/docs/guide/custom_policy.rst +++ b/docs/guide/custom_policy.rst @@ -139,7 +139,7 @@ that derives from ``BaseFeaturesExtractor`` and then pass it to the model when t """ def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256): - super(CustomCNN, self).__init__(observation_space, features_dim) + super().__init__(observation_space, features_dim) # We assume CxHxW images (channels first) # Re-ordering will be done by pre-preprocessing or wrapper n_input_channels = observation_space.shape[0] @@ -201,7 +201,7 @@ downsampling and "vector" with a single linear layer. # We do not know features-dim here before going over all the items, # so put something dummy for now. PyTorch requires calling # nn.Module.__init__ before adding modules - super(CustomCombinedExtractor, self).__init__(observation_space, features_dim=1) + super().__init__(observation_space, features_dim=1) extractors = {} @@ -374,7 +374,7 @@ If your task requires even more granular control over the policy/value architect **kwargs, ): - super(CustomActorCriticPolicy, self).__init__( + super().__init__( observation_space, action_space, lr_schedule, diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 47bdc40..228dad0 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -194,7 +194,7 @@ You can use environments with dictionary observation spaces. This is useful in t concatenate observations such as an image from a camera combined with a vector of servo sensor data (e.g., rotation angles). Stable Baselines3 provides ``SimpleMultiObsEnv`` as an example of this kind of of setting. The environment is a simple grid world but the observations for each cell come in the form of dictionaries. -These dictionaries are randomly initilaized on the creation of the environment and contain a vector observation and an image observation. +These dictionaries are randomly initialized on the creation of the environment and contain a vector observation and an image observation. .. code-block:: python diff --git a/docs/guide/export.rst b/docs/guide/export.rst index 3a21749..b50c484 100644 --- a/docs/guide/export.rst +++ b/docs/guide/export.rst @@ -35,7 +35,7 @@ As of June 2021, ONNX format `doesn't support None: """ Check that the input will be compatible with Stable-Baselines when the observation is apparently an image. + + :param observation_space: Observation space + :key: When the observation space comes from a Dict space, we pass the + corresponding key to have more precise warning messages. Defaults to "". """ if observation_space.dtype != np.uint8: warnings.warn( - f"It seems that your observation {key} is an image but the `dtype` " - "of your observation_space is not `np.uint8`. " + f"It seems that your observation {key} is an image but its `dtype` " + f"is ({observation_space.dtype}) whereas it has to be `np.uint8`. " "If your observation is not an image, we recommend you to flatten the observation " "to have only a 1D vector" ) @@ -180,7 +184,7 @@ def _check_box_obs(observation_space: spaces.Box, key: str = "") -> None: # If image, check the low and high values, the type and the number of channels # and the shape (minimal value) if len(observation_space.shape) == 3: - _check_image_input(observation_space) + _check_image_input(observation_space, key) if len(observation_space.shape) not in [1, 3]: warnings.warn( diff --git a/stable_baselines3/common/logger.py b/stable_baselines3/common/logger.py index 51b6f6b..e065992 100644 --- a/stable_baselines3/common/logger.py +++ b/stable_baselines3/common/logger.py @@ -73,7 +73,7 @@ class Image: class HParam: """ - Hyperparameter data class storing hyperparameters and metrics in dictionnaries + Hyperparameter data class storing hyperparameters and metrics in dictionaries :param hparam_dict: key-value pairs of hyperparameters to log :param metric_dict: key-value pairs of metrics to log diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 07287f0..3e8c920 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -87,6 +87,9 @@ class BaseModel(nn.Module): self.features_extractor_class = features_extractor_class self.features_extractor_kwargs = features_extractor_kwargs + # Automatically deactivate dtype and bounds checks + if normalize_images is False and issubclass(features_extractor_class, (NatureCNN, CombinedExtractor)): + self.features_extractor_kwargs.update(dict(normalized_image=True)) def _update_features_extractor( self, @@ -430,6 +433,7 @@ class ActorCriticPolicy(BasePolicy): optimizer_class=optimizer_class, optimizer_kwargs=optimizer_kwargs, squash_output=squash_output, + normalize_images=normalize_images, ) # Default network architecture, from stable-baselines @@ -446,7 +450,6 @@ class ActorCriticPolicy(BasePolicy): self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs) self.features_dim = self.features_extractor.features_dim - self.normalize_images = normalize_images self.log_std_init = log_std_init dist_kwargs = None # Keyword arguments for gSDE distribution diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py index 76dab70..e280ed7 100644 --- a/stable_baselines3/common/preprocessing.py +++ b/stable_baselines3/common/preprocessing.py @@ -27,6 +27,7 @@ def is_image_space_channels_first(observation_space: spaces.Box) -> bool: def is_image_space( observation_space: spaces.Space, check_channels: bool = False, + normalized_image: bool = False, ) -> bool: """ Check if a observation space has the shape, limits and dtype @@ -38,15 +39,21 @@ def is_image_space( :param observation_space: :param check_channels: Whether to do or not the check for the number of channels. e.g., with frame-stacking, the observation space may have more channels than expected. + :param normalized_image: Whether to assume that the image is already normalized + or not (this disables dtype and bounds checks): when True, it only checks that + the space is a Box and has 3 dimensions. + Otherwise, it checks that it has expected dtype (uint8) and bounds (values in [0, 255]). :return: """ + check_dtype = check_bounds = not normalized_image if isinstance(observation_space, spaces.Box) and len(observation_space.shape) == 3: # Check the type - if observation_space.dtype != np.uint8: + if check_dtype and observation_space.dtype != np.uint8: return False # Check the value range - if np.any(observation_space.low != 0) or np.any(observation_space.high != 255): + incorrect_bounds = np.any(observation_space.low != 0) or np.any(observation_space.high != 255) + if check_bounds and incorrect_bounds: return False # Skip channels check @@ -57,7 +64,7 @@ def is_image_space( n_channels = observation_space.shape[0] else: n_channels = observation_space.shape[-1] - # RGB, RGBD, GrayScale + # GrayScale, RGB, RGBD return n_channels in [1, 3, 4] return False @@ -99,7 +106,7 @@ def preprocess_obs( :return: """ if isinstance(observation_space, spaces.Box): - if is_image_space(observation_space) and normalize_images: + if normalize_images and is_image_space(observation_space): return obs.float() / 255.0 return obs.float() diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index 5de2af1..abc8bce 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -18,7 +18,7 @@ class BaseFeaturesExtractor(nn.Module): :param features_dim: Number of features extracted. """ - def __init__(self, observation_space: gym.Space, features_dim: int = 0): + def __init__(self, observation_space: gym.Space, features_dim: int = 0) -> None: super().__init__() assert features_dim > 0 self._observation_space = observation_space @@ -37,7 +37,7 @@ class FlattenExtractor(BaseFeaturesExtractor): :param observation_space: """ - def __init__(self, observation_space: gym.Space): + def __init__(self, observation_space: gym.Space) -> None: super().__init__(observation_space, get_flattened_obs_dim(observation_space)) self.flatten = nn.Flatten() @@ -55,19 +55,31 @@ class NatureCNN(BaseFeaturesExtractor): :param observation_space: :param features_dim: Number of features extracted. This corresponds to the number of unit for the last layer. + :param normalized_image: Whether to assume that the image is already normalized + or not (this disables dtype and bounds checks): when True, it only checks that + the space is a Box and has 3 dimensions. + Otherwise, it checks that it has expected dtype (uint8) and bounds (values in [0, 255]). """ - def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512): + def __init__( + self, + observation_space: gym.spaces.Box, + features_dim: int = 512, + normalized_image: bool = False, + ) -> None: super().__init__(observation_space, features_dim) # We assume CxHxW images (channels first) # Re-ordering will be done by pre-preprocessing or wrapper - assert is_image_space(observation_space, check_channels=False), ( + assert is_image_space(observation_space, check_channels=False, normalized_image=normalized_image), ( "You should use NatureCNN " f"only with images not with {observation_space}\n" "(you are probably using `CnnPolicy` instead of `MlpPolicy` or `MultiInputPolicy`)\n" "If you are using a custom environment,\n" "please check it using our env checker:\n" - "https://stable-baselines3.readthedocs.io/en/master/common/env_checker.html" + "https://stable-baselines3.readthedocs.io/en/master/common/env_checker.html.\n" + "If you are using `VecNormalize` or already normalized channel-first images " + "you should pass `normalize_images=False`: \n" + "https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html" ) n_input_channels = observation_space.shape[0] self.cnn = nn.Sequential( @@ -167,7 +179,7 @@ class MlpExtractor(nn.Module): net_arch: List[Union[int, Dict[str, List[int]]]], activation_fn: Type[nn.Module], device: Union[th.device, str] = "auto", - ): + ) -> None: super().__init__() device = get_device(device) shared_net: List[nn.Module] = [] @@ -247,9 +259,18 @@ class CombinedExtractor(BaseFeaturesExtractor): :param observation_space: :param cnn_output_dim: Number of features to output from each CNN submodule(s). Defaults to 256 to avoid exploding network sizes. + :param normalized_image: Whether to assume that the image is already normalized + or not (this disables dtype and bounds checks): when True, it only checks that + the space is a Box and has 3 dimensions. + Otherwise, it checks that it has expected dtype (uint8) and bounds (values in [0, 255]). """ - def __init__(self, observation_space: gym.spaces.Dict, cnn_output_dim: int = 256): + def __init__( + self, + observation_space: gym.spaces.Dict, + cnn_output_dim: int = 256, + normalized_image: bool = False, + ) -> None: # TODO we do not know features-dim here before going over all the items, so put something there. This is dirty! super().__init__(observation_space, features_dim=1) @@ -257,8 +278,8 @@ class CombinedExtractor(BaseFeaturesExtractor): total_concat_size = 0 for key, subspace in observation_space.spaces.items(): - if is_image_space(subspace): - extractors[key] = NatureCNN(subspace, features_dim=cnn_output_dim) + if is_image_space(subspace, normalized_image=normalized_image): + extractors[key] = NatureCNN(subspace, features_dim=cnn_output_dim, normalized_image=normalized_image) total_concat_size += cnn_output_dim else: # The observation key is a vector, flatten it if needed diff --git a/stable_baselines3/common/vec_env/stacked_observations.py b/stable_baselines3/common/vec_env/stacked_observations.py index 88d725e..8583518 100644 --- a/stable_baselines3/common/vec_env/stacked_observations.py +++ b/stable_baselines3/common/vec_env/stacked_observations.py @@ -187,7 +187,7 @@ class StackedDictObservations(StackedObservations): def stack_observation_space(self, observation_space: spaces.Dict) -> spaces.Dict: """ - Returns the stacked verson of a Dict observation space + Returns the stacked version of a Dict observation space :param observation_space: Dict observation space to stack :return: stacked observation space diff --git a/stable_baselines3/common/vec_env/vec_normalize.py b/stable_baselines3/common/vec_env/vec_normalize.py index 73c890c..ad400d1 100644 --- a/stable_baselines3/common/vec_env/vec_normalize.py +++ b/stable_baselines3/common/vec_env/vec_normalize.py @@ -6,6 +6,7 @@ import gym import numpy as np from stable_baselines3.common import utils +from stable_baselines3.common.preprocessing import is_image_space from stable_baselines3.common.running_mean_std import RunningMeanStd from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper @@ -50,9 +51,35 @@ class VecNormalize(VecEnvWrapper): if isinstance(self.observation_space, gym.spaces.Dict): self.obs_spaces = self.observation_space.spaces self.obs_rms = {key: RunningMeanStd(shape=self.obs_spaces[key].shape) for key in self.norm_obs_keys} + # Update observation space when using image + # See explanation below and GH #1214 + for key in self.obs_rms.keys(): + if is_image_space(self.obs_spaces[key]): + self.observation_space.spaces[key] = gym.spaces.Box( + low=-clip_obs, + high=clip_obs, + shape=self.obs_spaces[key].shape, + dtype=np.float32, + ) + else: self.obs_spaces = None self.obs_rms = RunningMeanStd(shape=self.observation_space.shape) + # Update observation space when using image + # See GH #1214 + # This is to raise proper error when + # VecNormalize is used with an image-like input and + # normalize_images=True. + # For correctness, we should also update the bounds + # in other cases but this will cause backward-incompatible change + # and break already saved policies. + if is_image_space(self.observation_space): + self.observation_space = gym.spaces.Box( + low=-clip_obs, + high=clip_obs, + shape=self.observation_space.shape, + dtype=np.float32, + ) self.ret_rms = RunningMeanStd(shape=()) self.clip_obs = clip_obs diff --git a/stable_baselines3/dqn/policies.py b/stable_baselines3/dqn/policies.py index ed3497c..ee0c00c 100644 --- a/stable_baselines3/dqn/policies.py +++ b/stable_baselines3/dqn/policies.py @@ -51,7 +51,6 @@ class QNetwork(BasePolicy): self.activation_fn = activation_fn self.features_extractor = features_extractor self.features_dim = features_dim - self.normalize_images = normalize_images action_dim = self.action_space.n # number of actions q_net = create_mlp(self.features_dim, action_dim, self.net_arch, self.activation_fn) self.q_net = nn.Sequential(*q_net) @@ -125,6 +124,7 @@ class DQNPolicy(BasePolicy): features_extractor_kwargs, optimizer_class=optimizer_class, optimizer_kwargs=optimizer_kwargs, + normalize_images=normalize_images, ) if net_arch is None: @@ -135,7 +135,6 @@ class DQNPolicy(BasePolicy): self.net_arch = net_arch self.activation_fn = activation_fn - self.normalize_images = normalize_images self.net_args = { "observation_space": self.observation_space, diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py index ac93249..6f769c1 100644 --- a/stable_baselines3/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -232,6 +232,7 @@ class SACPolicy(BasePolicy): optimizer_class=optimizer_class, optimizer_kwargs=optimizer_kwargs, squash_output=True, + normalize_images=normalize_images, ) if net_arch is None: diff --git a/stable_baselines3/td3/policies.py b/stable_baselines3/td3/policies.py index 8781b32..ce908a6 100644 --- a/stable_baselines3/td3/policies.py +++ b/stable_baselines3/td3/policies.py @@ -129,6 +129,7 @@ class TD3Policy(BasePolicy): optimizer_class=optimizer_class, optimizer_kwargs=optimizer_kwargs, squash_output=True, + normalize_images=normalize_images, ) # Default network architecture, from the original paper @@ -177,7 +178,7 @@ class TD3Policy(BasePolicy): if self.share_features_extractor: self.critic = self.make_critic(features_extractor=self.actor.features_extractor) - # Critic target should not share the features extactor with critic + # Critic target should not share the features extractor with critic # but it can share it with the actor target as actor and critic are sharing # the same features_extractor too # NOTE: as a result the effective poliak (soft-copy) coefficient for the features extractor diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 21392a8..0f2c082 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.7.0a6 +1.7.0a7 diff --git a/tests/test_cnn.py b/tests/test_cnn.py index 03f089d..432016a 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -10,7 +10,7 @@ from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 from stable_baselines3.common.envs import FakeImageEnv from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first from stable_baselines3.common.utils import zip_strict -from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecTransposeImage, is_vecenv_wrapped +from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNormalize, VecTransposeImage, is_vecenv_wrapped @pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN]) @@ -269,6 +269,10 @@ def test_image_space_checks(): not_image_space = spaces.Box(0, 10, shape=(10, 10, 3), dtype=np.uint8) assert not is_image_space(not_image_space) + # Deactivate dtype and bound checking + normalized_image = spaces.Box(0, 1, shape=(10, 10, 3), dtype=np.float32) + assert is_image_space(normalized_image, normalized_image=True) + # Not correct space not_image_space = spaces.Discrete(n=10) assert not is_image_space(not_image_space) @@ -297,3 +301,44 @@ def test_image_space_checks(): # Should raise a warning with pytest.warns(Warning): assert not is_image_space_channels_first(channel_mid_space) + + +@pytest.mark.parametrize("model_class", [A2C, PPO, DQN, SAC, TD3]) +@pytest.mark.parametrize("normalize_images", [True, False]) +def test_image_like_input(model_class, normalize_images): + """ + Check that we can handle image-like input (3D tensor) + when normalize_images=False + """ + # Fake grayscale with frameskip + # Atari after preprocessing: 84x84x1, here we are using lower resolution + # to check that the network handle it automatically + env = FakeImageEnv( + screen_height=36, + screen_width=36, + n_channels=1, + channel_first=True, + discrete=model_class not in {SAC, TD3}, + ) + vec_env = VecNormalize(DummyVecEnv([lambda: env])) + # Reduce the size of the features + # deactivate normalization + kwargs = dict( + policy_kwargs=dict( + normalize_images=normalize_images, + features_extractor_kwargs=dict(features_dim=32), + ), + seed=1, + ) + + if model_class in {A2C, PPO}: + kwargs.update(dict(n_steps=64)) + else: + # Avoid memory error when using replay buffer + # Reduce the size of the features + kwargs.update(dict(buffer_size=250)) + if normalize_images: + with pytest.raises(AssertionError): + model_class("CnnPolicy", vec_env, **kwargs).learn(128) + else: + model_class("CnnPolicy", vec_env, **kwargs).learn(128) diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index 93b13b4..2c114f6 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -28,8 +28,8 @@ class DummyDictEnv(gym.Env): else: self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) N_CHANNELS = 1 - HEIGHT = 64 - WIDTH = 64 + HEIGHT = 36 + WIDTH = 36 if channel_last: obs_shape = (HEIGHT, WIDTH, N_CHANNELS) @@ -323,3 +323,10 @@ def test_dict_nested(): with pytest.raises(NotImplementedError): env = DummyVecEnv([lambda: DummyDictEnv(nested_dict_obs=True)]) + + +def test_vec_normalize_image(): + env = VecNormalize(DummyVecEnv([lambda: DummyDictEnv()]), norm_obs_keys=["img"]) + assert env.observation_space.spaces["img"].dtype == np.float32 + assert (env.observation_space.spaces["img"].low == -env.clip_obs).all() + assert (env.observation_space.spaces["img"].high == env.clip_obs).all()