Avoid transposing channel-first envs (#213)

* Add test for channel-first environments

* Add support for channel-first envs, including more tests

* Update changelog

* Run black

* Run black, again

* Improve NatureCNN error message

* Update image checks and FrameStack wrapper

* Update tests

* Update docs

* Run isort

* Reformat

* Fixes: avoid breaking changes for non-image env

* Add additional checks

* Update docstring

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
Anssi 2020-11-03 13:34:09 +02:00 committed by GitHub
parent 9d463bc476
commit e2b6f5460f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 255 additions and 24 deletions

View file

@ -8,8 +8,9 @@ That is to say, your environment must implement the following methods (and inher
.. note::
If you are using images as input, the input values must be in [0, 255] as the observation
is normalized (dividing by 255 to have values in [0, 1]) when using CNN policies.
If you are using images as input, the input values must be in [0, 255] and np.uint8 as the observation
is normalized (dividing by 255 to have values in [0, 1]) when using CNN policies. Images can be either
channel-first or channel-last.
@ -28,7 +29,7 @@ That is to say, your environment must implement the following methods (and inher
# They must be gym.spaces objects
# Example when using discrete actions:
self.action_space = spaces.Discrete(N_DISCRETE_ACTIONS)
# Example for using image as input:
# Example for using image as input (can be channel-first or channel-last):
self.observation_space = spaces.Box(low=0, high=255,
shape=(HEIGHT, WIDTH, N_CHANNELS), dtype=np.uint8)

View file

@ -55,7 +55,8 @@ Pre-Processing
To handle different observation spaces, some pre-processing needs to be done (e.g. one-hot encoding for discrete observation).
Most of the code for pre-processing is in ``common/preprocessing.py`` and ``common/policies.py``.
For images, we make use of an additional wrapper ``VecTransposeImage`` because PyTorch uses the "channel-first" convention.
For images, environment is automatically wrapped with ``VecTransposeImage`` if observations are detected to be images with
channel-last convention to transform it to PyTorch's channel-first convention.
Policy Structure

View file

@ -6,15 +6,20 @@ Changelog
Pre-Release 0.11.0a0 (WIP)
-------------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
New Features:
^^^^^^^^^^^^^
- Add support for ``VecFrameStack`` to stack on first or last observation dimension, along with
automatic check for image spaces.
- ``VecFrameStack`` now has a ``channels_order`` argument to tell if observations should be stacked
on the first or last observation dimension (originally always stacked on last).
Bug Fixes:
^^^^^^^^^^
- Fixed bug where code added VecTranspose on channel-first image environments (thanks @qxcv)
Deprecations:
^^^^^^^^^^^^^
@ -22,6 +27,7 @@ Deprecations:
Others:
^^^^^^^
- Add more issue templates
- Improve error message in ``NatureCNN``
Documentation:
^^^^^^^^^^^^^^

View file

@ -101,7 +101,7 @@ Run the benchmark (replace ``$ENV_ID`` by the env id, for instance ``BreakoutNoF
.. code-block:: bash
python train.py --algo a2c --env $ENV_ID --eval-episodes 10 --eval-freq 10000
python train.py --algo dqn --env $ENV_ID --eval-episodes 10 --eval-freq 10000
Plot the results:

View file

@ -16,7 +16,7 @@ from stable_baselines3.common.callbacks import BaseCallback, CallbackList, Conve
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.policies import BasePolicy, get_policy_from_name
from stable_baselines3.common.preprocessing import is_image_space
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
from stable_baselines3.common.utils import (
@ -176,7 +176,11 @@ class BaseAlgorithm(ABC):
print("Wrapping the env in a DummyVecEnv.")
env = DummyVecEnv([lambda: env])
if is_image_space(env.observation_space) and not is_wrapped(env, VecTransposeImage):
if (
is_image_space(env.observation_space)
and not is_wrapped(env, VecTransposeImage)
and not is_image_space_channels_first(env.observation_space)
):
if verbose >= 1:
print("Wrapping the env in a VecTransposeImage.")
env = VecTransposeImage(env)

View file

@ -112,14 +112,23 @@ class FakeImageEnv(Env):
:param screen_height: Height of the image
:param screen_width: Width of the image
:param n_channels: Number of color channels
:param discrete:
:param discrete: Create discrete action space instead of continuous
:param channel_first: Put channels on first axis instead of last
"""
def __init__(
self, action_dim: int = 6, screen_height: int = 84, screen_width: int = 84, n_channels: int = 1, discrete: bool = True
self,
action_dim: int = 6,
screen_height: int = 84,
screen_width: int = 84,
n_channels: int = 1,
discrete: bool = True,
channel_first: bool = False,
):
self.observation_space = Box(low=0, high=255, shape=(screen_height, screen_width, n_channels), dtype=np.uint8)
self.observation_shape = (screen_height, screen_width, n_channels)
if channel_first:
self.observation_shape = (n_channels, screen_height, screen_width)
self.observation_space = Box(low=0, high=255, shape=self.observation_shape, dtype=np.uint8)
if discrete:
self.action_space = Discrete(action_dim)
else:

View file

@ -1,3 +1,4 @@
import warnings
from typing import Tuple
import numpy as np
@ -6,6 +7,23 @@ from gym import spaces
from torch.nn import functional as F
def is_image_space_channels_first(observation_space: spaces.Box) -> bool:
"""
Check if an image observation space (see ``is_image_space``)
is channels-first (CxHxW, True) or channels-last (HxWxC, False).
Use a heuristic that channel dimension is the smallest of the three.
If second dimension is smallest, raise an exception (no support).
:param observation_space:
:return: True if observation space is channels-first image, False if channels-last.
"""
smallest_dimension = np.argmin(observation_space.shape).item()
if smallest_dimension == 1:
warnings.warn("Treating image space as channels-last, while second dimension was smallest of the three.")
return smallest_dimension == 0
def is_image_space(observation_space: spaces.Space, channels_last: bool = True, check_channels: bool = False) -> bool:
"""
Check if a observation space has the shape, limits and dtype

View file

@ -65,8 +65,11 @@ class NatureCNN(BaseFeaturesExtractor):
# Re-ordering will be done by pre-preprocessing or wrapper
assert is_image_space(observation_space), (
"You should use NatureCNN "
f"only with images not with {observation_space} "
"(you are probably using `CnnPolicy` instead of `MlpPolicy`)"
f"only with images not with {observation_space}\n"
"(you are probably using `CnnPolicy` instead of `MlpPolicy`)\n"
"If you are using a custom environment,\n"
"please check it using our env checker:\n"
"https://stable-baselines3.readthedocs.io/en/master/common/env_checker.html"
)
n_input_channels = observation_space.shape[0]
self.cnn = nn.Sequential(

View file

@ -1,27 +1,52 @@
import warnings
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from gym import spaces
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper
class VecFrameStack(VecEnvWrapper):
"""
Frame stacking wrapper for vectorized environment
Frame stacking wrapper for vectorized environment. Designed for image observations.
Dimension to stack over is either first (channels-first) or
last (channels-last), which is detected automatically using
``common.preprocessing.is_image_space_channels_first`` if
observation is an image space.
:param venv: the vectorized environment to wrap
:param n_stack: Number of frames to stack
:param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension.
If None, automatically detect channel to stack over in case of image observation or default to "last" (default).
"""
def __init__(self, venv: VecEnv, n_stack: int):
def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[str] = None):
self.venv = venv
self.n_stack = n_stack
wrapped_obs_space = venv.observation_space
assert isinstance(wrapped_obs_space, spaces.Box), "VecFrameStack only work with gym.spaces.Box observation space"
low = np.repeat(wrapped_obs_space.low, self.n_stack, axis=-1)
high = np.repeat(wrapped_obs_space.high, self.n_stack, axis=-1)
if channels_order is None:
# Detect channel location automatically for images
if is_image_space(wrapped_obs_space):
self.channels_first = is_image_space_channels_first(wrapped_obs_space)
else:
# Default behavior for non-image space, stack on the last axis
self.channels_first = False
else:
assert channels_order in {"last", "first"}, "`channels_order` must be one of following: 'last', 'first'"
self.channels_first = channels_order == "first"
# This includes the vec-env dimension (first)
self.stack_dimension = 1 if self.channels_first else -1
repeat_axis = 0 if self.channels_first else -1
low = np.repeat(wrapped_obs_space.low, self.n_stack, axis=repeat_axis)
high = np.repeat(wrapped_obs_space.high, self.n_stack, axis=repeat_axis)
self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype)
observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype)
VecEnvWrapper.__init__(self, venv, observation_space=observation_space)
@ -30,18 +55,29 @@ class VecFrameStack(VecEnvWrapper):
observations, rewards, dones, infos = self.venv.step_wait()
# Let pytype know that observation is not a dict
assert isinstance(observations, np.ndarray)
last_ax_size = observations.shape[-1]
self.stackedobs = np.roll(self.stackedobs, shift=-last_ax_size, axis=-1)
stack_ax_size = observations.shape[self.stack_dimension]
self.stackedobs = np.roll(self.stackedobs, shift=-stack_ax_size, axis=self.stack_dimension)
for i, done in enumerate(dones):
if done:
if "terminal_observation" in infos[i]:
old_terminal = infos[i]["terminal_observation"]
new_terminal = np.concatenate((self.stackedobs[i, ..., :-last_ax_size], old_terminal), axis=-1)
if self.channels_first:
new_terminal = np.concatenate(
(self.stackedobs[i, :-stack_ax_size, ...], old_terminal), axis=self.stack_dimension
)
else:
new_terminal = np.concatenate(
(self.stackedobs[i, ..., :-stack_ax_size], old_terminal), axis=self.stack_dimension
)
infos[i]["terminal_observation"] = new_terminal
else:
warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info")
self.stackedobs[i] = 0
self.stackedobs[..., -observations.shape[-1] :] = observations
if self.channels_first:
self.stackedobs[:, -observations.shape[self.stack_dimension] :, ...] = observations
else:
self.stackedobs[..., -observations.shape[self.stack_dimension] :] = observations
return self.stackedobs, rewards, dones, infos
def reset(self) -> np.ndarray:
@ -50,7 +86,10 @@ class VecFrameStack(VecEnvWrapper):
"""
obs: np.ndarray = self.venv.reset() # pytype:disable=annotation-type-mismatch
self.stackedobs[...] = 0
self.stackedobs[..., -obs.shape[-1] :] = obs
if self.channels_first:
self.stackedobs[:, -obs.shape[self.stack_dimension] :, ...] = obs
else:
self.stackedobs[..., -obs.shape[self.stack_dimension] :] = obs
return self.stackedobs
def close(self) -> None:

View file

@ -4,10 +4,13 @@ from copy import deepcopy
import numpy as np
import pytest
import torch as th
from gym import spaces
from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
from stable_baselines3.common.identity_env import FakeImageEnv
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
from stable_baselines3.common.utils import zip_strict
from stable_baselines3.common.vec_env import VecTransposeImage, is_wrapped
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN])
@ -25,6 +28,9 @@ def test_cnn(tmp_path, model_class):
kwargs = dict(buffer_size=250, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)))
model = model_class("CnnPolicy", env, **kwargs).learn(250)
# FakeImageEnv is channel last by default and should be wrapped
assert is_wrapped(model.get_env(), VecTransposeImage)
obs = env.reset()
action, _ = model.predict(obs, deterministic=True)
@ -174,3 +180,73 @@ def test_features_extractor_target_net(model_class, share_features_extractor):
params_should_match(original_actor_param, model.actor.parameters())
td3_features_extractor_check(model)
def test_channel_first_env(tmp_path):
# test_cnn uses environment with HxWxC setup that is transposed, but we
# also want to work with CxHxW envs directly without transposing wrapper.
SAVE_NAME = "cnn_model.zip"
# Create environment with transposed images (CxHxW).
# If underlying CNN processes the data in wrong format,
# it will raise an error of negative dimension sizes while creating convolutions
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=True, channel_first=True)
model = A2C("CnnPolicy", env, n_steps=100).learn(250)
assert not is_wrapped(model.get_env(), VecTransposeImage)
obs = env.reset()
action, _ = model.predict(obs, deterministic=True)
model.save(tmp_path / SAVE_NAME)
del model
model = A2C.load(tmp_path / SAVE_NAME)
# Check that the prediction is the same
assert np.allclose(action, model.predict(obs, deterministic=True)[0])
os.remove(str(tmp_path / SAVE_NAME))
def test_image_space_checks():
not_image_space = spaces.Box(0, 1, shape=(10,))
assert not is_image_space(not_image_space)
# Not uint8
not_image_space = spaces.Box(0, 255, shape=(10, 10, 3))
assert not is_image_space(not_image_space)
# Not correct shape
not_image_space = spaces.Box(0, 255, shape=(10, 10), dtype=np.uint8)
assert not is_image_space(not_image_space)
# Not correct low/high
not_image_space = spaces.Box(0, 10, shape=(10, 10, 3), dtype=np.uint8)
assert not is_image_space(not_image_space)
# Not correct space
not_image_space = spaces.Discrete(n=10)
assert not is_image_space(not_image_space)
an_image_space = spaces.Box(0, 255, shape=(10, 10, 3), dtype=np.uint8)
assert is_image_space(an_image_space)
an_image_space_with_odd_channels = spaces.Box(0, 255, shape=(10, 10, 5), dtype=np.uint8)
assert is_image_space(an_image_space_with_odd_channels)
# Should not pass if we check if channels are valid for an image
assert not is_image_space(an_image_space_with_odd_channels, check_channels=True)
# Test if channel-check works
channel_first_space = spaces.Box(0, 255, shape=(3, 10, 10), dtype=np.uint8)
assert is_image_space_channels_first(channel_first_space)
channel_last_space = spaces.Box(0, 255, shape=(10, 10, 3), dtype=np.uint8)
assert not is_image_space_channels_first(channel_last_space)
channel_mid_space = spaces.Box(0, 255, shape=(10, 3, 10), dtype=np.uint8)
# Should raise a warning
with pytest.warns(Warning):
assert not is_image_space_channels_first(channel_mid_space)

View file

@ -341,3 +341,77 @@ def test_vecenv_wrapper_getattr():
_ = double_wrapped.var_b
with pytest.raises(AttributeError): # should raise as does not exist
_ = double_wrapped.nonexistent_attribute
def test_framestack_vecenv():
"""Test that framestack environment stacks on desired axis"""
image_space_shape = [12, 8, 3]
zero_acts = np.zeros([N_ENVS] + image_space_shape)
transposed_image_space_shape = image_space_shape[::-1]
transposed_zero_acts = np.zeros([N_ENVS] + transposed_image_space_shape)
def make_image_env():
return CustomGymEnv(
gym.spaces.Box(
low=np.zeros(image_space_shape),
high=np.ones(image_space_shape) * 255,
dtype=np.uint8,
)
)
def make_transposed_image_env():
return CustomGymEnv(
gym.spaces.Box(
low=np.zeros(transposed_image_space_shape),
high=np.ones(transposed_image_space_shape) * 255,
dtype=np.uint8,
)
)
def make_non_image_env():
return CustomGymEnv(gym.spaces.Box(low=np.zeros((2,)), high=np.ones((2,))))
vec_env = DummyVecEnv([make_image_env for _ in range(N_ENVS)])
vec_env = VecFrameStack(vec_env, n_stack=2)
obs, _, _, _ = vec_env.step(zero_acts)
vec_env.close()
# Should be stacked on the last dimension
assert obs.shape[-1] == (image_space_shape[-1] * 2)
# Try automatic stacking on first dimension now
vec_env = DummyVecEnv([make_transposed_image_env for _ in range(N_ENVS)])
vec_env = VecFrameStack(vec_env, n_stack=2)
obs, _, _, _ = vec_env.step(transposed_zero_acts)
vec_env.close()
# Should be stacked on the first dimension (note the transposing in make_transposed_image_env)
assert obs.shape[1] == (image_space_shape[-1] * 2)
# Try forcing dimensions
vec_env = DummyVecEnv([make_image_env for _ in range(N_ENVS)])
vec_env = VecFrameStack(vec_env, n_stack=2, channels_order="last")
obs, _, _, _ = vec_env.step(zero_acts)
vec_env.close()
# Should be stacked on the last dimension
assert obs.shape[-1] == (image_space_shape[-1] * 2)
vec_env = DummyVecEnv([make_image_env for _ in range(N_ENVS)])
vec_env = VecFrameStack(vec_env, n_stack=2, channels_order="first")
obs, _, _, _ = vec_env.step(zero_acts)
vec_env.close()
# Should be stacked on the first dimension
assert obs.shape[1] == (image_space_shape[0] * 2)
# Test invalid channels_order
vec_env = DummyVecEnv([make_image_env for _ in range(N_ENVS)])
with pytest.raises(AssertionError):
vec_env = VecFrameStack(vec_env, n_stack=2, channels_order="not_valid")
# Test that it works with non-image envs when no channels_order is given
vec_env = DummyVecEnv([make_non_image_env for _ in range(N_ENVS)])
vec_env = VecFrameStack(vec_env, n_stack=2)