mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-05 00:00:04 +00:00
Make check_env assertions in regards to observation_space more actionable (#1400)
* add instructions for running single tests in the README, add assertions for observation_space * update changelog * address linting warnings * correct pytest command in the README * correct review comments, run make commit-checks * truncate lines that are too long * address make lint warning about checking module availability * fix tests * use f-strings for formatting assertion messages * fix type issue * Refactor tests, improve error messages --------- Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
This commit is contained in:
parent
c5adad82b2
commit
b6aa507a22
6 changed files with 123 additions and 13 deletions
22
README.md
22
README.md
|
|
@ -104,7 +104,7 @@ pip install stable-baselines3[extra]
|
|||
**Note:** Some shells such as Zsh require quotation marks around brackets, i.e. `pip install 'stable-baselines3[extra]'` ([More Info](https://stackoverflow.com/a/30539963)).
|
||||
|
||||
This includes an optional dependencies like Tensorboard, OpenCV or `atari-py` to train on atari games. If you do not need those, you can use:
|
||||
```
|
||||
```sh
|
||||
pip install stable-baselines3
|
||||
```
|
||||
|
||||
|
|
@ -194,20 +194,32 @@ Actions `gym.spaces`:
|
|||
|
||||
|
||||
## Testing the installation
|
||||
All unit tests in stable baselines3 can be run using `pytest` runner:
|
||||
### Install dependencies
|
||||
```sh
|
||||
pip install -e .[docs,tests,extra]
|
||||
```
|
||||
pip install pytest pytest-cov
|
||||
### Run tests
|
||||
All unit tests in stable baselines3 can be run using `pytest` runner:
|
||||
```sh
|
||||
make pytest
|
||||
```
|
||||
To run a single test file:
|
||||
```sh
|
||||
python3 -m pytest -v tests/test_env_checker.py
|
||||
```
|
||||
To run a single test:
|
||||
```sh
|
||||
python3 -m pytest -v -k 'test_check_env_dict_action'
|
||||
```
|
||||
|
||||
You can also do a static type check using `pytype` and `mypy`:
|
||||
```
|
||||
```sh
|
||||
pip install pytype mypy
|
||||
make type
|
||||
```
|
||||
|
||||
Codestyle check with `ruff`:
|
||||
```
|
||||
```sh
|
||||
pip install ruff
|
||||
make lint
|
||||
```
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.8.0a10 (WIP)
|
||||
Release 1.8.0a11 (WIP)
|
||||
--------------------------
|
||||
|
||||
.. warning::
|
||||
|
|
@ -31,6 +31,7 @@ New Features:
|
|||
- Added support for dict/tuple observations spaces for ``VecCheckNan``, the check is now active in the ``env_checker()`` (@DavyMorgan)
|
||||
- Added multiprocessing support for ``HerReplayBuffer``
|
||||
- ``HerReplayBuffer`` now supports all datatypes supported by ``ReplayBuffer``
|
||||
- Provide more helpful failure messages when validating the ``observation_space`` of custom gym environments using ``check_env``` (@FieteO)
|
||||
|
||||
|
||||
`SB3-Contrib`_
|
||||
|
|
@ -1251,4 +1252,4 @@ And all the contributors:
|
|||
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
|
||||
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @yuanmingqi
|
||||
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong
|
||||
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan
|
||||
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO
|
||||
|
|
|
|||
|
|
@ -169,6 +169,29 @@ def _check_obs(obs: Union[tuple, dict, np.ndarray, int], observation_space: spac
|
|||
elif _is_numpy_array_space(observation_space):
|
||||
assert isinstance(obs, np.ndarray), f"The observation returned by `{method_name}()` method must be a numpy array"
|
||||
|
||||
# Additional checks for numpy arrays, so the error message is clearer (see GH#1399)
|
||||
if isinstance(obs, np.ndarray):
|
||||
# check obs dimensions, dtype and bounds
|
||||
assert observation_space.shape == obs.shape, (
|
||||
f"The observation returned by the `{method_name}()` method does not match the shape "
|
||||
f"of the given observation space. Expected: {observation_space.shape}, actual shape: {obs.shape}"
|
||||
)
|
||||
assert observation_space.dtype == obs.dtype, (
|
||||
f"The observation returned by the `{method_name}()` method does not match the data type "
|
||||
f"of the given observation space. Expected: {observation_space.dtype}, actual dtype: {obs.dtype}"
|
||||
)
|
||||
if isinstance(observation_space, spaces.Box):
|
||||
assert np.all(obs >= observation_space.low), (
|
||||
f"The observation returned by the `{method_name}()` method does not match the lower bound "
|
||||
f"of the given observation space. Expected: obs >= {np.min(observation_space.low)}, "
|
||||
f"actual min value: {np.min(obs)} at index {np.argmin(obs)}"
|
||||
)
|
||||
assert np.all(obs <= observation_space.high), (
|
||||
f"The observation returned by the `{method_name}()` method does not match the upper bound "
|
||||
f"of the given observation space. Expected: obs <= {np.max(observation_space.high)}, "
|
||||
f"actual max value: {np.max(obs)} at index {np.argmax(obs)}"
|
||||
)
|
||||
|
||||
assert observation_space.contains(
|
||||
obs
|
||||
), f"The observation returned by the `{method_name}()` method does not match the given observation space"
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.8.0a10
|
||||
1.8.0a11
|
||||
|
|
|
|||
|
|
@ -29,3 +29,80 @@ def test_check_env_dict_action():
|
|||
|
||||
with pytest.warns(Warning):
|
||||
check_env(env=test_env, warn=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"obs_tuple",
|
||||
[
|
||||
# Above upper bound
|
||||
(
|
||||
spaces.Box(low=0.0, high=1.0, shape=(3,), dtype=np.float32),
|
||||
np.array([1.0, 1.5, 0.5], dtype=np.float32),
|
||||
r"Expected: obs <= 1\.0, actual max value: 1\.5 at index 1",
|
||||
),
|
||||
# Below lower bound
|
||||
(
|
||||
spaces.Box(low=0.0, high=2.0, shape=(3,), dtype=np.float32),
|
||||
np.array([-1.0, 1.5, 0.5], dtype=np.float32),
|
||||
r"Expected: obs >= 0\.0, actual min value: -1\.0 at index 0",
|
||||
),
|
||||
# Wrong dtype
|
||||
(
|
||||
spaces.Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32),
|
||||
np.array([1.0, 1.5, 0.5], dtype=np.float64),
|
||||
r"Expected: float32, actual dtype: float64",
|
||||
),
|
||||
# Wrong shape
|
||||
(
|
||||
spaces.Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32),
|
||||
np.array([[1.0, 1.5, 0.5], [1.0, 1.5, 0.5]], dtype=np.float32),
|
||||
r"Expected: \(3,\), actual shape: \(2, 3\)",
|
||||
),
|
||||
# Wrong shape (dict obs)
|
||||
(
|
||||
spaces.Dict({"obs": spaces.Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32)}),
|
||||
{"obs": np.array([[1.0, 1.5, 0.5], [1.0, 1.5, 0.5]], dtype=np.float32)},
|
||||
r"Error while checking key=obs.*Expected: \(3,\), actual shape: \(2, 3\)",
|
||||
),
|
||||
# Wrong shape (multi discrete)
|
||||
(
|
||||
spaces.MultiDiscrete([3, 3]),
|
||||
np.array([[2, 0]]),
|
||||
r"Expected: \(2,\), actual shape: \(1, 2\)",
|
||||
),
|
||||
# Wrong shape (multi binary)
|
||||
(
|
||||
spaces.MultiBinary(3),
|
||||
np.array([[1, 0, 0]]),
|
||||
r"Expected: \(3,\), actual shape: \(1, 3\)",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
# Check when it happens at reset or during step
|
||||
"method",
|
||||
["reset", "step"],
|
||||
)
|
||||
def test_check_env_detailed_error(obs_tuple, method):
|
||||
"""
|
||||
Check that the env checker returns more detail error
|
||||
when the observation is not in the obs space.
|
||||
"""
|
||||
observation_space, wrong_obs, error_message = obs_tuple
|
||||
good_obs = observation_space.sample()
|
||||
|
||||
class TestEnv(gym.Env):
|
||||
action_space = spaces.Box(low=-1.0, high=1.0, shape=(3,), dtype=np.float32)
|
||||
|
||||
def reset(self):
|
||||
return wrong_obs if method == "reset" else good_obs
|
||||
|
||||
def step(self, action):
|
||||
obs = wrong_obs if method == "step" else good_obs
|
||||
return obs, 0.0, True, {}
|
||||
|
||||
TestEnv.observation_space = observation_space
|
||||
|
||||
test_env = TestEnv()
|
||||
with pytest.raises(AssertionError, match=error_message):
|
||||
check_env(env=test_env)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import importlib.util
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
|
@ -233,11 +234,7 @@ def test_report_video_to_tensorboard(tmp_path, read_log, capsys):
|
|||
|
||||
|
||||
def is_moviepy_installed():
|
||||
try:
|
||||
import moviepy
|
||||
except ModuleNotFoundError:
|
||||
return False
|
||||
return True
|
||||
return importlib.util.find_spec("moviepy") is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"])
|
||||
|
|
|
|||
Loading…
Reference in a new issue