mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Avoid transposing channel-first envs (#213)
* Add test for channel-first environments * Add support for channel-first envs, including more tests * Update changelog * Run black * Run black, again * Improve NatureCNN error message * Update image checks and FrameStack wrapper * Update tests * Update docs * Run isort * Reformat * Fixes: avoid breaking changes for non-image env * Add additional checks * Update docstring Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
9d463bc476
commit
e2b6f5460f
11 changed files with 255 additions and 24 deletions
|
|
@ -8,8 +8,9 @@ That is to say, your environment must implement the following methods (and inher
|
|||
|
||||
|
||||
.. note::
|
||||
If you are using images as input, the input values must be in [0, 255] as the observation
|
||||
is normalized (dividing by 255 to have values in [0, 1]) when using CNN policies.
|
||||
If you are using images as input, the input values must be in [0, 255] and np.uint8 as the observation
|
||||
is normalized (dividing by 255 to have values in [0, 1]) when using CNN policies. Images can be either
|
||||
channel-first or channel-last.
|
||||
|
||||
|
||||
|
||||
|
|
@ -28,7 +29,7 @@ That is to say, your environment must implement the following methods (and inher
|
|||
# They must be gym.spaces objects
|
||||
# Example when using discrete actions:
|
||||
self.action_space = spaces.Discrete(N_DISCRETE_ACTIONS)
|
||||
# Example for using image as input:
|
||||
# Example for using image as input (can be channel-first or channel-last):
|
||||
self.observation_space = spaces.Box(low=0, high=255,
|
||||
shape=(HEIGHT, WIDTH, N_CHANNELS), dtype=np.uint8)
|
||||
|
||||
|
|
|
|||
|
|
@ -55,7 +55,8 @@ Pre-Processing
|
|||
To handle different observation spaces, some pre-processing needs to be done (e.g. one-hot encoding for discrete observation).
|
||||
Most of the code for pre-processing is in ``common/preprocessing.py`` and ``common/policies.py``.
|
||||
|
||||
For images, we make use of an additional wrapper ``VecTransposeImage`` because PyTorch uses the "channel-first" convention.
|
||||
For images, environment is automatically wrapped with ``VecTransposeImage`` if observations are detected to be images with
|
||||
channel-last convention to transform it to PyTorch's channel-first convention.
|
||||
|
||||
|
||||
Policy Structure
|
||||
|
|
|
|||
|
|
@ -6,15 +6,20 @@ Changelog
|
|||
Pre-Release 0.11.0a0 (WIP)
|
||||
-------------------------------
|
||||
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Add support for ``VecFrameStack`` to stack on first or last observation dimension, along with
|
||||
automatic check for image spaces.
|
||||
- ``VecFrameStack`` now has a ``channels_order`` argument to tell if observations should be stacked
|
||||
on the first or last observation dimension (originally always stacked on last).
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fixed bug where code added VecTranspose on channel-first image environments (thanks @qxcv)
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -22,6 +27,7 @@ Deprecations:
|
|||
Others:
|
||||
^^^^^^^
|
||||
- Add more issue templates
|
||||
- Improve error message in ``NatureCNN``
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ Run the benchmark (replace ``$ENV_ID`` by the env id, for instance ``BreakoutNoF
|
|||
|
||||
.. code-block:: bash
|
||||
|
||||
python train.py --algo a2c --env $ENV_ID --eval-episodes 10 --eval-freq 10000
|
||||
python train.py --algo dqn --env $ENV_ID --eval-episodes 10 --eval-freq 10000
|
||||
|
||||
|
||||
Plot the results:
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from stable_baselines3.common.callbacks import BaseCallback, CallbackList, Conve
|
|||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.noise import ActionNoise
|
||||
from stable_baselines3.common.policies import BasePolicy, get_policy_from_name
|
||||
from stable_baselines3.common.preprocessing import is_image_space
|
||||
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
|
||||
from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
|
||||
from stable_baselines3.common.utils import (
|
||||
|
|
@ -176,7 +176,11 @@ class BaseAlgorithm(ABC):
|
|||
print("Wrapping the env in a DummyVecEnv.")
|
||||
env = DummyVecEnv([lambda: env])
|
||||
|
||||
if is_image_space(env.observation_space) and not is_wrapped(env, VecTransposeImage):
|
||||
if (
|
||||
is_image_space(env.observation_space)
|
||||
and not is_wrapped(env, VecTransposeImage)
|
||||
and not is_image_space_channels_first(env.observation_space)
|
||||
):
|
||||
if verbose >= 1:
|
||||
print("Wrapping the env in a VecTransposeImage.")
|
||||
env = VecTransposeImage(env)
|
||||
|
|
|
|||
|
|
@ -112,14 +112,23 @@ class FakeImageEnv(Env):
|
|||
:param screen_height: Height of the image
|
||||
:param screen_width: Width of the image
|
||||
:param n_channels: Number of color channels
|
||||
:param discrete:
|
||||
:param discrete: Create discrete action space instead of continuous
|
||||
:param channel_first: Put channels on first axis instead of last
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, action_dim: int = 6, screen_height: int = 84, screen_width: int = 84, n_channels: int = 1, discrete: bool = True
|
||||
self,
|
||||
action_dim: int = 6,
|
||||
screen_height: int = 84,
|
||||
screen_width: int = 84,
|
||||
n_channels: int = 1,
|
||||
discrete: bool = True,
|
||||
channel_first: bool = False,
|
||||
):
|
||||
|
||||
self.observation_space = Box(low=0, high=255, shape=(screen_height, screen_width, n_channels), dtype=np.uint8)
|
||||
self.observation_shape = (screen_height, screen_width, n_channels)
|
||||
if channel_first:
|
||||
self.observation_shape = (n_channels, screen_height, screen_width)
|
||||
self.observation_space = Box(low=0, high=255, shape=self.observation_shape, dtype=np.uint8)
|
||||
if discrete:
|
||||
self.action_space = Discrete(action_dim)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import warnings
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -6,6 +7,23 @@ from gym import spaces
|
|||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def is_image_space_channels_first(observation_space: spaces.Box) -> bool:
|
||||
"""
|
||||
Check if an image observation space (see ``is_image_space``)
|
||||
is channels-first (CxHxW, True) or channels-last (HxWxC, False).
|
||||
|
||||
Use a heuristic that channel dimension is the smallest of the three.
|
||||
If second dimension is smallest, raise an exception (no support).
|
||||
|
||||
:param observation_space:
|
||||
:return: True if observation space is channels-first image, False if channels-last.
|
||||
"""
|
||||
smallest_dimension = np.argmin(observation_space.shape).item()
|
||||
if smallest_dimension == 1:
|
||||
warnings.warn("Treating image space as channels-last, while second dimension was smallest of the three.")
|
||||
return smallest_dimension == 0
|
||||
|
||||
|
||||
def is_image_space(observation_space: spaces.Space, channels_last: bool = True, check_channels: bool = False) -> bool:
|
||||
"""
|
||||
Check if a observation space has the shape, limits and dtype
|
||||
|
|
|
|||
|
|
@ -65,8 +65,11 @@ class NatureCNN(BaseFeaturesExtractor):
|
|||
# Re-ordering will be done by pre-preprocessing or wrapper
|
||||
assert is_image_space(observation_space), (
|
||||
"You should use NatureCNN "
|
||||
f"only with images not with {observation_space} "
|
||||
"(you are probably using `CnnPolicy` instead of `MlpPolicy`)"
|
||||
f"only with images not with {observation_space}\n"
|
||||
"(you are probably using `CnnPolicy` instead of `MlpPolicy`)\n"
|
||||
"If you are using a custom environment,\n"
|
||||
"please check it using our env checker:\n"
|
||||
"https://stable-baselines3.readthedocs.io/en/master/common/env_checker.html"
|
||||
)
|
||||
n_input_channels = observation_space.shape[0]
|
||||
self.cnn = nn.Sequential(
|
||||
|
|
|
|||
|
|
@ -1,27 +1,52 @@
|
|||
import warnings
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
|
||||
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper
|
||||
|
||||
|
||||
class VecFrameStack(VecEnvWrapper):
|
||||
"""
|
||||
Frame stacking wrapper for vectorized environment
|
||||
Frame stacking wrapper for vectorized environment. Designed for image observations.
|
||||
|
||||
Dimension to stack over is either first (channels-first) or
|
||||
last (channels-last), which is detected automatically using
|
||||
``common.preprocessing.is_image_space_channels_first`` if
|
||||
observation is an image space.
|
||||
|
||||
:param venv: the vectorized environment to wrap
|
||||
:param n_stack: Number of frames to stack
|
||||
:param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension.
|
||||
If None, automatically detect channel to stack over in case of image observation or default to "last" (default).
|
||||
"""
|
||||
|
||||
def __init__(self, venv: VecEnv, n_stack: int):
|
||||
def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[str] = None):
|
||||
self.venv = venv
|
||||
self.n_stack = n_stack
|
||||
|
||||
wrapped_obs_space = venv.observation_space
|
||||
assert isinstance(wrapped_obs_space, spaces.Box), "VecFrameStack only work with gym.spaces.Box observation space"
|
||||
low = np.repeat(wrapped_obs_space.low, self.n_stack, axis=-1)
|
||||
high = np.repeat(wrapped_obs_space.high, self.n_stack, axis=-1)
|
||||
|
||||
if channels_order is None:
|
||||
# Detect channel location automatically for images
|
||||
if is_image_space(wrapped_obs_space):
|
||||
self.channels_first = is_image_space_channels_first(wrapped_obs_space)
|
||||
else:
|
||||
# Default behavior for non-image space, stack on the last axis
|
||||
self.channels_first = False
|
||||
else:
|
||||
assert channels_order in {"last", "first"}, "`channels_order` must be one of following: 'last', 'first'"
|
||||
|
||||
self.channels_first = channels_order == "first"
|
||||
|
||||
# This includes the vec-env dimension (first)
|
||||
self.stack_dimension = 1 if self.channels_first else -1
|
||||
repeat_axis = 0 if self.channels_first else -1
|
||||
low = np.repeat(wrapped_obs_space.low, self.n_stack, axis=repeat_axis)
|
||||
high = np.repeat(wrapped_obs_space.high, self.n_stack, axis=repeat_axis)
|
||||
self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype)
|
||||
observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype)
|
||||
VecEnvWrapper.__init__(self, venv, observation_space=observation_space)
|
||||
|
|
@ -30,18 +55,29 @@ class VecFrameStack(VecEnvWrapper):
|
|||
observations, rewards, dones, infos = self.venv.step_wait()
|
||||
# Let pytype know that observation is not a dict
|
||||
assert isinstance(observations, np.ndarray)
|
||||
last_ax_size = observations.shape[-1]
|
||||
self.stackedobs = np.roll(self.stackedobs, shift=-last_ax_size, axis=-1)
|
||||
stack_ax_size = observations.shape[self.stack_dimension]
|
||||
self.stackedobs = np.roll(self.stackedobs, shift=-stack_ax_size, axis=self.stack_dimension)
|
||||
for i, done in enumerate(dones):
|
||||
if done:
|
||||
if "terminal_observation" in infos[i]:
|
||||
old_terminal = infos[i]["terminal_observation"]
|
||||
new_terminal = np.concatenate((self.stackedobs[i, ..., :-last_ax_size], old_terminal), axis=-1)
|
||||
if self.channels_first:
|
||||
new_terminal = np.concatenate(
|
||||
(self.stackedobs[i, :-stack_ax_size, ...], old_terminal), axis=self.stack_dimension
|
||||
)
|
||||
else:
|
||||
new_terminal = np.concatenate(
|
||||
(self.stackedobs[i, ..., :-stack_ax_size], old_terminal), axis=self.stack_dimension
|
||||
)
|
||||
infos[i]["terminal_observation"] = new_terminal
|
||||
else:
|
||||
warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info")
|
||||
self.stackedobs[i] = 0
|
||||
self.stackedobs[..., -observations.shape[-1] :] = observations
|
||||
if self.channels_first:
|
||||
self.stackedobs[:, -observations.shape[self.stack_dimension] :, ...] = observations
|
||||
else:
|
||||
self.stackedobs[..., -observations.shape[self.stack_dimension] :] = observations
|
||||
|
||||
return self.stackedobs, rewards, dones, infos
|
||||
|
||||
def reset(self) -> np.ndarray:
|
||||
|
|
@ -50,7 +86,10 @@ class VecFrameStack(VecEnvWrapper):
|
|||
"""
|
||||
obs: np.ndarray = self.venv.reset() # pytype:disable=annotation-type-mismatch
|
||||
self.stackedobs[...] = 0
|
||||
self.stackedobs[..., -obs.shape[-1] :] = obs
|
||||
if self.channels_first:
|
||||
self.stackedobs[:, -obs.shape[self.stack_dimension] :, ...] = obs
|
||||
else:
|
||||
self.stackedobs[..., -obs.shape[self.stack_dimension] :] = obs
|
||||
return self.stackedobs
|
||||
|
||||
def close(self) -> None:
|
||||
|
|
|
|||
|
|
@ -4,10 +4,13 @@ from copy import deepcopy
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch as th
|
||||
from gym import spaces
|
||||
|
||||
from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
|
||||
from stable_baselines3.common.identity_env import FakeImageEnv
|
||||
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
|
||||
from stable_baselines3.common.utils import zip_strict
|
||||
from stable_baselines3.common.vec_env import VecTransposeImage, is_wrapped
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN])
|
||||
|
|
@ -25,6 +28,9 @@ def test_cnn(tmp_path, model_class):
|
|||
kwargs = dict(buffer_size=250, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)))
|
||||
model = model_class("CnnPolicy", env, **kwargs).learn(250)
|
||||
|
||||
# FakeImageEnv is channel last by default and should be wrapped
|
||||
assert is_wrapped(model.get_env(), VecTransposeImage)
|
||||
|
||||
obs = env.reset()
|
||||
|
||||
action, _ = model.predict(obs, deterministic=True)
|
||||
|
|
@ -174,3 +180,73 @@ def test_features_extractor_target_net(model_class, share_features_extractor):
|
|||
params_should_match(original_actor_param, model.actor.parameters())
|
||||
|
||||
td3_features_extractor_check(model)
|
||||
|
||||
|
||||
def test_channel_first_env(tmp_path):
|
||||
# test_cnn uses environment with HxWxC setup that is transposed, but we
|
||||
# also want to work with CxHxW envs directly without transposing wrapper.
|
||||
SAVE_NAME = "cnn_model.zip"
|
||||
|
||||
# Create environment with transposed images (CxHxW).
|
||||
# If underlying CNN processes the data in wrong format,
|
||||
# it will raise an error of negative dimension sizes while creating convolutions
|
||||
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=True, channel_first=True)
|
||||
|
||||
model = A2C("CnnPolicy", env, n_steps=100).learn(250)
|
||||
|
||||
assert not is_wrapped(model.get_env(), VecTransposeImage)
|
||||
|
||||
obs = env.reset()
|
||||
|
||||
action, _ = model.predict(obs, deterministic=True)
|
||||
|
||||
model.save(tmp_path / SAVE_NAME)
|
||||
del model
|
||||
|
||||
model = A2C.load(tmp_path / SAVE_NAME)
|
||||
|
||||
# Check that the prediction is the same
|
||||
assert np.allclose(action, model.predict(obs, deterministic=True)[0])
|
||||
|
||||
os.remove(str(tmp_path / SAVE_NAME))
|
||||
|
||||
|
||||
def test_image_space_checks():
|
||||
not_image_space = spaces.Box(0, 1, shape=(10,))
|
||||
assert not is_image_space(not_image_space)
|
||||
|
||||
# Not uint8
|
||||
not_image_space = spaces.Box(0, 255, shape=(10, 10, 3))
|
||||
assert not is_image_space(not_image_space)
|
||||
|
||||
# Not correct shape
|
||||
not_image_space = spaces.Box(0, 255, shape=(10, 10), dtype=np.uint8)
|
||||
assert not is_image_space(not_image_space)
|
||||
|
||||
# Not correct low/high
|
||||
not_image_space = spaces.Box(0, 10, shape=(10, 10, 3), dtype=np.uint8)
|
||||
assert not is_image_space(not_image_space)
|
||||
|
||||
# Not correct space
|
||||
not_image_space = spaces.Discrete(n=10)
|
||||
assert not is_image_space(not_image_space)
|
||||
|
||||
an_image_space = spaces.Box(0, 255, shape=(10, 10, 3), dtype=np.uint8)
|
||||
assert is_image_space(an_image_space)
|
||||
|
||||
an_image_space_with_odd_channels = spaces.Box(0, 255, shape=(10, 10, 5), dtype=np.uint8)
|
||||
assert is_image_space(an_image_space_with_odd_channels)
|
||||
# Should not pass if we check if channels are valid for an image
|
||||
assert not is_image_space(an_image_space_with_odd_channels, check_channels=True)
|
||||
|
||||
# Test if channel-check works
|
||||
channel_first_space = spaces.Box(0, 255, shape=(3, 10, 10), dtype=np.uint8)
|
||||
assert is_image_space_channels_first(channel_first_space)
|
||||
|
||||
channel_last_space = spaces.Box(0, 255, shape=(10, 10, 3), dtype=np.uint8)
|
||||
assert not is_image_space_channels_first(channel_last_space)
|
||||
|
||||
channel_mid_space = spaces.Box(0, 255, shape=(10, 3, 10), dtype=np.uint8)
|
||||
# Should raise a warning
|
||||
with pytest.warns(Warning):
|
||||
assert not is_image_space_channels_first(channel_mid_space)
|
||||
|
|
|
|||
|
|
@ -341,3 +341,77 @@ def test_vecenv_wrapper_getattr():
|
|||
_ = double_wrapped.var_b
|
||||
with pytest.raises(AttributeError): # should raise as does not exist
|
||||
_ = double_wrapped.nonexistent_attribute
|
||||
|
||||
|
||||
def test_framestack_vecenv():
|
||||
"""Test that framestack environment stacks on desired axis"""
|
||||
|
||||
image_space_shape = [12, 8, 3]
|
||||
zero_acts = np.zeros([N_ENVS] + image_space_shape)
|
||||
|
||||
transposed_image_space_shape = image_space_shape[::-1]
|
||||
transposed_zero_acts = np.zeros([N_ENVS] + transposed_image_space_shape)
|
||||
|
||||
def make_image_env():
|
||||
return CustomGymEnv(
|
||||
gym.spaces.Box(
|
||||
low=np.zeros(image_space_shape),
|
||||
high=np.ones(image_space_shape) * 255,
|
||||
dtype=np.uint8,
|
||||
)
|
||||
)
|
||||
|
||||
def make_transposed_image_env():
|
||||
return CustomGymEnv(
|
||||
gym.spaces.Box(
|
||||
low=np.zeros(transposed_image_space_shape),
|
||||
high=np.ones(transposed_image_space_shape) * 255,
|
||||
dtype=np.uint8,
|
||||
)
|
||||
)
|
||||
|
||||
def make_non_image_env():
|
||||
return CustomGymEnv(gym.spaces.Box(low=np.zeros((2,)), high=np.ones((2,))))
|
||||
|
||||
vec_env = DummyVecEnv([make_image_env for _ in range(N_ENVS)])
|
||||
vec_env = VecFrameStack(vec_env, n_stack=2)
|
||||
obs, _, _, _ = vec_env.step(zero_acts)
|
||||
vec_env.close()
|
||||
|
||||
# Should be stacked on the last dimension
|
||||
assert obs.shape[-1] == (image_space_shape[-1] * 2)
|
||||
|
||||
# Try automatic stacking on first dimension now
|
||||
vec_env = DummyVecEnv([make_transposed_image_env for _ in range(N_ENVS)])
|
||||
vec_env = VecFrameStack(vec_env, n_stack=2)
|
||||
obs, _, _, _ = vec_env.step(transposed_zero_acts)
|
||||
vec_env.close()
|
||||
|
||||
# Should be stacked on the first dimension (note the transposing in make_transposed_image_env)
|
||||
assert obs.shape[1] == (image_space_shape[-1] * 2)
|
||||
|
||||
# Try forcing dimensions
|
||||
vec_env = DummyVecEnv([make_image_env for _ in range(N_ENVS)])
|
||||
vec_env = VecFrameStack(vec_env, n_stack=2, channels_order="last")
|
||||
obs, _, _, _ = vec_env.step(zero_acts)
|
||||
vec_env.close()
|
||||
|
||||
# Should be stacked on the last dimension
|
||||
assert obs.shape[-1] == (image_space_shape[-1] * 2)
|
||||
|
||||
vec_env = DummyVecEnv([make_image_env for _ in range(N_ENVS)])
|
||||
vec_env = VecFrameStack(vec_env, n_stack=2, channels_order="first")
|
||||
obs, _, _, _ = vec_env.step(zero_acts)
|
||||
vec_env.close()
|
||||
|
||||
# Should be stacked on the first dimension
|
||||
assert obs.shape[1] == (image_space_shape[0] * 2)
|
||||
|
||||
# Test invalid channels_order
|
||||
vec_env = DummyVecEnv([make_image_env for _ in range(N_ENVS)])
|
||||
with pytest.raises(AssertionError):
|
||||
vec_env = VecFrameStack(vec_env, n_stack=2, channels_order="not_valid")
|
||||
|
||||
# Test that it works with non-image envs when no channels_order is given
|
||||
vec_env = DummyVecEnv([make_non_image_env for _ in range(N_ENVS)])
|
||||
vec_env = VecFrameStack(vec_env, n_stack=2)
|
||||
|
|
|
|||
Loading…
Reference in a new issue