Fix support of image like normalized inputs (#1214)

* Fix support of image like normalized inputs

* Improve docstring and warning message.

* Don't check if obs is image when normalize_images is False (lil opt)

* Comment fix

* Fix normalize_images not passed to parent

* Check for subclasses too

* Remove useless multiline

* Update version and add comment

* Fix some typos

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
This commit is contained in:
Antonin RAFFIN 2022-12-20 13:18:28 +01:00 committed by GitHub
parent ca944fed2d
commit 8452106734
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 163 additions and 37 deletions

View file

@ -105,7 +105,7 @@ train/
- ``loss``: Current total loss value
- ``n_updates``: Number of gradient updates applied so far
- ``policy_gradient_loss``: Current value of the policy gradient loss (its value does not have much meaning)
- ``value_loss``: Current value for the value function loss for on-policy algorithms, usually error between value function output and Monte-Carle estimate (or TD(lambda) estimate)
- ``value_loss``: Current value for the value function loss for on-policy algorithms, usually error between value function output and Monte-Carlo estimate (or TD(lambda) estimate)
- ``std``: Current standard deviation of the noise when using generalized State-Dependent Exploration (gSDE)

View file

@ -8,9 +8,13 @@ That is to say, your environment must implement the following methods (and inher
.. note::
If you are using images as input, the observation must be of type ``np.uint8`` and be contained in [0, 255]
is normalized (dividing by 255 to have values in [0, 1]) when using CNN policies. Images can be either
channel-first or channel-last.
If you are using images as input, the observation must be of type ``np.uint8`` and be contained in [0, 255].
By default, the observation is normalized by SB3 pre-processing (dividing by 255 to have values in [0, 1]) when using CNN policies.
Images can be either channel-first or channel-last.
If you want to use ``CnnPolicy`` or ``MultiInputPolicy`` with image-like observation (3D tensor) that are already normalized, you must pass ``normalize_images=False``
to the policy (using ``policy_kwargs`` parameter, ``policy_kwargs=dict(normalize_images=False)``)
and make sure your image is in the **channel-first** format.
.. note::

View file

@ -139,7 +139,7 @@ that derives from ``BaseFeaturesExtractor`` and then pass it to the model when t
"""
def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256):
super(CustomCNN, self).__init__(observation_space, features_dim)
super().__init__(observation_space, features_dim)
# We assume CxHxW images (channels first)
# Re-ordering will be done by pre-preprocessing or wrapper
n_input_channels = observation_space.shape[0]
@ -201,7 +201,7 @@ downsampling and "vector" with a single linear layer.
# We do not know features-dim here before going over all the items,
# so put something dummy for now. PyTorch requires calling
# nn.Module.__init__ before adding modules
super(CustomCombinedExtractor, self).__init__(observation_space, features_dim=1)
super().__init__(observation_space, features_dim=1)
extractors = {}
@ -374,7 +374,7 @@ If your task requires even more granular control over the policy/value architect
**kwargs,
):
super(CustomActorCriticPolicy, self).__init__(
super().__init__(
observation_space,
action_space,
lr_schedule,

View file

@ -194,7 +194,7 @@ You can use environments with dictionary observation spaces. This is useful in t
concatenate observations such as an image from a camera combined with a vector of servo sensor data (e.g., rotation angles).
Stable Baselines3 provides ``SimpleMultiObsEnv`` as an example of this kind of of setting.
The environment is a simple grid world but the observations for each cell come in the form of dictionaries.
These dictionaries are randomly initilaized on the creation of the environment and contain a vector observation and an image observation.
These dictionaries are randomly initialized on the creation of the environment and contain a vector observation and an image observation.
.. code-block:: python

View file

@ -35,7 +35,7 @@ As of June 2021, ONNX format `doesn't support <https://github.com/onnx/onnx/iss
The following examples are for ``MlpPolicy`` only, and are general examples. Note that you have to preprocess the observation the same way stable-baselines3 agent does (see ``common.preprocessing.preprocess_obs``).
For PPO, assuming a shared feature extactor.
For PPO, assuming a shared feature extractor.
.. warning::

View file

@ -4,7 +4,7 @@ Changelog
==========
Release 1.7.0a6 (WIP)
Release 1.7.0a7 (WIP)
--------------------------
Breaking Changes:
@ -13,12 +13,16 @@ Breaking Changes:
please use an ``EvalCallback`` instead
- Removed deprecated ``sde_net_arch`` parameter
- Removed ``ret`` attributes in ``VecNormalize``, please use ``returns`` instead
- ``VecNormalize`` now updates the observation space when normalizing images
New Features:
^^^^^^^^^^^^^
- Introduced mypy type checking
- Added ``with_bias`` argument to ``create_mlp``
- Added support for multidimensional ``spaces.MultiBinary`` observations
- Features extractors now properly support unnormalized image-like observations (3D tensor)
when passing ``normalize_images=False``
- Added ``normalized_image`` parameter to ``NatureCNN`` and ``CombinedExtractor``
SB3-Contrib
^^^^^^^^^^^
@ -31,6 +35,8 @@ Bug Fixes:
- Raise an error when the same gym environment instance is passed as separate environments when creating a vectorized environment with more than one environment. (@Rocamonde)
- Fix type annotation of ``model`` in ``evaluate_policy``
- Fixed ``Self`` return type using ``TypeVar``
- Fixed the env checker, the key was not passed when checking images from Dict observation space
- Fixed ``normalize_images`` which was not passed to parent class in some cases
Deprecations:
^^^^^^^^^^^^^
@ -58,6 +64,7 @@ Documentation:
- Changed ``env`` to ``vec_env`` when environment is vectorized
- Updated custom policy docs to better explain the ``mlp_extractor``'s dimensions (@AlexPasqua)
- Update custom policy documentation (@athatheo)
- Clarify doc when using image-like input
Release 1.6.2 (2022-10-10)
--------------------------

View file

@ -155,7 +155,7 @@ Driving policies can be trained in different scenarios, and several notebooks us
tactile-gym
-------------------
Suite of RL environments focussed on using a simulated tactile sensor as the primary source of observations. Sim-to-Real results across 4 out of 5 proposed envs.
Suite of RL environments focused on using a simulated tactile sensor as the primary source of observations. Sim-to-Real results across 4 out of 5 proposed envs.
| Author: Alex Church
| GitHub: https://github.com/ac-93/tactile_gym

View file

@ -21,11 +21,15 @@ def _check_image_input(observation_space: spaces.Box, key: str = "") -> None:
"""
Check that the input will be compatible with Stable-Baselines
when the observation is apparently an image.
:param observation_space: Observation space
:key: When the observation space comes from a Dict space, we pass the
corresponding key to have more precise warning messages. Defaults to "".
"""
if observation_space.dtype != np.uint8:
warnings.warn(
f"It seems that your observation {key} is an image but the `dtype` "
"of your observation_space is not `np.uint8`. "
f"It seems that your observation {key} is an image but its `dtype` "
f"is ({observation_space.dtype}) whereas it has to be `np.uint8`. "
"If your observation is not an image, we recommend you to flatten the observation "
"to have only a 1D vector"
)
@ -180,7 +184,7 @@ def _check_box_obs(observation_space: spaces.Box, key: str = "") -> None:
# If image, check the low and high values, the type and the number of channels
# and the shape (minimal value)
if len(observation_space.shape) == 3:
_check_image_input(observation_space)
_check_image_input(observation_space, key)
if len(observation_space.shape) not in [1, 3]:
warnings.warn(

View file

@ -73,7 +73,7 @@ class Image:
class HParam:
"""
Hyperparameter data class storing hyperparameters and metrics in dictionnaries
Hyperparameter data class storing hyperparameters and metrics in dictionaries
:param hparam_dict: key-value pairs of hyperparameters to log
:param metric_dict: key-value pairs of metrics to log

View file

@ -87,6 +87,9 @@ class BaseModel(nn.Module):
self.features_extractor_class = features_extractor_class
self.features_extractor_kwargs = features_extractor_kwargs
# Automatically deactivate dtype and bounds checks
if normalize_images is False and issubclass(features_extractor_class, (NatureCNN, CombinedExtractor)):
self.features_extractor_kwargs.update(dict(normalized_image=True))
def _update_features_extractor(
self,
@ -430,6 +433,7 @@ class ActorCriticPolicy(BasePolicy):
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs,
squash_output=squash_output,
normalize_images=normalize_images,
)
# Default network architecture, from stable-baselines
@ -446,7 +450,6 @@ class ActorCriticPolicy(BasePolicy):
self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
self.features_dim = self.features_extractor.features_dim
self.normalize_images = normalize_images
self.log_std_init = log_std_init
dist_kwargs = None
# Keyword arguments for gSDE distribution

View file

@ -27,6 +27,7 @@ def is_image_space_channels_first(observation_space: spaces.Box) -> bool:
def is_image_space(
observation_space: spaces.Space,
check_channels: bool = False,
normalized_image: bool = False,
) -> bool:
"""
Check if a observation space has the shape, limits and dtype
@ -38,15 +39,21 @@ def is_image_space(
:param observation_space:
:param check_channels: Whether to do or not the check for the number of channels.
e.g., with frame-stacking, the observation space may have more channels than expected.
:param normalized_image: Whether to assume that the image is already normalized
or not (this disables dtype and bounds checks): when True, it only checks that
the space is a Box and has 3 dimensions.
Otherwise, it checks that it has expected dtype (uint8) and bounds (values in [0, 255]).
:return:
"""
check_dtype = check_bounds = not normalized_image
if isinstance(observation_space, spaces.Box) and len(observation_space.shape) == 3:
# Check the type
if observation_space.dtype != np.uint8:
if check_dtype and observation_space.dtype != np.uint8:
return False
# Check the value range
if np.any(observation_space.low != 0) or np.any(observation_space.high != 255):
incorrect_bounds = np.any(observation_space.low != 0) or np.any(observation_space.high != 255)
if check_bounds and incorrect_bounds:
return False
# Skip channels check
@ -57,7 +64,7 @@ def is_image_space(
n_channels = observation_space.shape[0]
else:
n_channels = observation_space.shape[-1]
# RGB, RGBD, GrayScale
# GrayScale, RGB, RGBD
return n_channels in [1, 3, 4]
return False
@ -99,7 +106,7 @@ def preprocess_obs(
:return:
"""
if isinstance(observation_space, spaces.Box):
if is_image_space(observation_space) and normalize_images:
if normalize_images and is_image_space(observation_space):
return obs.float() / 255.0
return obs.float()

View file

@ -18,7 +18,7 @@ class BaseFeaturesExtractor(nn.Module):
:param features_dim: Number of features extracted.
"""
def __init__(self, observation_space: gym.Space, features_dim: int = 0):
def __init__(self, observation_space: gym.Space, features_dim: int = 0) -> None:
super().__init__()
assert features_dim > 0
self._observation_space = observation_space
@ -37,7 +37,7 @@ class FlattenExtractor(BaseFeaturesExtractor):
:param observation_space:
"""
def __init__(self, observation_space: gym.Space):
def __init__(self, observation_space: gym.Space) -> None:
super().__init__(observation_space, get_flattened_obs_dim(observation_space))
self.flatten = nn.Flatten()
@ -55,19 +55,31 @@ class NatureCNN(BaseFeaturesExtractor):
:param observation_space:
:param features_dim: Number of features extracted.
This corresponds to the number of unit for the last layer.
:param normalized_image: Whether to assume that the image is already normalized
or not (this disables dtype and bounds checks): when True, it only checks that
the space is a Box and has 3 dimensions.
Otherwise, it checks that it has expected dtype (uint8) and bounds (values in [0, 255]).
"""
def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512):
def __init__(
self,
observation_space: gym.spaces.Box,
features_dim: int = 512,
normalized_image: bool = False,
) -> None:
super().__init__(observation_space, features_dim)
# We assume CxHxW images (channels first)
# Re-ordering will be done by pre-preprocessing or wrapper
assert is_image_space(observation_space, check_channels=False), (
assert is_image_space(observation_space, check_channels=False, normalized_image=normalized_image), (
"You should use NatureCNN "
f"only with images not with {observation_space}\n"
"(you are probably using `CnnPolicy` instead of `MlpPolicy` or `MultiInputPolicy`)\n"
"If you are using a custom environment,\n"
"please check it using our env checker:\n"
"https://stable-baselines3.readthedocs.io/en/master/common/env_checker.html"
"https://stable-baselines3.readthedocs.io/en/master/common/env_checker.html.\n"
"If you are using `VecNormalize` or already normalized channel-first images "
"you should pass `normalize_images=False`: \n"
"https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html"
)
n_input_channels = observation_space.shape[0]
self.cnn = nn.Sequential(
@ -167,7 +179,7 @@ class MlpExtractor(nn.Module):
net_arch: List[Union[int, Dict[str, List[int]]]],
activation_fn: Type[nn.Module],
device: Union[th.device, str] = "auto",
):
) -> None:
super().__init__()
device = get_device(device)
shared_net: List[nn.Module] = []
@ -247,9 +259,18 @@ class CombinedExtractor(BaseFeaturesExtractor):
:param observation_space:
:param cnn_output_dim: Number of features to output from each CNN submodule(s). Defaults to
256 to avoid exploding network sizes.
:param normalized_image: Whether to assume that the image is already normalized
or not (this disables dtype and bounds checks): when True, it only checks that
the space is a Box and has 3 dimensions.
Otherwise, it checks that it has expected dtype (uint8) and bounds (values in [0, 255]).
"""
def __init__(self, observation_space: gym.spaces.Dict, cnn_output_dim: int = 256):
def __init__(
self,
observation_space: gym.spaces.Dict,
cnn_output_dim: int = 256,
normalized_image: bool = False,
) -> None:
# TODO we do not know features-dim here before going over all the items, so put something there. This is dirty!
super().__init__(observation_space, features_dim=1)
@ -257,8 +278,8 @@ class CombinedExtractor(BaseFeaturesExtractor):
total_concat_size = 0
for key, subspace in observation_space.spaces.items():
if is_image_space(subspace):
extractors[key] = NatureCNN(subspace, features_dim=cnn_output_dim)
if is_image_space(subspace, normalized_image=normalized_image):
extractors[key] = NatureCNN(subspace, features_dim=cnn_output_dim, normalized_image=normalized_image)
total_concat_size += cnn_output_dim
else:
# The observation key is a vector, flatten it if needed

View file

@ -187,7 +187,7 @@ class StackedDictObservations(StackedObservations):
def stack_observation_space(self, observation_space: spaces.Dict) -> spaces.Dict:
"""
Returns the stacked verson of a Dict observation space
Returns the stacked version of a Dict observation space
:param observation_space: Dict observation space to stack
:return: stacked observation space

View file

@ -6,6 +6,7 @@ import gym
import numpy as np
from stable_baselines3.common import utils
from stable_baselines3.common.preprocessing import is_image_space
from stable_baselines3.common.running_mean_std import RunningMeanStd
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
@ -50,9 +51,35 @@ class VecNormalize(VecEnvWrapper):
if isinstance(self.observation_space, gym.spaces.Dict):
self.obs_spaces = self.observation_space.spaces
self.obs_rms = {key: RunningMeanStd(shape=self.obs_spaces[key].shape) for key in self.norm_obs_keys}
# Update observation space when using image
# See explanation below and GH #1214
for key in self.obs_rms.keys():
if is_image_space(self.obs_spaces[key]):
self.observation_space.spaces[key] = gym.spaces.Box(
low=-clip_obs,
high=clip_obs,
shape=self.obs_spaces[key].shape,
dtype=np.float32,
)
else:
self.obs_spaces = None
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
# Update observation space when using image
# See GH #1214
# This is to raise proper error when
# VecNormalize is used with an image-like input and
# normalize_images=True.
# For correctness, we should also update the bounds
# in other cases but this will cause backward-incompatible change
# and break already saved policies.
if is_image_space(self.observation_space):
self.observation_space = gym.spaces.Box(
low=-clip_obs,
high=clip_obs,
shape=self.observation_space.shape,
dtype=np.float32,
)
self.ret_rms = RunningMeanStd(shape=())
self.clip_obs = clip_obs

View file

@ -51,7 +51,6 @@ class QNetwork(BasePolicy):
self.activation_fn = activation_fn
self.features_extractor = features_extractor
self.features_dim = features_dim
self.normalize_images = normalize_images
action_dim = self.action_space.n # number of actions
q_net = create_mlp(self.features_dim, action_dim, self.net_arch, self.activation_fn)
self.q_net = nn.Sequential(*q_net)
@ -125,6 +124,7 @@ class DQNPolicy(BasePolicy):
features_extractor_kwargs,
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs,
normalize_images=normalize_images,
)
if net_arch is None:
@ -135,7 +135,6 @@ class DQNPolicy(BasePolicy):
self.net_arch = net_arch
self.activation_fn = activation_fn
self.normalize_images = normalize_images
self.net_args = {
"observation_space": self.observation_space,

View file

@ -232,6 +232,7 @@ class SACPolicy(BasePolicy):
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs,
squash_output=True,
normalize_images=normalize_images,
)
if net_arch is None:

View file

@ -129,6 +129,7 @@ class TD3Policy(BasePolicy):
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs,
squash_output=True,
normalize_images=normalize_images,
)
# Default network architecture, from the original paper
@ -177,7 +178,7 @@ class TD3Policy(BasePolicy):
if self.share_features_extractor:
self.critic = self.make_critic(features_extractor=self.actor.features_extractor)
# Critic target should not share the features extactor with critic
# Critic target should not share the features extractor with critic
# but it can share it with the actor target as actor and critic are sharing
# the same features_extractor too
# NOTE: as a result the effective poliak (soft-copy) coefficient for the features extractor

View file

@ -1 +1 @@
1.7.0a6
1.7.0a7

View file

@ -10,7 +10,7 @@ from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
from stable_baselines3.common.envs import FakeImageEnv
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
from stable_baselines3.common.utils import zip_strict
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecTransposeImage, is_vecenv_wrapped
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNormalize, VecTransposeImage, is_vecenv_wrapped
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN])
@ -269,6 +269,10 @@ def test_image_space_checks():
not_image_space = spaces.Box(0, 10, shape=(10, 10, 3), dtype=np.uint8)
assert not is_image_space(not_image_space)
# Deactivate dtype and bound checking
normalized_image = spaces.Box(0, 1, shape=(10, 10, 3), dtype=np.float32)
assert is_image_space(normalized_image, normalized_image=True)
# Not correct space
not_image_space = spaces.Discrete(n=10)
assert not is_image_space(not_image_space)
@ -297,3 +301,44 @@ def test_image_space_checks():
# Should raise a warning
with pytest.warns(Warning):
assert not is_image_space_channels_first(channel_mid_space)
@pytest.mark.parametrize("model_class", [A2C, PPO, DQN, SAC, TD3])
@pytest.mark.parametrize("normalize_images", [True, False])
def test_image_like_input(model_class, normalize_images):
"""
Check that we can handle image-like input (3D tensor)
when normalize_images=False
"""
# Fake grayscale with frameskip
# Atari after preprocessing: 84x84x1, here we are using lower resolution
# to check that the network handle it automatically
env = FakeImageEnv(
screen_height=36,
screen_width=36,
n_channels=1,
channel_first=True,
discrete=model_class not in {SAC, TD3},
)
vec_env = VecNormalize(DummyVecEnv([lambda: env]))
# Reduce the size of the features
# deactivate normalization
kwargs = dict(
policy_kwargs=dict(
normalize_images=normalize_images,
features_extractor_kwargs=dict(features_dim=32),
),
seed=1,
)
if model_class in {A2C, PPO}:
kwargs.update(dict(n_steps=64))
else:
# Avoid memory error when using replay buffer
# Reduce the size of the features
kwargs.update(dict(buffer_size=250))
if normalize_images:
with pytest.raises(AssertionError):
model_class("CnnPolicy", vec_env, **kwargs).learn(128)
else:
model_class("CnnPolicy", vec_env, **kwargs).learn(128)

View file

@ -28,8 +28,8 @@ class DummyDictEnv(gym.Env):
else:
self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
N_CHANNELS = 1
HEIGHT = 64
WIDTH = 64
HEIGHT = 36
WIDTH = 36
if channel_last:
obs_shape = (HEIGHT, WIDTH, N_CHANNELS)
@ -323,3 +323,10 @@ def test_dict_nested():
with pytest.raises(NotImplementedError):
env = DummyVecEnv([lambda: DummyDictEnv(nested_dict_obs=True)])
def test_vec_normalize_image():
env = VecNormalize(DummyVecEnv([lambda: DummyDictEnv()]), norm_obs_keys=["img"])
assert env.observation_space.spaces["img"].dtype == np.float32
assert (env.observation_space.spaces["img"].low == -env.clip_obs).all()
assert (env.observation_space.spaces["img"].high == env.clip_obs).all()