import types import warnings import gym import numpy as np import pytest from gym import spaces from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.envs import ( BitFlippingEnv, FakeImageEnv, IdentityEnv, IdentityEnvBox, IdentityEnvMultiBinary, IdentityEnvMultiDiscrete, SimpleMultiObsEnv, ) ENV_CLASSES = [ BitFlippingEnv, IdentityEnv, IdentityEnvBox, IdentityEnvMultiBinary, IdentityEnvMultiDiscrete, FakeImageEnv, SimpleMultiObsEnv, ] @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"]) def test_env(env_id): """ Check that environmnent integrated in Gym pass the test. :param env_id: (str) """ env = gym.make(env_id) with warnings.catch_warnings(record=True) as record: check_env(env) # Pendulum-v1 will produce a warning because the action space is # in [-2, 2] and not [-1, 1] if env_id == "Pendulum-v1": assert len(record) == 1 else: # The other environments must pass without warning assert len(record) == 0 @pytest.mark.parametrize("env_class", ENV_CLASSES) def test_custom_envs(env_class): env = env_class() with warnings.catch_warnings(record=True) as record: check_env(env) # No warnings for custom envs assert len(record) == 0 @pytest.mark.parametrize( "kwargs", [ dict(continuous=True), dict(discrete_obs_space=True), dict(image_obs_space=True, channel_first=True), dict(image_obs_space=True, channel_first=False), ], ) def test_bit_flipping(kwargs): # Additional tests for BitFlippingEnv env = BitFlippingEnv(**kwargs) with warnings.catch_warnings(record=True) as record: check_env(env) # No warnings for custom envs assert len(record) == 0 def test_high_dimension_action_space(): """ Test for continuous action space with more than one action. """ env = FakeImageEnv() # Patch the action space env.action_space = spaces.Box(low=-1, high=1, shape=(20,), dtype=np.float32) # Patch to avoid error def patched_step(_action): return env.observation_space.sample(), 0.0, False, {} env.step = patched_step check_env(env) @pytest.mark.parametrize( "new_obs_space", [ # Small image spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8), # Range not in [0, 255] spaces.Box(low=0, high=1, shape=(64, 64, 3), dtype=np.uint8), # Wrong dtype spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.float32), # Not an image, it should be a 1D vector spaces.Box(low=-1, high=1, shape=(64, 3), dtype=np.float32), # Tuple space is not supported by SB spaces.Tuple([spaces.Discrete(5), spaces.Discrete(10)]), # Nested dict space is not supported by SB3 spaces.Dict({"position": spaces.Dict({"abs": spaces.Discrete(5), "rel": spaces.Discrete(2)})}), # Small image inside a dict spaces.Dict({"img": spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8)}), ], ) def test_non_default_spaces(new_obs_space): env = FakeImageEnv() env.observation_space = new_obs_space # Patch methods to avoid errors env.reset = new_obs_space.sample def patched_step(_action): return new_obs_space.sample(), 0.0, False, {} env.step = patched_step with pytest.warns(UserWarning): check_env(env) @pytest.mark.parametrize( "new_action_space", [ # Not symmetric spaces.Box(low=0, high=1, shape=(3,), dtype=np.float32), # Wrong dtype spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float64), # Too big range spaces.Box(low=-1000, high=1000, shape=(3,), dtype=np.float32), # Too small range spaces.Box(low=-0.1, high=0.1, shape=(2,), dtype=np.float32), # Inverted boundaries spaces.Box(low=1, high=-1, shape=(2,), dtype=np.float32), # Same boundaries spaces.Box(low=1, high=1, shape=(2,), dtype=np.float32), # Unbounded action space spaces.Box(low=-np.inf, high=1, shape=(2,), dtype=np.float32), # Almost good, except for one dim spaces.Box(low=np.array([-1, -1, -1]), high=np.array([1, 1, 0.99]), dtype=np.float32), ], ) def test_non_default_action_spaces(new_action_space): env = FakeImageEnv(discrete=False) # Default, should pass the test with warnings.catch_warnings(record=True) as record: check_env(env) # No warnings for custom envs assert len(record) == 0 # Change the action space env.action_space = new_action_space # Unbounded action space throws an error, # the rest only warning if not np.all(np.isfinite(env.action_space.low)): with pytest.raises(AssertionError), pytest.warns(UserWarning): check_env(env) else: with pytest.warns(UserWarning): check_env(env) def check_reset_assert_error(env, new_reset_return): """ Helper to check that the error is caught. :param env: (gym.Env) :param new_reset_return: (Any) """ def wrong_reset(): return new_reset_return # Patch the reset method with a wrong one env.reset = wrong_reset with pytest.raises(AssertionError): check_env(env) def test_common_failures_reset(): """ Test that common failure cases of the `reset_method` are caught """ env = IdentityEnvBox() # Return an observation that does not match the observation_space check_reset_assert_error(env, np.ones((3,))) # The observation is not a numpy array check_reset_assert_error(env, 1) # Return not only the observation check_reset_assert_error(env, (env.observation_space.sample(), False)) env = SimpleMultiObsEnv() # Observation keys and observation space keys must match wrong_obs = env.observation_space.sample() wrong_obs.pop("img") check_reset_assert_error(env, wrong_obs) wrong_obs = {**env.observation_space.sample(), "extra_key": None} check_reset_assert_error(env, wrong_obs) obs = env.reset() def wrong_reset(self): return {"img": obs["img"], "vec": obs["img"]} env.reset = types.MethodType(wrong_reset, env) with pytest.raises(AssertionError) as excinfo: check_env(env) # Check that the key is explicitly mentioned assert "vec" in str(excinfo.value) def check_step_assert_error(env, new_step_return=()): """ Helper to check that the error is caught. :param env: (gym.Env) :param new_step_return: (tuple) """ def wrong_step(_action): return new_step_return # Patch the step method with a wrong one env.step = wrong_step with pytest.raises(AssertionError): check_env(env) def test_common_failures_step(): """ Test that common failure cases of the `step` method are caught """ env = IdentityEnvBox() # Wrong shape for the observation check_step_assert_error(env, (np.ones((4,)), 1.0, False, {})) # Obs is not a numpy array check_step_assert_error(env, (1, 1.0, False, {})) # Return a wrong reward check_step_assert_error(env, (env.observation_space.sample(), np.ones(1), False, {})) # Info dict is not returned check_step_assert_error(env, (env.observation_space.sample(), 0.0, False)) # Done is not a boolean check_step_assert_error(env, (env.observation_space.sample(), 0.0, 3.0, {})) check_step_assert_error(env, (env.observation_space.sample(), 0.0, 1, {})) env = SimpleMultiObsEnv() # Observation keys and observation space keys must match wrong_obs = env.observation_space.sample() wrong_obs.pop("img") check_step_assert_error(env, (wrong_obs, 0.0, False, {})) wrong_obs = {**env.observation_space.sample(), "extra_key": None} check_step_assert_error(env, (wrong_obs, 0.0, False, {})) obs = env.reset() def wrong_step(self, action): return {"img": obs["vec"], "vec": obs["vec"]}, 0.0, False, {} env.step = types.MethodType(wrong_step, env) with pytest.raises(AssertionError) as excinfo: check_env(env) # Check that the key is explicitly mentioned assert "img" in str(excinfo.value)