mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-01 03:45:11 +00:00
Fix support of image like normalized inputs (#1214)
* Fix support of image like normalized inputs * Improve docstring and warning message. * Don't check if obs is image when normalize_images is False (lil opt) * Comment fix * Fix normalize_images not passed to parent * Check for subclasses too * Remove useless multiline * Update version and add comment * Fix some typos Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
This commit is contained in:
parent
ca944fed2d
commit
8452106734
20 changed files with 163 additions and 37 deletions
|
|
@ -105,7 +105,7 @@ train/
|
|||
- ``loss``: Current total loss value
|
||||
- ``n_updates``: Number of gradient updates applied so far
|
||||
- ``policy_gradient_loss``: Current value of the policy gradient loss (its value does not have much meaning)
|
||||
- ``value_loss``: Current value for the value function loss for on-policy algorithms, usually error between value function output and Monte-Carle estimate (or TD(lambda) estimate)
|
||||
- ``value_loss``: Current value for the value function loss for on-policy algorithms, usually error between value function output and Monte-Carlo estimate (or TD(lambda) estimate)
|
||||
- ``std``: Current standard deviation of the noise when using generalized State-Dependent Exploration (gSDE)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -8,9 +8,13 @@ That is to say, your environment must implement the following methods (and inher
|
|||
|
||||
|
||||
.. note::
|
||||
If you are using images as input, the observation must be of type ``np.uint8`` and be contained in [0, 255]
|
||||
is normalized (dividing by 255 to have values in [0, 1]) when using CNN policies. Images can be either
|
||||
channel-first or channel-last.
|
||||
If you are using images as input, the observation must be of type ``np.uint8`` and be contained in [0, 255].
|
||||
By default, the observation is normalized by SB3 pre-processing (dividing by 255 to have values in [0, 1]) when using CNN policies.
|
||||
Images can be either channel-first or channel-last.
|
||||
|
||||
If you want to use ``CnnPolicy`` or ``MultiInputPolicy`` with image-like observation (3D tensor) that are already normalized, you must pass ``normalize_images=False``
|
||||
to the policy (using ``policy_kwargs`` parameter, ``policy_kwargs=dict(normalize_images=False)``)
|
||||
and make sure your image is in the **channel-first** format.
|
||||
|
||||
|
||||
.. note::
|
||||
|
|
|
|||
|
|
@ -139,7 +139,7 @@ that derives from ``BaseFeaturesExtractor`` and then pass it to the model when t
|
|||
"""
|
||||
|
||||
def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256):
|
||||
super(CustomCNN, self).__init__(observation_space, features_dim)
|
||||
super().__init__(observation_space, features_dim)
|
||||
# We assume CxHxW images (channels first)
|
||||
# Re-ordering will be done by pre-preprocessing or wrapper
|
||||
n_input_channels = observation_space.shape[0]
|
||||
|
|
@ -201,7 +201,7 @@ downsampling and "vector" with a single linear layer.
|
|||
# We do not know features-dim here before going over all the items,
|
||||
# so put something dummy for now. PyTorch requires calling
|
||||
# nn.Module.__init__ before adding modules
|
||||
super(CustomCombinedExtractor, self).__init__(observation_space, features_dim=1)
|
||||
super().__init__(observation_space, features_dim=1)
|
||||
|
||||
extractors = {}
|
||||
|
||||
|
|
@ -374,7 +374,7 @@ If your task requires even more granular control over the policy/value architect
|
|||
**kwargs,
|
||||
):
|
||||
|
||||
super(CustomActorCriticPolicy, self).__init__(
|
||||
super().__init__(
|
||||
observation_space,
|
||||
action_space,
|
||||
lr_schedule,
|
||||
|
|
|
|||
|
|
@ -194,7 +194,7 @@ You can use environments with dictionary observation spaces. This is useful in t
|
|||
concatenate observations such as an image from a camera combined with a vector of servo sensor data (e.g., rotation angles).
|
||||
Stable Baselines3 provides ``SimpleMultiObsEnv`` as an example of this kind of of setting.
|
||||
The environment is a simple grid world but the observations for each cell come in the form of dictionaries.
|
||||
These dictionaries are randomly initilaized on the creation of the environment and contain a vector observation and an image observation.
|
||||
These dictionaries are randomly initialized on the creation of the environment and contain a vector observation and an image observation.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ As of June 2021, ONNX format `doesn't support <https://github.com/onnx/onnx/iss
|
|||
|
||||
The following examples are for ``MlpPolicy`` only, and are general examples. Note that you have to preprocess the observation the same way stable-baselines3 agent does (see ``common.preprocessing.preprocess_obs``).
|
||||
|
||||
For PPO, assuming a shared feature extactor.
|
||||
For PPO, assuming a shared feature extractor.
|
||||
|
||||
.. warning::
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.7.0a6 (WIP)
|
||||
Release 1.7.0a7 (WIP)
|
||||
--------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -13,12 +13,16 @@ Breaking Changes:
|
|||
please use an ``EvalCallback`` instead
|
||||
- Removed deprecated ``sde_net_arch`` parameter
|
||||
- Removed ``ret`` attributes in ``VecNormalize``, please use ``returns`` instead
|
||||
- ``VecNormalize`` now updates the observation space when normalizing images
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Introduced mypy type checking
|
||||
- Added ``with_bias`` argument to ``create_mlp``
|
||||
- Added support for multidimensional ``spaces.MultiBinary`` observations
|
||||
- Features extractors now properly support unnormalized image-like observations (3D tensor)
|
||||
when passing ``normalize_images=False``
|
||||
- Added ``normalized_image`` parameter to ``NatureCNN`` and ``CombinedExtractor``
|
||||
|
||||
SB3-Contrib
|
||||
^^^^^^^^^^^
|
||||
|
|
@ -31,6 +35,8 @@ Bug Fixes:
|
|||
- Raise an error when the same gym environment instance is passed as separate environments when creating a vectorized environment with more than one environment. (@Rocamonde)
|
||||
- Fix type annotation of ``model`` in ``evaluate_policy``
|
||||
- Fixed ``Self`` return type using ``TypeVar``
|
||||
- Fixed the env checker, the key was not passed when checking images from Dict observation space
|
||||
- Fixed ``normalize_images`` which was not passed to parent class in some cases
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -58,6 +64,7 @@ Documentation:
|
|||
- Changed ``env`` to ``vec_env`` when environment is vectorized
|
||||
- Updated custom policy docs to better explain the ``mlp_extractor``'s dimensions (@AlexPasqua)
|
||||
- Update custom policy documentation (@athatheo)
|
||||
- Clarify doc when using image-like input
|
||||
|
||||
Release 1.6.2 (2022-10-10)
|
||||
--------------------------
|
||||
|
|
|
|||
|
|
@ -155,7 +155,7 @@ Driving policies can be trained in different scenarios, and several notebooks us
|
|||
tactile-gym
|
||||
-------------------
|
||||
|
||||
Suite of RL environments focussed on using a simulated tactile sensor as the primary source of observations. Sim-to-Real results across 4 out of 5 proposed envs.
|
||||
Suite of RL environments focused on using a simulated tactile sensor as the primary source of observations. Sim-to-Real results across 4 out of 5 proposed envs.
|
||||
|
||||
| Author: Alex Church
|
||||
| GitHub: https://github.com/ac-93/tactile_gym
|
||||
|
|
|
|||
|
|
@ -21,11 +21,15 @@ def _check_image_input(observation_space: spaces.Box, key: str = "") -> None:
|
|||
"""
|
||||
Check that the input will be compatible with Stable-Baselines
|
||||
when the observation is apparently an image.
|
||||
|
||||
:param observation_space: Observation space
|
||||
:key: When the observation space comes from a Dict space, we pass the
|
||||
corresponding key to have more precise warning messages. Defaults to "".
|
||||
"""
|
||||
if observation_space.dtype != np.uint8:
|
||||
warnings.warn(
|
||||
f"It seems that your observation {key} is an image but the `dtype` "
|
||||
"of your observation_space is not `np.uint8`. "
|
||||
f"It seems that your observation {key} is an image but its `dtype` "
|
||||
f"is ({observation_space.dtype}) whereas it has to be `np.uint8`. "
|
||||
"If your observation is not an image, we recommend you to flatten the observation "
|
||||
"to have only a 1D vector"
|
||||
)
|
||||
|
|
@ -180,7 +184,7 @@ def _check_box_obs(observation_space: spaces.Box, key: str = "") -> None:
|
|||
# If image, check the low and high values, the type and the number of channels
|
||||
# and the shape (minimal value)
|
||||
if len(observation_space.shape) == 3:
|
||||
_check_image_input(observation_space)
|
||||
_check_image_input(observation_space, key)
|
||||
|
||||
if len(observation_space.shape) not in [1, 3]:
|
||||
warnings.warn(
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ class Image:
|
|||
|
||||
class HParam:
|
||||
"""
|
||||
Hyperparameter data class storing hyperparameters and metrics in dictionnaries
|
||||
Hyperparameter data class storing hyperparameters and metrics in dictionaries
|
||||
|
||||
:param hparam_dict: key-value pairs of hyperparameters to log
|
||||
:param metric_dict: key-value pairs of metrics to log
|
||||
|
|
|
|||
|
|
@ -87,6 +87,9 @@ class BaseModel(nn.Module):
|
|||
|
||||
self.features_extractor_class = features_extractor_class
|
||||
self.features_extractor_kwargs = features_extractor_kwargs
|
||||
# Automatically deactivate dtype and bounds checks
|
||||
if normalize_images is False and issubclass(features_extractor_class, (NatureCNN, CombinedExtractor)):
|
||||
self.features_extractor_kwargs.update(dict(normalized_image=True))
|
||||
|
||||
def _update_features_extractor(
|
||||
self,
|
||||
|
|
@ -430,6 +433,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
optimizer_class=optimizer_class,
|
||||
optimizer_kwargs=optimizer_kwargs,
|
||||
squash_output=squash_output,
|
||||
normalize_images=normalize_images,
|
||||
)
|
||||
|
||||
# Default network architecture, from stable-baselines
|
||||
|
|
@ -446,7 +450,6 @@ class ActorCriticPolicy(BasePolicy):
|
|||
self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
|
||||
self.features_dim = self.features_extractor.features_dim
|
||||
|
||||
self.normalize_images = normalize_images
|
||||
self.log_std_init = log_std_init
|
||||
dist_kwargs = None
|
||||
# Keyword arguments for gSDE distribution
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ def is_image_space_channels_first(observation_space: spaces.Box) -> bool:
|
|||
def is_image_space(
|
||||
observation_space: spaces.Space,
|
||||
check_channels: bool = False,
|
||||
normalized_image: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a observation space has the shape, limits and dtype
|
||||
|
|
@ -38,15 +39,21 @@ def is_image_space(
|
|||
:param observation_space:
|
||||
:param check_channels: Whether to do or not the check for the number of channels.
|
||||
e.g., with frame-stacking, the observation space may have more channels than expected.
|
||||
:param normalized_image: Whether to assume that the image is already normalized
|
||||
or not (this disables dtype and bounds checks): when True, it only checks that
|
||||
the space is a Box and has 3 dimensions.
|
||||
Otherwise, it checks that it has expected dtype (uint8) and bounds (values in [0, 255]).
|
||||
:return:
|
||||
"""
|
||||
check_dtype = check_bounds = not normalized_image
|
||||
if isinstance(observation_space, spaces.Box) and len(observation_space.shape) == 3:
|
||||
# Check the type
|
||||
if observation_space.dtype != np.uint8:
|
||||
if check_dtype and observation_space.dtype != np.uint8:
|
||||
return False
|
||||
|
||||
# Check the value range
|
||||
if np.any(observation_space.low != 0) or np.any(observation_space.high != 255):
|
||||
incorrect_bounds = np.any(observation_space.low != 0) or np.any(observation_space.high != 255)
|
||||
if check_bounds and incorrect_bounds:
|
||||
return False
|
||||
|
||||
# Skip channels check
|
||||
|
|
@ -57,7 +64,7 @@ def is_image_space(
|
|||
n_channels = observation_space.shape[0]
|
||||
else:
|
||||
n_channels = observation_space.shape[-1]
|
||||
# RGB, RGBD, GrayScale
|
||||
# GrayScale, RGB, RGBD
|
||||
return n_channels in [1, 3, 4]
|
||||
return False
|
||||
|
||||
|
|
@ -99,7 +106,7 @@ def preprocess_obs(
|
|||
:return:
|
||||
"""
|
||||
if isinstance(observation_space, spaces.Box):
|
||||
if is_image_space(observation_space) and normalize_images:
|
||||
if normalize_images and is_image_space(observation_space):
|
||||
return obs.float() / 255.0
|
||||
return obs.float()
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ class BaseFeaturesExtractor(nn.Module):
|
|||
:param features_dim: Number of features extracted.
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space: gym.Space, features_dim: int = 0):
|
||||
def __init__(self, observation_space: gym.Space, features_dim: int = 0) -> None:
|
||||
super().__init__()
|
||||
assert features_dim > 0
|
||||
self._observation_space = observation_space
|
||||
|
|
@ -37,7 +37,7 @@ class FlattenExtractor(BaseFeaturesExtractor):
|
|||
:param observation_space:
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space: gym.Space):
|
||||
def __init__(self, observation_space: gym.Space) -> None:
|
||||
super().__init__(observation_space, get_flattened_obs_dim(observation_space))
|
||||
self.flatten = nn.Flatten()
|
||||
|
||||
|
|
@ -55,19 +55,31 @@ class NatureCNN(BaseFeaturesExtractor):
|
|||
:param observation_space:
|
||||
:param features_dim: Number of features extracted.
|
||||
This corresponds to the number of unit for the last layer.
|
||||
:param normalized_image: Whether to assume that the image is already normalized
|
||||
or not (this disables dtype and bounds checks): when True, it only checks that
|
||||
the space is a Box and has 3 dimensions.
|
||||
Otherwise, it checks that it has expected dtype (uint8) and bounds (values in [0, 255]).
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512):
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Box,
|
||||
features_dim: int = 512,
|
||||
normalized_image: bool = False,
|
||||
) -> None:
|
||||
super().__init__(observation_space, features_dim)
|
||||
# We assume CxHxW images (channels first)
|
||||
# Re-ordering will be done by pre-preprocessing or wrapper
|
||||
assert is_image_space(observation_space, check_channels=False), (
|
||||
assert is_image_space(observation_space, check_channels=False, normalized_image=normalized_image), (
|
||||
"You should use NatureCNN "
|
||||
f"only with images not with {observation_space}\n"
|
||||
"(you are probably using `CnnPolicy` instead of `MlpPolicy` or `MultiInputPolicy`)\n"
|
||||
"If you are using a custom environment,\n"
|
||||
"please check it using our env checker:\n"
|
||||
"https://stable-baselines3.readthedocs.io/en/master/common/env_checker.html"
|
||||
"https://stable-baselines3.readthedocs.io/en/master/common/env_checker.html.\n"
|
||||
"If you are using `VecNormalize` or already normalized channel-first images "
|
||||
"you should pass `normalize_images=False`: \n"
|
||||
"https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html"
|
||||
)
|
||||
n_input_channels = observation_space.shape[0]
|
||||
self.cnn = nn.Sequential(
|
||||
|
|
@ -167,7 +179,7 @@ class MlpExtractor(nn.Module):
|
|||
net_arch: List[Union[int, Dict[str, List[int]]]],
|
||||
activation_fn: Type[nn.Module],
|
||||
device: Union[th.device, str] = "auto",
|
||||
):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
device = get_device(device)
|
||||
shared_net: List[nn.Module] = []
|
||||
|
|
@ -247,9 +259,18 @@ class CombinedExtractor(BaseFeaturesExtractor):
|
|||
:param observation_space:
|
||||
:param cnn_output_dim: Number of features to output from each CNN submodule(s). Defaults to
|
||||
256 to avoid exploding network sizes.
|
||||
:param normalized_image: Whether to assume that the image is already normalized
|
||||
or not (this disables dtype and bounds checks): when True, it only checks that
|
||||
the space is a Box and has 3 dimensions.
|
||||
Otherwise, it checks that it has expected dtype (uint8) and bounds (values in [0, 255]).
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space: gym.spaces.Dict, cnn_output_dim: int = 256):
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.spaces.Dict,
|
||||
cnn_output_dim: int = 256,
|
||||
normalized_image: bool = False,
|
||||
) -> None:
|
||||
# TODO we do not know features-dim here before going over all the items, so put something there. This is dirty!
|
||||
super().__init__(observation_space, features_dim=1)
|
||||
|
||||
|
|
@ -257,8 +278,8 @@ class CombinedExtractor(BaseFeaturesExtractor):
|
|||
|
||||
total_concat_size = 0
|
||||
for key, subspace in observation_space.spaces.items():
|
||||
if is_image_space(subspace):
|
||||
extractors[key] = NatureCNN(subspace, features_dim=cnn_output_dim)
|
||||
if is_image_space(subspace, normalized_image=normalized_image):
|
||||
extractors[key] = NatureCNN(subspace, features_dim=cnn_output_dim, normalized_image=normalized_image)
|
||||
total_concat_size += cnn_output_dim
|
||||
else:
|
||||
# The observation key is a vector, flatten it if needed
|
||||
|
|
|
|||
|
|
@ -187,7 +187,7 @@ class StackedDictObservations(StackedObservations):
|
|||
|
||||
def stack_observation_space(self, observation_space: spaces.Dict) -> spaces.Dict:
|
||||
"""
|
||||
Returns the stacked verson of a Dict observation space
|
||||
Returns the stacked version of a Dict observation space
|
||||
|
||||
:param observation_space: Dict observation space to stack
|
||||
:return: stacked observation space
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import gym
|
|||
import numpy as np
|
||||
|
||||
from stable_baselines3.common import utils
|
||||
from stable_baselines3.common.preprocessing import is_image_space
|
||||
from stable_baselines3.common.running_mean_std import RunningMeanStd
|
||||
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
|
||||
|
||||
|
|
@ -50,9 +51,35 @@ class VecNormalize(VecEnvWrapper):
|
|||
if isinstance(self.observation_space, gym.spaces.Dict):
|
||||
self.obs_spaces = self.observation_space.spaces
|
||||
self.obs_rms = {key: RunningMeanStd(shape=self.obs_spaces[key].shape) for key in self.norm_obs_keys}
|
||||
# Update observation space when using image
|
||||
# See explanation below and GH #1214
|
||||
for key in self.obs_rms.keys():
|
||||
if is_image_space(self.obs_spaces[key]):
|
||||
self.observation_space.spaces[key] = gym.spaces.Box(
|
||||
low=-clip_obs,
|
||||
high=clip_obs,
|
||||
shape=self.obs_spaces[key].shape,
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
else:
|
||||
self.obs_spaces = None
|
||||
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
|
||||
# Update observation space when using image
|
||||
# See GH #1214
|
||||
# This is to raise proper error when
|
||||
# VecNormalize is used with an image-like input and
|
||||
# normalize_images=True.
|
||||
# For correctness, we should also update the bounds
|
||||
# in other cases but this will cause backward-incompatible change
|
||||
# and break already saved policies.
|
||||
if is_image_space(self.observation_space):
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=-clip_obs,
|
||||
high=clip_obs,
|
||||
shape=self.observation_space.shape,
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
self.ret_rms = RunningMeanStd(shape=())
|
||||
self.clip_obs = clip_obs
|
||||
|
|
|
|||
|
|
@ -51,7 +51,6 @@ class QNetwork(BasePolicy):
|
|||
self.activation_fn = activation_fn
|
||||
self.features_extractor = features_extractor
|
||||
self.features_dim = features_dim
|
||||
self.normalize_images = normalize_images
|
||||
action_dim = self.action_space.n # number of actions
|
||||
q_net = create_mlp(self.features_dim, action_dim, self.net_arch, self.activation_fn)
|
||||
self.q_net = nn.Sequential(*q_net)
|
||||
|
|
@ -125,6 +124,7 @@ class DQNPolicy(BasePolicy):
|
|||
features_extractor_kwargs,
|
||||
optimizer_class=optimizer_class,
|
||||
optimizer_kwargs=optimizer_kwargs,
|
||||
normalize_images=normalize_images,
|
||||
)
|
||||
|
||||
if net_arch is None:
|
||||
|
|
@ -135,7 +135,6 @@ class DQNPolicy(BasePolicy):
|
|||
|
||||
self.net_arch = net_arch
|
||||
self.activation_fn = activation_fn
|
||||
self.normalize_images = normalize_images
|
||||
|
||||
self.net_args = {
|
||||
"observation_space": self.observation_space,
|
||||
|
|
|
|||
|
|
@ -232,6 +232,7 @@ class SACPolicy(BasePolicy):
|
|||
optimizer_class=optimizer_class,
|
||||
optimizer_kwargs=optimizer_kwargs,
|
||||
squash_output=True,
|
||||
normalize_images=normalize_images,
|
||||
)
|
||||
|
||||
if net_arch is None:
|
||||
|
|
|
|||
|
|
@ -129,6 +129,7 @@ class TD3Policy(BasePolicy):
|
|||
optimizer_class=optimizer_class,
|
||||
optimizer_kwargs=optimizer_kwargs,
|
||||
squash_output=True,
|
||||
normalize_images=normalize_images,
|
||||
)
|
||||
|
||||
# Default network architecture, from the original paper
|
||||
|
|
@ -177,7 +178,7 @@ class TD3Policy(BasePolicy):
|
|||
|
||||
if self.share_features_extractor:
|
||||
self.critic = self.make_critic(features_extractor=self.actor.features_extractor)
|
||||
# Critic target should not share the features extactor with critic
|
||||
# Critic target should not share the features extractor with critic
|
||||
# but it can share it with the actor target as actor and critic are sharing
|
||||
# the same features_extractor too
|
||||
# NOTE: as a result the effective poliak (soft-copy) coefficient for the features extractor
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.7.0a6
|
||||
1.7.0a7
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
|
|||
from stable_baselines3.common.envs import FakeImageEnv
|
||||
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
|
||||
from stable_baselines3.common.utils import zip_strict
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecTransposeImage, is_vecenv_wrapped
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNormalize, VecTransposeImage, is_vecenv_wrapped
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN])
|
||||
|
|
@ -269,6 +269,10 @@ def test_image_space_checks():
|
|||
not_image_space = spaces.Box(0, 10, shape=(10, 10, 3), dtype=np.uint8)
|
||||
assert not is_image_space(not_image_space)
|
||||
|
||||
# Deactivate dtype and bound checking
|
||||
normalized_image = spaces.Box(0, 1, shape=(10, 10, 3), dtype=np.float32)
|
||||
assert is_image_space(normalized_image, normalized_image=True)
|
||||
|
||||
# Not correct space
|
||||
not_image_space = spaces.Discrete(n=10)
|
||||
assert not is_image_space(not_image_space)
|
||||
|
|
@ -297,3 +301,44 @@ def test_image_space_checks():
|
|||
# Should raise a warning
|
||||
with pytest.warns(Warning):
|
||||
assert not is_image_space_channels_first(channel_mid_space)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, PPO, DQN, SAC, TD3])
|
||||
@pytest.mark.parametrize("normalize_images", [True, False])
|
||||
def test_image_like_input(model_class, normalize_images):
|
||||
"""
|
||||
Check that we can handle image-like input (3D tensor)
|
||||
when normalize_images=False
|
||||
"""
|
||||
# Fake grayscale with frameskip
|
||||
# Atari after preprocessing: 84x84x1, here we are using lower resolution
|
||||
# to check that the network handle it automatically
|
||||
env = FakeImageEnv(
|
||||
screen_height=36,
|
||||
screen_width=36,
|
||||
n_channels=1,
|
||||
channel_first=True,
|
||||
discrete=model_class not in {SAC, TD3},
|
||||
)
|
||||
vec_env = VecNormalize(DummyVecEnv([lambda: env]))
|
||||
# Reduce the size of the features
|
||||
# deactivate normalization
|
||||
kwargs = dict(
|
||||
policy_kwargs=dict(
|
||||
normalize_images=normalize_images,
|
||||
features_extractor_kwargs=dict(features_dim=32),
|
||||
),
|
||||
seed=1,
|
||||
)
|
||||
|
||||
if model_class in {A2C, PPO}:
|
||||
kwargs.update(dict(n_steps=64))
|
||||
else:
|
||||
# Avoid memory error when using replay buffer
|
||||
# Reduce the size of the features
|
||||
kwargs.update(dict(buffer_size=250))
|
||||
if normalize_images:
|
||||
with pytest.raises(AssertionError):
|
||||
model_class("CnnPolicy", vec_env, **kwargs).learn(128)
|
||||
else:
|
||||
model_class("CnnPolicy", vec_env, **kwargs).learn(128)
|
||||
|
|
|
|||
|
|
@ -28,8 +28,8 @@ class DummyDictEnv(gym.Env):
|
|||
else:
|
||||
self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
|
||||
N_CHANNELS = 1
|
||||
HEIGHT = 64
|
||||
WIDTH = 64
|
||||
HEIGHT = 36
|
||||
WIDTH = 36
|
||||
|
||||
if channel_last:
|
||||
obs_shape = (HEIGHT, WIDTH, N_CHANNELS)
|
||||
|
|
@ -323,3 +323,10 @@ def test_dict_nested():
|
|||
|
||||
with pytest.raises(NotImplementedError):
|
||||
env = DummyVecEnv([lambda: DummyDictEnv(nested_dict_obs=True)])
|
||||
|
||||
|
||||
def test_vec_normalize_image():
|
||||
env = VecNormalize(DummyVecEnv([lambda: DummyDictEnv()]), norm_obs_keys=["img"])
|
||||
assert env.observation_space.spaces["img"].dtype == np.float32
|
||||
assert (env.observation_space.spaces["img"].low == -env.clip_obs).all()
|
||||
assert (env.observation_space.spaces["img"].high == env.clip_obs).all()
|
||||
|
|
|
|||
Loading…
Reference in a new issue