stable-baselines3/tests/test_envs.py
Antonin RAFFIN e24147390d
Improve tests and add check for float32 (#686)
* Add additional checks

* Improve tests and error message

* Update changelog

* Bump version

* Update doc

* Add tests for action space

* Improve test
2021-12-09 14:14:33 +02:00

253 lines
7.1 KiB
Python

import types
import gym
import numpy as np
import pytest
from gym import spaces
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.envs import (
BitFlippingEnv,
FakeImageEnv,
IdentityEnv,
IdentityEnvBox,
IdentityEnvMultiBinary,
IdentityEnvMultiDiscrete,
SimpleMultiObsEnv,
)
ENV_CLASSES = [
BitFlippingEnv,
IdentityEnv,
IdentityEnvBox,
IdentityEnvMultiBinary,
IdentityEnvMultiDiscrete,
FakeImageEnv,
SimpleMultiObsEnv,
]
@pytest.mark.parametrize("env_id", ["CartPole-v0", "Pendulum-v0"])
def test_env(env_id):
"""
Check that environmnent integrated in Gym pass the test.
:param env_id: (str)
"""
env = gym.make(env_id)
with pytest.warns(None) as record:
check_env(env)
# Pendulum-v0 will produce a warning because the action space is
# in [-2, 2] and not [-1, 1]
if env_id == "Pendulum-v0":
assert len(record) == 1
else:
# The other environments must pass without warning
assert len(record) == 0
@pytest.mark.parametrize("env_class", ENV_CLASSES)
def test_custom_envs(env_class):
env = env_class()
with pytest.warns(None) as record:
check_env(env)
# No warnings for custom envs
assert len(record) == 0
@pytest.mark.parametrize(
"kwargs",
[
dict(continuous=True),
dict(discrete_obs_space=True),
dict(image_obs_space=True, channel_first=True),
dict(image_obs_space=True, channel_first=False),
],
)
def test_bit_flipping(kwargs):
# Additional tests for BitFlippingEnv
env = BitFlippingEnv(**kwargs)
with pytest.warns(None) as record:
check_env(env)
# No warnings for custom envs
assert len(record) == 0
def test_high_dimension_action_space():
"""
Test for continuous action space
with more than one action.
"""
env = FakeImageEnv()
# Patch the action space
env.action_space = spaces.Box(low=-1, high=1, shape=(20,), dtype=np.float32)
# Patch to avoid error
def patched_step(_action):
return env.observation_space.sample(), 0.0, False, {}
env.step = patched_step
check_env(env)
@pytest.mark.parametrize(
"new_obs_space",
[
# Small image
spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8),
# Range not in [0, 255]
spaces.Box(low=0, high=1, shape=(64, 64, 3), dtype=np.uint8),
# Wrong dtype
spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.float32),
# Not an image, it should be a 1D vector
spaces.Box(low=-1, high=1, shape=(64, 3), dtype=np.float32),
# Tuple space is not supported by SB
spaces.Tuple([spaces.Discrete(5), spaces.Discrete(10)]),
# Nested dict space is not supported by SB3
spaces.Dict({"position": spaces.Dict({"abs": spaces.Discrete(5), "rel": spaces.Discrete(2)})}),
# Small image inside a dict
spaces.Dict({"img": spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8)}),
],
)
def test_non_default_spaces(new_obs_space):
env = FakeImageEnv()
env.observation_space = new_obs_space
# Patch methods to avoid errors
env.reset = new_obs_space.sample
def patched_step(_action):
return new_obs_space.sample(), 0.0, False, {}
env.step = patched_step
with pytest.warns(UserWarning):
check_env(env)
@pytest.mark.parametrize(
"new_action_space",
[
# Not symmetric
spaces.Box(low=0, high=1, shape=(3,), dtype=np.float32),
# Wrong dtype
spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float64),
# Too big range
spaces.Box(low=-1000, high=1000, shape=(3,), dtype=np.float32),
# Too small range
spaces.Box(low=-0.1, high=0.1, shape=(2,), dtype=np.float32),
# Inverted boundaries
spaces.Box(low=1, high=-1, shape=(2,), dtype=np.float32),
# Same boundaries
spaces.Box(low=1, high=1, shape=(2,), dtype=np.float32),
# Almost good, except for one dim
spaces.Box(low=np.array([-1, -1, -1]), high=np.array([1, 1, 0.99]), dtype=np.float32),
],
)
def test_non_default_action_spaces(new_action_space):
env = FakeImageEnv(discrete=False)
# Default, should pass the test
with pytest.warns(None) as record:
check_env(env)
# No warnings for custom envs
assert len(record) == 0
# Change the action space
env.action_space = new_action_space
with pytest.warns(UserWarning):
check_env(env)
def check_reset_assert_error(env, new_reset_return):
"""
Helper to check that the error is caught.
:param env: (gym.Env)
:param new_reset_return: (Any)
"""
def wrong_reset():
return new_reset_return
# Patch the reset method with a wrong one
env.reset = wrong_reset
with pytest.raises(AssertionError):
check_env(env)
def test_common_failures_reset():
"""
Test that common failure cases of the `reset_method` are caught
"""
env = IdentityEnvBox()
# Return an observation that does not match the observation_space
check_reset_assert_error(env, np.ones((3,)))
# The observation is not a numpy array
check_reset_assert_error(env, 1)
# Return not only the observation
check_reset_assert_error(env, (env.observation_space.sample(), False))
env = SimpleMultiObsEnv()
obs = env.reset()
def wrong_reset(self):
return {"img": obs["img"], "vec": obs["img"]}
env.reset = types.MethodType(wrong_reset, env)
with pytest.raises(AssertionError) as excinfo:
check_env(env)
# Check that the key is explicitly mentioned
assert "vec" in str(excinfo.value)
def check_step_assert_error(env, new_step_return=()):
"""
Helper to check that the error is caught.
:param env: (gym.Env)
:param new_step_return: (tuple)
"""
def wrong_step(_action):
return new_step_return
# Patch the step method with a wrong one
env.step = wrong_step
with pytest.raises(AssertionError):
check_env(env)
def test_common_failures_step():
"""
Test that common failure cases of the `step` method are caught
"""
env = IdentityEnvBox()
# Wrong shape for the observation
check_step_assert_error(env, (np.ones((4,)), 1.0, False, {}))
# Obs is not a numpy array
check_step_assert_error(env, (1, 1.0, False, {}))
# Return a wrong reward
check_step_assert_error(env, (env.observation_space.sample(), np.ones(1), False, {}))
# Info dict is not returned
check_step_assert_error(env, (env.observation_space.sample(), 0.0, False))
# Done is not a boolean
check_step_assert_error(env, (env.observation_space.sample(), 0.0, 3.0, {}))
check_step_assert_error(env, (env.observation_space.sample(), 0.0, 1, {}))
env = SimpleMultiObsEnv()
obs = env.reset()
def wrong_step(self, action):
return {"img": obs["vec"], "vec": obs["vec"]}, 0.0, False, {}
env.step = types.MethodType(wrong_step, env)
with pytest.raises(AssertionError) as excinfo:
check_env(env)
# Check that the key is explicitly mentioned
assert "img" in str(excinfo.value)