diff --git a/docs/guide/custom_env.rst b/docs/guide/custom_env.rst index 62ea72a..6adf55d 100644 --- a/docs/guide/custom_env.rst +++ b/docs/guide/custom_env.rst @@ -8,8 +8,9 @@ That is to say, your environment must implement the following methods (and inher .. note:: - If you are using images as input, the input values must be in [0, 255] as the observation - is normalized (dividing by 255 to have values in [0, 1]) when using CNN policies. + If you are using images as input, the input values must be in [0, 255] and np.uint8 as the observation + is normalized (dividing by 255 to have values in [0, 1]) when using CNN policies. Images can be either + channel-first or channel-last. @@ -28,7 +29,7 @@ That is to say, your environment must implement the following methods (and inher # They must be gym.spaces objects # Example when using discrete actions: self.action_space = spaces.Discrete(N_DISCRETE_ACTIONS) - # Example for using image as input: + # Example for using image as input (can be channel-first or channel-last): self.observation_space = spaces.Box(low=0, high=255, shape=(HEIGHT, WIDTH, N_CHANNELS), dtype=np.uint8) diff --git a/docs/guide/developer.rst b/docs/guide/developer.rst index 388c454..d930594 100644 --- a/docs/guide/developer.rst +++ b/docs/guide/developer.rst @@ -55,7 +55,8 @@ Pre-Processing To handle different observation spaces, some pre-processing needs to be done (e.g. one-hot encoding for discrete observation). Most of the code for pre-processing is in ``common/preprocessing.py`` and ``common/policies.py``. -For images, we make use of an additional wrapper ``VecTransposeImage`` because PyTorch uses the "channel-first" convention. +For images, environment is automatically wrapped with ``VecTransposeImage`` if observations are detected to be images with +channel-last convention to transform it to PyTorch's channel-first convention. Policy Structure diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 000e6e5..e614819 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -6,15 +6,20 @@ Changelog Pre-Release 0.11.0a0 (WIP) ------------------------------- - Breaking Changes: ^^^^^^^^^^^^^^^^^ + New Features: ^^^^^^^^^^^^^ +- Add support for ``VecFrameStack`` to stack on first or last observation dimension, along with + automatic check for image spaces. +- ``VecFrameStack`` now has a ``channels_order`` argument to tell if observations should be stacked + on the first or last observation dimension (originally always stacked on last). Bug Fixes: ^^^^^^^^^^ +- Fixed bug where code added VecTranspose on channel-first image environments (thanks @qxcv) Deprecations: ^^^^^^^^^^^^^ @@ -22,6 +27,7 @@ Deprecations: Others: ^^^^^^^ - Add more issue templates +- Improve error message in ``NatureCNN`` Documentation: ^^^^^^^^^^^^^^ diff --git a/docs/modules/dqn.rst b/docs/modules/dqn.rst index 388307c..1d1d5e3 100644 --- a/docs/modules/dqn.rst +++ b/docs/modules/dqn.rst @@ -101,7 +101,7 @@ Run the benchmark (replace ``$ENV_ID`` by the env id, for instance ``BreakoutNoF .. code-block:: bash - python train.py --algo a2c --env $ENV_ID --eval-episodes 10 --eval-freq 10000 + python train.py --algo dqn --env $ENV_ID --eval-episodes 10 --eval-freq 10000 Plot the results: diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index b4195f7..f33b36f 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -16,7 +16,7 @@ from stable_baselines3.common.callbacks import BaseCallback, CallbackList, Conve from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.policies import BasePolicy, get_policy_from_name -from stable_baselines3.common.preprocessing import is_image_space +from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback from stable_baselines3.common.utils import ( @@ -176,7 +176,11 @@ class BaseAlgorithm(ABC): print("Wrapping the env in a DummyVecEnv.") env = DummyVecEnv([lambda: env]) - if is_image_space(env.observation_space) and not is_wrapped(env, VecTransposeImage): + if ( + is_image_space(env.observation_space) + and not is_wrapped(env, VecTransposeImage) + and not is_image_space_channels_first(env.observation_space) + ): if verbose >= 1: print("Wrapping the env in a VecTransposeImage.") env = VecTransposeImage(env) diff --git a/stable_baselines3/common/identity_env.py b/stable_baselines3/common/identity_env.py index 0d6a743..8f6ccd2 100644 --- a/stable_baselines3/common/identity_env.py +++ b/stable_baselines3/common/identity_env.py @@ -112,14 +112,23 @@ class FakeImageEnv(Env): :param screen_height: Height of the image :param screen_width: Width of the image :param n_channels: Number of color channels - :param discrete: + :param discrete: Create discrete action space instead of continuous + :param channel_first: Put channels on first axis instead of last """ def __init__( - self, action_dim: int = 6, screen_height: int = 84, screen_width: int = 84, n_channels: int = 1, discrete: bool = True + self, + action_dim: int = 6, + screen_height: int = 84, + screen_width: int = 84, + n_channels: int = 1, + discrete: bool = True, + channel_first: bool = False, ): - - self.observation_space = Box(low=0, high=255, shape=(screen_height, screen_width, n_channels), dtype=np.uint8) + self.observation_shape = (screen_height, screen_width, n_channels) + if channel_first: + self.observation_shape = (n_channels, screen_height, screen_width) + self.observation_space = Box(low=0, high=255, shape=self.observation_shape, dtype=np.uint8) if discrete: self.action_space = Discrete(action_dim) else: diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py index 0b84683..881970a 100644 --- a/stable_baselines3/common/preprocessing.py +++ b/stable_baselines3/common/preprocessing.py @@ -1,3 +1,4 @@ +import warnings from typing import Tuple import numpy as np @@ -6,6 +7,23 @@ from gym import spaces from torch.nn import functional as F +def is_image_space_channels_first(observation_space: spaces.Box) -> bool: + """ + Check if an image observation space (see ``is_image_space``) + is channels-first (CxHxW, True) or channels-last (HxWxC, False). + + Use a heuristic that channel dimension is the smallest of the three. + If second dimension is smallest, raise an exception (no support). + + :param observation_space: + :return: True if observation space is channels-first image, False if channels-last. + """ + smallest_dimension = np.argmin(observation_space.shape).item() + if smallest_dimension == 1: + warnings.warn("Treating image space as channels-last, while second dimension was smallest of the three.") + return smallest_dimension == 0 + + def is_image_space(observation_space: spaces.Space, channels_last: bool = True, check_channels: bool = False) -> bool: """ Check if a observation space has the shape, limits and dtype diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index e73a684..165d37d 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -65,8 +65,11 @@ class NatureCNN(BaseFeaturesExtractor): # Re-ordering will be done by pre-preprocessing or wrapper assert is_image_space(observation_space), ( "You should use NatureCNN " - f"only with images not with {observation_space} " - "(you are probably using `CnnPolicy` instead of `MlpPolicy`)" + f"only with images not with {observation_space}\n" + "(you are probably using `CnnPolicy` instead of `MlpPolicy`)\n" + "If you are using a custom environment,\n" + "please check it using our env checker:\n" + "https://stable-baselines3.readthedocs.io/en/master/common/env_checker.html" ) n_input_channels = observation_space.shape[0] self.cnn = nn.Sequential( diff --git a/stable_baselines3/common/vec_env/vec_frame_stack.py b/stable_baselines3/common/vec_env/vec_frame_stack.py index 94199ff..ff9a796 100644 --- a/stable_baselines3/common/vec_env/vec_frame_stack.py +++ b/stable_baselines3/common/vec_env/vec_frame_stack.py @@ -1,27 +1,52 @@ import warnings -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple import numpy as np from gym import spaces +from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper class VecFrameStack(VecEnvWrapper): """ - Frame stacking wrapper for vectorized environment + Frame stacking wrapper for vectorized environment. Designed for image observations. + + Dimension to stack over is either first (channels-first) or + last (channels-last), which is detected automatically using + ``common.preprocessing.is_image_space_channels_first`` if + observation is an image space. :param venv: the vectorized environment to wrap :param n_stack: Number of frames to stack + :param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension. + If None, automatically detect channel to stack over in case of image observation or default to "last" (default). """ - def __init__(self, venv: VecEnv, n_stack: int): + def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[str] = None): self.venv = venv self.n_stack = n_stack + wrapped_obs_space = venv.observation_space assert isinstance(wrapped_obs_space, spaces.Box), "VecFrameStack only work with gym.spaces.Box observation space" - low = np.repeat(wrapped_obs_space.low, self.n_stack, axis=-1) - high = np.repeat(wrapped_obs_space.high, self.n_stack, axis=-1) + + if channels_order is None: + # Detect channel location automatically for images + if is_image_space(wrapped_obs_space): + self.channels_first = is_image_space_channels_first(wrapped_obs_space) + else: + # Default behavior for non-image space, stack on the last axis + self.channels_first = False + else: + assert channels_order in {"last", "first"}, "`channels_order` must be one of following: 'last', 'first'" + + self.channels_first = channels_order == "first" + + # This includes the vec-env dimension (first) + self.stack_dimension = 1 if self.channels_first else -1 + repeat_axis = 0 if self.channels_first else -1 + low = np.repeat(wrapped_obs_space.low, self.n_stack, axis=repeat_axis) + high = np.repeat(wrapped_obs_space.high, self.n_stack, axis=repeat_axis) self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype) observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype) VecEnvWrapper.__init__(self, venv, observation_space=observation_space) @@ -30,18 +55,29 @@ class VecFrameStack(VecEnvWrapper): observations, rewards, dones, infos = self.venv.step_wait() # Let pytype know that observation is not a dict assert isinstance(observations, np.ndarray) - last_ax_size = observations.shape[-1] - self.stackedobs = np.roll(self.stackedobs, shift=-last_ax_size, axis=-1) + stack_ax_size = observations.shape[self.stack_dimension] + self.stackedobs = np.roll(self.stackedobs, shift=-stack_ax_size, axis=self.stack_dimension) for i, done in enumerate(dones): if done: if "terminal_observation" in infos[i]: old_terminal = infos[i]["terminal_observation"] - new_terminal = np.concatenate((self.stackedobs[i, ..., :-last_ax_size], old_terminal), axis=-1) + if self.channels_first: + new_terminal = np.concatenate( + (self.stackedobs[i, :-stack_ax_size, ...], old_terminal), axis=self.stack_dimension + ) + else: + new_terminal = np.concatenate( + (self.stackedobs[i, ..., :-stack_ax_size], old_terminal), axis=self.stack_dimension + ) infos[i]["terminal_observation"] = new_terminal else: warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info") self.stackedobs[i] = 0 - self.stackedobs[..., -observations.shape[-1] :] = observations + if self.channels_first: + self.stackedobs[:, -observations.shape[self.stack_dimension] :, ...] = observations + else: + self.stackedobs[..., -observations.shape[self.stack_dimension] :] = observations + return self.stackedobs, rewards, dones, infos def reset(self) -> np.ndarray: @@ -50,7 +86,10 @@ class VecFrameStack(VecEnvWrapper): """ obs: np.ndarray = self.venv.reset() # pytype:disable=annotation-type-mismatch self.stackedobs[...] = 0 - self.stackedobs[..., -obs.shape[-1] :] = obs + if self.channels_first: + self.stackedobs[:, -obs.shape[self.stack_dimension] :, ...] = obs + else: + self.stackedobs[..., -obs.shape[self.stack_dimension] :] = obs return self.stackedobs def close(self) -> None: diff --git a/tests/test_cnn.py b/tests/test_cnn.py index ba1040d..7f85b75 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -4,10 +4,13 @@ from copy import deepcopy import numpy as np import pytest import torch as th +from gym import spaces from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 from stable_baselines3.common.identity_env import FakeImageEnv +from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first from stable_baselines3.common.utils import zip_strict +from stable_baselines3.common.vec_env import VecTransposeImage, is_wrapped @pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN]) @@ -25,6 +28,9 @@ def test_cnn(tmp_path, model_class): kwargs = dict(buffer_size=250, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32))) model = model_class("CnnPolicy", env, **kwargs).learn(250) + # FakeImageEnv is channel last by default and should be wrapped + assert is_wrapped(model.get_env(), VecTransposeImage) + obs = env.reset() action, _ = model.predict(obs, deterministic=True) @@ -174,3 +180,73 @@ def test_features_extractor_target_net(model_class, share_features_extractor): params_should_match(original_actor_param, model.actor.parameters()) td3_features_extractor_check(model) + + +def test_channel_first_env(tmp_path): + # test_cnn uses environment with HxWxC setup that is transposed, but we + # also want to work with CxHxW envs directly without transposing wrapper. + SAVE_NAME = "cnn_model.zip" + + # Create environment with transposed images (CxHxW). + # If underlying CNN processes the data in wrong format, + # it will raise an error of negative dimension sizes while creating convolutions + env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=True, channel_first=True) + + model = A2C("CnnPolicy", env, n_steps=100).learn(250) + + assert not is_wrapped(model.get_env(), VecTransposeImage) + + obs = env.reset() + + action, _ = model.predict(obs, deterministic=True) + + model.save(tmp_path / SAVE_NAME) + del model + + model = A2C.load(tmp_path / SAVE_NAME) + + # Check that the prediction is the same + assert np.allclose(action, model.predict(obs, deterministic=True)[0]) + + os.remove(str(tmp_path / SAVE_NAME)) + + +def test_image_space_checks(): + not_image_space = spaces.Box(0, 1, shape=(10,)) + assert not is_image_space(not_image_space) + + # Not uint8 + not_image_space = spaces.Box(0, 255, shape=(10, 10, 3)) + assert not is_image_space(not_image_space) + + # Not correct shape + not_image_space = spaces.Box(0, 255, shape=(10, 10), dtype=np.uint8) + assert not is_image_space(not_image_space) + + # Not correct low/high + not_image_space = spaces.Box(0, 10, shape=(10, 10, 3), dtype=np.uint8) + assert not is_image_space(not_image_space) + + # Not correct space + not_image_space = spaces.Discrete(n=10) + assert not is_image_space(not_image_space) + + an_image_space = spaces.Box(0, 255, shape=(10, 10, 3), dtype=np.uint8) + assert is_image_space(an_image_space) + + an_image_space_with_odd_channels = spaces.Box(0, 255, shape=(10, 10, 5), dtype=np.uint8) + assert is_image_space(an_image_space_with_odd_channels) + # Should not pass if we check if channels are valid for an image + assert not is_image_space(an_image_space_with_odd_channels, check_channels=True) + + # Test if channel-check works + channel_first_space = spaces.Box(0, 255, shape=(3, 10, 10), dtype=np.uint8) + assert is_image_space_channels_first(channel_first_space) + + channel_last_space = spaces.Box(0, 255, shape=(10, 10, 3), dtype=np.uint8) + assert not is_image_space_channels_first(channel_last_space) + + channel_mid_space = spaces.Box(0, 255, shape=(10, 3, 10), dtype=np.uint8) + # Should raise a warning + with pytest.warns(Warning): + assert not is_image_space_channels_first(channel_mid_space) diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index 141ca6a..5545ff4 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -341,3 +341,77 @@ def test_vecenv_wrapper_getattr(): _ = double_wrapped.var_b with pytest.raises(AttributeError): # should raise as does not exist _ = double_wrapped.nonexistent_attribute + + +def test_framestack_vecenv(): + """Test that framestack environment stacks on desired axis""" + + image_space_shape = [12, 8, 3] + zero_acts = np.zeros([N_ENVS] + image_space_shape) + + transposed_image_space_shape = image_space_shape[::-1] + transposed_zero_acts = np.zeros([N_ENVS] + transposed_image_space_shape) + + def make_image_env(): + return CustomGymEnv( + gym.spaces.Box( + low=np.zeros(image_space_shape), + high=np.ones(image_space_shape) * 255, + dtype=np.uint8, + ) + ) + + def make_transposed_image_env(): + return CustomGymEnv( + gym.spaces.Box( + low=np.zeros(transposed_image_space_shape), + high=np.ones(transposed_image_space_shape) * 255, + dtype=np.uint8, + ) + ) + + def make_non_image_env(): + return CustomGymEnv(gym.spaces.Box(low=np.zeros((2,)), high=np.ones((2,)))) + + vec_env = DummyVecEnv([make_image_env for _ in range(N_ENVS)]) + vec_env = VecFrameStack(vec_env, n_stack=2) + obs, _, _, _ = vec_env.step(zero_acts) + vec_env.close() + + # Should be stacked on the last dimension + assert obs.shape[-1] == (image_space_shape[-1] * 2) + + # Try automatic stacking on first dimension now + vec_env = DummyVecEnv([make_transposed_image_env for _ in range(N_ENVS)]) + vec_env = VecFrameStack(vec_env, n_stack=2) + obs, _, _, _ = vec_env.step(transposed_zero_acts) + vec_env.close() + + # Should be stacked on the first dimension (note the transposing in make_transposed_image_env) + assert obs.shape[1] == (image_space_shape[-1] * 2) + + # Try forcing dimensions + vec_env = DummyVecEnv([make_image_env for _ in range(N_ENVS)]) + vec_env = VecFrameStack(vec_env, n_stack=2, channels_order="last") + obs, _, _, _ = vec_env.step(zero_acts) + vec_env.close() + + # Should be stacked on the last dimension + assert obs.shape[-1] == (image_space_shape[-1] * 2) + + vec_env = DummyVecEnv([make_image_env for _ in range(N_ENVS)]) + vec_env = VecFrameStack(vec_env, n_stack=2, channels_order="first") + obs, _, _, _ = vec_env.step(zero_acts) + vec_env.close() + + # Should be stacked on the first dimension + assert obs.shape[1] == (image_space_shape[0] * 2) + + # Test invalid channels_order + vec_env = DummyVecEnv([make_image_env for _ in range(N_ENVS)]) + with pytest.raises(AssertionError): + vec_env = VecFrameStack(vec_env, n_stack=2, channels_order="not_valid") + + # Test that it works with non-image envs when no channels_order is given + vec_env = DummyVecEnv([make_non_image_env for _ in range(N_ENVS)]) + vec_env = VecFrameStack(vec_env, n_stack=2)