From a8e905977f3073066eb332f063f6335f355c455a Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 19 Feb 2024 16:44:02 +0100 Subject: [PATCH] Update env checker for spaces with non-zero start (#1845) * Update ruff * Update env checker for non-zero start --- Makefile | 2 +- docs/misc/changelog.rst | 4 ++- pyproject.toml | 6 ++-- setup.py | 2 +- stable_baselines3/common/env_checker.py | 44 ++++++++++++++++--------- stable_baselines3/version.txt | 2 +- tests/test_envs.py | 6 +++- 7 files changed, 43 insertions(+), 23 deletions(-) diff --git a/Makefile b/Makefile index fe9f6ae..51a5940 100644 --- a/Makefile +++ b/Makefile @@ -18,7 +18,7 @@ type: mypy lint: # stop the build if there are Python syntax errors or undefined names # see https://www.flake8rules.com/ - ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --show-source + ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full # exit-zero treats all errors as warnings. ruff ${LINT_PATHS} --exit-zero diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index cf101af..feb096a 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.3.0a2 (WIP) +Release 2.3.0a3 (WIP) -------------------------- Breaking Changes: @@ -55,6 +55,8 @@ Deprecations: Others: ^^^^^^^ - Updated black from v23 to v24 +- Updated ruff to >= v0.2.2 +- Updated env checker for (multi)discrete spaces with non-zero start. Documentation: ^^^^^^^^^^^^^^ diff --git a/pyproject.toml b/pyproject.toml index 1195687..ce0a14e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,13 +3,15 @@ line-length = 127 # Assume Python 3.8 target-version = "py38" + +[tool.ruff.lint] # See https://beta.ruff.rs/docs/rules/ select = ["E", "F", "B", "UP", "C90", "RUF"] # B028: Ignore explicit stacklevel` # RUF013: Too many false positives (implicit optional) ignore = ["B028", "RUF013"] -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Default implementation in abstract methods "./stable_baselines3/common/callbacks.py"= ["B027"] "./stable_baselines3/common/noise.py"= ["B027"] @@ -17,7 +19,7 @@ ignore = ["B028", "RUF013"] "./tests/*.py"= ["RUF012", "RUF013"] -[tool.ruff.mccabe] +[tool.ruff.lint.mccabe] # Unlike Flake8, default to a complexity level of 10. max-complexity = 15 diff --git a/setup.py b/setup.py index 763a6a3..a077738 100644 --- a/setup.py +++ b/setup.py @@ -120,7 +120,7 @@ setup( # Type check "mypy", # Lint code and sort imports (flake8 and isort replacement) - "ruff>=0.0.288", + "ruff>=0.2.2", # Reformat "black>=24.2.0,<25", ], diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index dc465a1..f24c86e 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -17,13 +17,37 @@ def _is_numpy_array_space(space: spaces.Space) -> bool: return not isinstance(space, (spaces.Dict, spaces.Tuple)) +def _starts_at_zero(space: Union[spaces.Discrete, spaces.MultiDiscrete]) -> bool: + """ + Return False if a (Multi)Discrete space has a non-zero start. + """ + return np.allclose(space.start, np.zeros_like(space.start)) + + +def _check_non_zero_start(space: spaces.Space, space_type: str = "observation", key: str = "") -> None: + """ + :param space: Observation or action space + :param space_type: information about whether it is an observation or action space + (for the warning message) + :param key: When the observation space comes from a Dict space, we pass the + corresponding key to have more precise warning messages. Defaults to "". + """ + if isinstance(space, (spaces.Discrete, spaces.MultiDiscrete)) and not _starts_at_zero(space): + maybe_key = f"(key='{key}')" if key else "" + warnings.warn( + f"{type(space).__name__} {space_type} space {maybe_key} with a non-zero start (start={space.start}) " + "is not supported by Stable-Baselines3. " + f"You can use a wrapper or update your {space_type} space." + ) + + def _check_image_input(observation_space: spaces.Box, key: str = "") -> None: """ Check that the input will be compatible with Stable-Baselines when the observation is apparently an image. :param observation_space: Observation space - :key: When the observation space comes from a Dict space, we pass the + :param key: When the observation space comes from a Dict space, we pass the corresponding key to have more precise warning messages. Defaults to "". """ if observation_space.dtype != np.uint8: @@ -63,11 +87,7 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act for key, space in observation_space.spaces.items(): if isinstance(space, spaces.Dict): nested_dict = True - if isinstance(space, spaces.Discrete) and space.start != 0: - warnings.warn( - f"Discrete observation space (key '{key}') with a non-zero start is not supported by Stable-Baselines3. " - "You can use a wrapper or update your observation space." - ) + _check_non_zero_start(space, "observation", key) if nested_dict: warnings.warn( @@ -87,11 +107,7 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act "which is supported by SB3." ) - if isinstance(observation_space, spaces.Discrete) and observation_space.start != 0: - warnings.warn( - "Discrete observation space with a non-zero start is not supported by Stable-Baselines3. " - "You can use a wrapper or update your observation space." - ) + _check_non_zero_start(observation_space, "observation") if isinstance(observation_space, spaces.Sequence): warnings.warn( @@ -100,11 +116,7 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act "Note: The checks for returned values are skipped." ) - if isinstance(action_space, spaces.Discrete) and action_space.start != 0: - warnings.warn( - "Discrete action space with a non-zero start is not supported by Stable-Baselines3. " - "You can use a wrapper or update your action space." - ) + _check_non_zero_start(action_space, "action") if not _is_numpy_array_space(action_space): warnings.warn( diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 34109b6..5334cfa 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.3.0a2 +2.3.0a3 diff --git a/tests/test_envs.py b/tests/test_envs.py index e82ef57..9a61eee 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -123,6 +123,8 @@ def test_high_dimension_action_space(): spaces.Dict({"img": spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8)}), # Non zero start index spaces.Discrete(3, start=-1), + # Non zero start index (MultiDiscrete) + spaces.MultiDiscrete([4, 4], start=[1, 0]), # Non zero start index inside a Dict spaces.Dict({"obs": spaces.Discrete(3, start=1)}), ], @@ -164,6 +166,8 @@ def test_non_default_spaces(new_obs_space): spaces.Box(low=np.array([-1, -1, -1]), high=np.array([1, 1, 0.99]), dtype=np.float32), # Non zero start index spaces.Discrete(3, start=-1), + # Non zero start index (MultiDiscrete) + spaces.MultiDiscrete([4, 4], start=[1, 0]), ], ) def test_non_default_action_spaces(new_action_space): @@ -179,7 +183,7 @@ def test_non_default_action_spaces(new_action_space): env.action_space = new_action_space # Discrete action space - if isinstance(new_action_space, spaces.Discrete): + if isinstance(new_action_space, (spaces.Discrete, spaces.MultiDiscrete)): with pytest.warns(UserWarning): check_env(env) return