From b6aa507a22b304050cdd1f95cb0343bec72a4a8c Mon Sep 17 00:00:00 2001 From: Fiete <41323592+FieteO@users.noreply.github.com> Date: Wed, 29 Mar 2023 15:26:03 +0200 Subject: [PATCH] Make check_env assertions in regards to observation_space more actionable (#1400) * add instructions for running single tests in the README, add assertions for observation_space * update changelog * address linting warnings * correct pytest command in the README * correct review comments, run make commit-checks * truncate lines that are too long * address make lint warning about checking module availability * fix tests * use f-strings for formatting assertion messages * fix type issue * Refactor tests, improve error messages --------- Co-authored-by: Antonin Raffin --- README.md | 22 +++++-- docs/misc/changelog.rst | 5 +- stable_baselines3/common/env_checker.py | 23 ++++++++ stable_baselines3/version.txt | 2 +- tests/test_env_checker.py | 77 +++++++++++++++++++++++++ tests/test_logger.py | 7 +-- 6 files changed, 123 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 0e0b38d..f7b9adc 100644 --- a/README.md +++ b/README.md @@ -104,7 +104,7 @@ pip install stable-baselines3[extra] **Note:** Some shells such as Zsh require quotation marks around brackets, i.e. `pip install 'stable-baselines3[extra]'` ([More Info](https://stackoverflow.com/a/30539963)). This includes an optional dependencies like Tensorboard, OpenCV or `atari-py` to train on atari games. If you do not need those, you can use: -``` +```sh pip install stable-baselines3 ``` @@ -194,20 +194,32 @@ Actions `gym.spaces`: ## Testing the installation -All unit tests in stable baselines3 can be run using `pytest` runner: +### Install dependencies +```sh +pip install -e .[docs,tests,extra] ``` -pip install pytest pytest-cov +### Run tests +All unit tests in stable baselines3 can be run using `pytest` runner: +```sh make pytest ``` +To run a single test file: +```sh +python3 -m pytest -v tests/test_env_checker.py +``` +To run a single test: +```sh +python3 -m pytest -v -k 'test_check_env_dict_action' +``` You can also do a static type check using `pytype` and `mypy`: -``` +```sh pip install pytype mypy make type ``` Codestyle check with `ruff`: -``` +```sh pip install ruff make lint ``` diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 840ab60..e96e26e 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.8.0a10 (WIP) +Release 1.8.0a11 (WIP) -------------------------- .. warning:: @@ -31,6 +31,7 @@ New Features: - Added support for dict/tuple observations spaces for ``VecCheckNan``, the check is now active in the ``env_checker()`` (@DavyMorgan) - Added multiprocessing support for ``HerReplayBuffer`` - ``HerReplayBuffer`` now supports all datatypes supported by ``ReplayBuffer`` +- Provide more helpful failure messages when validating the ``observation_space`` of custom gym environments using ``check_env``` (@FieteO) `SB3-Contrib`_ @@ -1251,4 +1252,4 @@ And all the contributors: @Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede @Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @yuanmingqi @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong -@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan +@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index 950cea6..b71454b 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -169,6 +169,29 @@ def _check_obs(obs: Union[tuple, dict, np.ndarray, int], observation_space: spac elif _is_numpy_array_space(observation_space): assert isinstance(obs, np.ndarray), f"The observation returned by `{method_name}()` method must be a numpy array" + # Additional checks for numpy arrays, so the error message is clearer (see GH#1399) + if isinstance(obs, np.ndarray): + # check obs dimensions, dtype and bounds + assert observation_space.shape == obs.shape, ( + f"The observation returned by the `{method_name}()` method does not match the shape " + f"of the given observation space. Expected: {observation_space.shape}, actual shape: {obs.shape}" + ) + assert observation_space.dtype == obs.dtype, ( + f"The observation returned by the `{method_name}()` method does not match the data type " + f"of the given observation space. Expected: {observation_space.dtype}, actual dtype: {obs.dtype}" + ) + if isinstance(observation_space, spaces.Box): + assert np.all(obs >= observation_space.low), ( + f"The observation returned by the `{method_name}()` method does not match the lower bound " + f"of the given observation space. Expected: obs >= {np.min(observation_space.low)}, " + f"actual min value: {np.min(obs)} at index {np.argmin(obs)}" + ) + assert np.all(obs <= observation_space.high), ( + f"The observation returned by the `{method_name}()` method does not match the upper bound " + f"of the given observation space. Expected: obs <= {np.max(observation_space.high)}, " + f"actual max value: {np.max(obs)} at index {np.argmax(obs)}" + ) + assert observation_space.contains( obs ), f"The observation returned by the `{method_name}()` method does not match the given observation space" diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index cba76ac..01d49ad 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.8.0a10 +1.8.0a11 diff --git a/tests/test_env_checker.py b/tests/test_env_checker.py index 3159786..94aeb3c 100644 --- a/tests/test_env_checker.py +++ b/tests/test_env_checker.py @@ -29,3 +29,80 @@ def test_check_env_dict_action(): with pytest.warns(Warning): check_env(env=test_env, warn=True) + + +@pytest.mark.parametrize( + "obs_tuple", + [ + # Above upper bound + ( + spaces.Box(low=0.0, high=1.0, shape=(3,), dtype=np.float32), + np.array([1.0, 1.5, 0.5], dtype=np.float32), + r"Expected: obs <= 1\.0, actual max value: 1\.5 at index 1", + ), + # Below lower bound + ( + spaces.Box(low=0.0, high=2.0, shape=(3,), dtype=np.float32), + np.array([-1.0, 1.5, 0.5], dtype=np.float32), + r"Expected: obs >= 0\.0, actual min value: -1\.0 at index 0", + ), + # Wrong dtype + ( + spaces.Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32), + np.array([1.0, 1.5, 0.5], dtype=np.float64), + r"Expected: float32, actual dtype: float64", + ), + # Wrong shape + ( + spaces.Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32), + np.array([[1.0, 1.5, 0.5], [1.0, 1.5, 0.5]], dtype=np.float32), + r"Expected: \(3,\), actual shape: \(2, 3\)", + ), + # Wrong shape (dict obs) + ( + spaces.Dict({"obs": spaces.Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32)}), + {"obs": np.array([[1.0, 1.5, 0.5], [1.0, 1.5, 0.5]], dtype=np.float32)}, + r"Error while checking key=obs.*Expected: \(3,\), actual shape: \(2, 3\)", + ), + # Wrong shape (multi discrete) + ( + spaces.MultiDiscrete([3, 3]), + np.array([[2, 0]]), + r"Expected: \(2,\), actual shape: \(1, 2\)", + ), + # Wrong shape (multi binary) + ( + spaces.MultiBinary(3), + np.array([[1, 0, 0]]), + r"Expected: \(3,\), actual shape: \(1, 3\)", + ), + ], +) +@pytest.mark.parametrize( + # Check when it happens at reset or during step + "method", + ["reset", "step"], +) +def test_check_env_detailed_error(obs_tuple, method): + """ + Check that the env checker returns more detail error + when the observation is not in the obs space. + """ + observation_space, wrong_obs, error_message = obs_tuple + good_obs = observation_space.sample() + + class TestEnv(gym.Env): + action_space = spaces.Box(low=-1.0, high=1.0, shape=(3,), dtype=np.float32) + + def reset(self): + return wrong_obs if method == "reset" else good_obs + + def step(self, action): + obs = wrong_obs if method == "step" else good_obs + return obs, 0.0, True, {} + + TestEnv.observation_space = observation_space + + test_env = TestEnv() + with pytest.raises(AssertionError, match=error_message): + check_env(env=test_env) diff --git a/tests/test_logger.py b/tests/test_logger.py index a1d1052..1bc11e5 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -1,3 +1,4 @@ +import importlib.util import os import sys import time @@ -233,11 +234,7 @@ def test_report_video_to_tensorboard(tmp_path, read_log, capsys): def is_moviepy_installed(): - try: - import moviepy - except ModuleNotFoundError: - return False - return True + return importlib.util.find_spec("moviepy") is not None @pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"])