stable-baselines3/tests/test_envs.py
Quentin Gallouédec e3b24829a5
Drop gym.GoalEnv and other minor changes initally from #780 (#1184)
* Various changes from #780

* Fix env_checker for goal_env detection
2022-11-28 18:22:31 +01:00

278 lines
8.1 KiB
Python

import types
import warnings
import gym
import numpy as np
import pytest
from gym import spaces
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.envs import (
BitFlippingEnv,
FakeImageEnv,
IdentityEnv,
IdentityEnvBox,
IdentityEnvMultiBinary,
IdentityEnvMultiDiscrete,
SimpleMultiObsEnv,
)
ENV_CLASSES = [
BitFlippingEnv,
IdentityEnv,
IdentityEnvBox,
IdentityEnvMultiBinary,
IdentityEnvMultiDiscrete,
FakeImageEnv,
SimpleMultiObsEnv,
]
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
def test_env(env_id):
"""
Check that environmnent integrated in Gym pass the test.
:param env_id: (str)
"""
env = gym.make(env_id)
with warnings.catch_warnings(record=True) as record:
check_env(env)
# Pendulum-v1 will produce a warning because the action space is
# in [-2, 2] and not [-1, 1]
if env_id == "Pendulum-v1":
assert len(record) == 1
else:
# The other environments must pass without warning
assert len(record) == 0
@pytest.mark.parametrize("env_class", ENV_CLASSES)
def test_custom_envs(env_class):
env = env_class()
with warnings.catch_warnings(record=True) as record:
check_env(env)
# No warnings for custom envs
assert len(record) == 0
@pytest.mark.parametrize(
"kwargs",
[
dict(continuous=True),
dict(discrete_obs_space=True),
dict(image_obs_space=True, channel_first=True),
dict(image_obs_space=True, channel_first=False),
],
)
def test_bit_flipping(kwargs):
# Additional tests for BitFlippingEnv
env = BitFlippingEnv(**kwargs)
with warnings.catch_warnings(record=True) as record:
check_env(env)
# No warnings for custom envs
assert len(record) == 0
def test_high_dimension_action_space():
"""
Test for continuous action space
with more than one action.
"""
env = FakeImageEnv()
# Patch the action space
env.action_space = spaces.Box(low=-1, high=1, shape=(20,), dtype=np.float32)
# Patch to avoid error
def patched_step(_action):
return env.observation_space.sample(), 0.0, False, {}
env.step = patched_step
check_env(env)
@pytest.mark.parametrize(
"new_obs_space",
[
# Small image
spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8),
# Range not in [0, 255]
spaces.Box(low=0, high=1, shape=(64, 64, 3), dtype=np.uint8),
# Wrong dtype
spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.float32),
# Not an image, it should be a 1D vector
spaces.Box(low=-1, high=1, shape=(64, 3), dtype=np.float32),
# Tuple space is not supported by SB
spaces.Tuple([spaces.Discrete(5), spaces.Discrete(10)]),
# Nested dict space is not supported by SB3
spaces.Dict({"position": spaces.Dict({"abs": spaces.Discrete(5), "rel": spaces.Discrete(2)})}),
# Small image inside a dict
spaces.Dict({"img": spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8)}),
],
)
def test_non_default_spaces(new_obs_space):
env = FakeImageEnv()
env.observation_space = new_obs_space
# Patch methods to avoid errors
env.reset = new_obs_space.sample
def patched_step(_action):
return new_obs_space.sample(), 0.0, False, {}
env.step = patched_step
with pytest.warns(UserWarning):
check_env(env)
@pytest.mark.parametrize(
"new_action_space",
[
# Not symmetric
spaces.Box(low=0, high=1, shape=(3,), dtype=np.float32),
# Wrong dtype
spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float64),
# Too big range
spaces.Box(low=-1000, high=1000, shape=(3,), dtype=np.float32),
# Too small range
spaces.Box(low=-0.1, high=0.1, shape=(2,), dtype=np.float32),
# Inverted boundaries
spaces.Box(low=1, high=-1, shape=(2,), dtype=np.float32),
# Same boundaries
spaces.Box(low=1, high=1, shape=(2,), dtype=np.float32),
# Unbounded action space
spaces.Box(low=-np.inf, high=1, shape=(2,), dtype=np.float32),
# Almost good, except for one dim
spaces.Box(low=np.array([-1, -1, -1]), high=np.array([1, 1, 0.99]), dtype=np.float32),
],
)
def test_non_default_action_spaces(new_action_space):
env = FakeImageEnv(discrete=False)
# Default, should pass the test
with warnings.catch_warnings(record=True) as record:
check_env(env)
# No warnings for custom envs
assert len(record) == 0
# Change the action space
env.action_space = new_action_space
# Unbounded action space throws an error,
# the rest only warning
if not np.all(np.isfinite(env.action_space.low)):
with pytest.raises(AssertionError), pytest.warns(UserWarning):
check_env(env)
else:
with pytest.warns(UserWarning):
check_env(env)
def check_reset_assert_error(env, new_reset_return):
"""
Helper to check that the error is caught.
:param env: (gym.Env)
:param new_reset_return: (Any)
"""
def wrong_reset():
return new_reset_return
# Patch the reset method with a wrong one
env.reset = wrong_reset
with pytest.raises(AssertionError):
check_env(env)
def test_common_failures_reset():
"""
Test that common failure cases of the `reset_method` are caught
"""
env = IdentityEnvBox()
# Return an observation that does not match the observation_space
check_reset_assert_error(env, np.ones((3,)))
# The observation is not a numpy array
check_reset_assert_error(env, 1)
# Return not only the observation
check_reset_assert_error(env, (env.observation_space.sample(), False))
env = SimpleMultiObsEnv()
# Observation keys and observation space keys must match
wrong_obs = env.observation_space.sample()
wrong_obs.pop("img")
check_reset_assert_error(env, wrong_obs)
wrong_obs = {**env.observation_space.sample(), "extra_key": None}
check_reset_assert_error(env, wrong_obs)
obs = env.reset()
def wrong_reset(self):
return {"img": obs["img"], "vec": obs["img"]}
env.reset = types.MethodType(wrong_reset, env)
with pytest.raises(AssertionError) as excinfo:
check_env(env)
# Check that the key is explicitly mentioned
assert "vec" in str(excinfo.value)
def check_step_assert_error(env, new_step_return=()):
"""
Helper to check that the error is caught.
:param env: (gym.Env)
:param new_step_return: (tuple)
"""
def wrong_step(_action):
return new_step_return
# Patch the step method with a wrong one
env.step = wrong_step
with pytest.raises(AssertionError):
check_env(env)
def test_common_failures_step():
"""
Test that common failure cases of the `step` method are caught
"""
env = IdentityEnvBox()
# Wrong shape for the observation
check_step_assert_error(env, (np.ones((4,)), 1.0, False, {}))
# Obs is not a numpy array
check_step_assert_error(env, (1, 1.0, False, {}))
# Return a wrong reward
check_step_assert_error(env, (env.observation_space.sample(), np.ones(1), False, {}))
# Info dict is not returned
check_step_assert_error(env, (env.observation_space.sample(), 0.0, False))
# Done is not a boolean
check_step_assert_error(env, (env.observation_space.sample(), 0.0, 3.0, {}))
check_step_assert_error(env, (env.observation_space.sample(), 0.0, 1, {}))
env = SimpleMultiObsEnv()
# Observation keys and observation space keys must match
wrong_obs = env.observation_space.sample()
wrong_obs.pop("img")
check_step_assert_error(env, (wrong_obs, 0.0, False, {}))
wrong_obs = {**env.observation_space.sample(), "extra_key": None}
check_step_assert_error(env, (wrong_obs, 0.0, False, {}))
obs = env.reset()
def wrong_step(self, action):
return {"img": obs["vec"], "vec": obs["vec"]}, 0.0, False, {}
env.step = types.MethodType(wrong_step, env)
with pytest.raises(AssertionError) as excinfo:
check_env(env)
# Check that the key is explicitly mentioned
assert "img" in str(excinfo.value)