mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Update env checker for spaces with non-zero start (#1845)
* Update ruff * Update env checker for non-zero start
This commit is contained in:
parent
1cba1bbd2f
commit
a8e905977f
7 changed files with 43 additions and 23 deletions
2
Makefile
2
Makefile
|
|
@ -18,7 +18,7 @@ type: mypy
|
|||
lint:
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
# see https://www.flake8rules.com/
|
||||
ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --show-source
|
||||
ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full
|
||||
# exit-zero treats all errors as warnings.
|
||||
ruff ${LINT_PATHS} --exit-zero
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 2.3.0a2 (WIP)
|
||||
Release 2.3.0a3 (WIP)
|
||||
--------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -55,6 +55,8 @@ Deprecations:
|
|||
Others:
|
||||
^^^^^^^
|
||||
- Updated black from v23 to v24
|
||||
- Updated ruff to >= v0.2.2
|
||||
- Updated env checker for (multi)discrete spaces with non-zero start.
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -3,13 +3,15 @@
|
|||
line-length = 127
|
||||
# Assume Python 3.8
|
||||
target-version = "py38"
|
||||
|
||||
[tool.ruff.lint]
|
||||
# See https://beta.ruff.rs/docs/rules/
|
||||
select = ["E", "F", "B", "UP", "C90", "RUF"]
|
||||
# B028: Ignore explicit stacklevel`
|
||||
# RUF013: Too many false positives (implicit optional)
|
||||
ignore = ["B028", "RUF013"]
|
||||
|
||||
[tool.ruff.per-file-ignores]
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
# Default implementation in abstract methods
|
||||
"./stable_baselines3/common/callbacks.py"= ["B027"]
|
||||
"./stable_baselines3/common/noise.py"= ["B027"]
|
||||
|
|
@ -17,7 +19,7 @@ ignore = ["B028", "RUF013"]
|
|||
"./tests/*.py"= ["RUF012", "RUF013"]
|
||||
|
||||
|
||||
[tool.ruff.mccabe]
|
||||
[tool.ruff.lint.mccabe]
|
||||
# Unlike Flake8, default to a complexity level of 10.
|
||||
max-complexity = 15
|
||||
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -120,7 +120,7 @@ setup(
|
|||
# Type check
|
||||
"mypy",
|
||||
# Lint code and sort imports (flake8 and isort replacement)
|
||||
"ruff>=0.0.288",
|
||||
"ruff>=0.2.2",
|
||||
# Reformat
|
||||
"black>=24.2.0,<25",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -17,13 +17,37 @@ def _is_numpy_array_space(space: spaces.Space) -> bool:
|
|||
return not isinstance(space, (spaces.Dict, spaces.Tuple))
|
||||
|
||||
|
||||
def _starts_at_zero(space: Union[spaces.Discrete, spaces.MultiDiscrete]) -> bool:
|
||||
"""
|
||||
Return False if a (Multi)Discrete space has a non-zero start.
|
||||
"""
|
||||
return np.allclose(space.start, np.zeros_like(space.start))
|
||||
|
||||
|
||||
def _check_non_zero_start(space: spaces.Space, space_type: str = "observation", key: str = "") -> None:
|
||||
"""
|
||||
:param space: Observation or action space
|
||||
:param space_type: information about whether it is an observation or action space
|
||||
(for the warning message)
|
||||
:param key: When the observation space comes from a Dict space, we pass the
|
||||
corresponding key to have more precise warning messages. Defaults to "".
|
||||
"""
|
||||
if isinstance(space, (spaces.Discrete, spaces.MultiDiscrete)) and not _starts_at_zero(space):
|
||||
maybe_key = f"(key='{key}')" if key else ""
|
||||
warnings.warn(
|
||||
f"{type(space).__name__} {space_type} space {maybe_key} with a non-zero start (start={space.start}) "
|
||||
"is not supported by Stable-Baselines3. "
|
||||
f"You can use a wrapper or update your {space_type} space."
|
||||
)
|
||||
|
||||
|
||||
def _check_image_input(observation_space: spaces.Box, key: str = "") -> None:
|
||||
"""
|
||||
Check that the input will be compatible with Stable-Baselines
|
||||
when the observation is apparently an image.
|
||||
|
||||
:param observation_space: Observation space
|
||||
:key: When the observation space comes from a Dict space, we pass the
|
||||
:param key: When the observation space comes from a Dict space, we pass the
|
||||
corresponding key to have more precise warning messages. Defaults to "".
|
||||
"""
|
||||
if observation_space.dtype != np.uint8:
|
||||
|
|
@ -63,11 +87,7 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act
|
|||
for key, space in observation_space.spaces.items():
|
||||
if isinstance(space, spaces.Dict):
|
||||
nested_dict = True
|
||||
if isinstance(space, spaces.Discrete) and space.start != 0:
|
||||
warnings.warn(
|
||||
f"Discrete observation space (key '{key}') with a non-zero start is not supported by Stable-Baselines3. "
|
||||
"You can use a wrapper or update your observation space."
|
||||
)
|
||||
_check_non_zero_start(space, "observation", key)
|
||||
|
||||
if nested_dict:
|
||||
warnings.warn(
|
||||
|
|
@ -87,11 +107,7 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act
|
|||
"which is supported by SB3."
|
||||
)
|
||||
|
||||
if isinstance(observation_space, spaces.Discrete) and observation_space.start != 0:
|
||||
warnings.warn(
|
||||
"Discrete observation space with a non-zero start is not supported by Stable-Baselines3. "
|
||||
"You can use a wrapper or update your observation space."
|
||||
)
|
||||
_check_non_zero_start(observation_space, "observation")
|
||||
|
||||
if isinstance(observation_space, spaces.Sequence):
|
||||
warnings.warn(
|
||||
|
|
@ -100,11 +116,7 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act
|
|||
"Note: The checks for returned values are skipped."
|
||||
)
|
||||
|
||||
if isinstance(action_space, spaces.Discrete) and action_space.start != 0:
|
||||
warnings.warn(
|
||||
"Discrete action space with a non-zero start is not supported by Stable-Baselines3. "
|
||||
"You can use a wrapper or update your action space."
|
||||
)
|
||||
_check_non_zero_start(action_space, "action")
|
||||
|
||||
if not _is_numpy_array_space(action_space):
|
||||
warnings.warn(
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
2.3.0a2
|
||||
2.3.0a3
|
||||
|
|
|
|||
|
|
@ -123,6 +123,8 @@ def test_high_dimension_action_space():
|
|||
spaces.Dict({"img": spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8)}),
|
||||
# Non zero start index
|
||||
spaces.Discrete(3, start=-1),
|
||||
# Non zero start index (MultiDiscrete)
|
||||
spaces.MultiDiscrete([4, 4], start=[1, 0]),
|
||||
# Non zero start index inside a Dict
|
||||
spaces.Dict({"obs": spaces.Discrete(3, start=1)}),
|
||||
],
|
||||
|
|
@ -164,6 +166,8 @@ def test_non_default_spaces(new_obs_space):
|
|||
spaces.Box(low=np.array([-1, -1, -1]), high=np.array([1, 1, 0.99]), dtype=np.float32),
|
||||
# Non zero start index
|
||||
spaces.Discrete(3, start=-1),
|
||||
# Non zero start index (MultiDiscrete)
|
||||
spaces.MultiDiscrete([4, 4], start=[1, 0]),
|
||||
],
|
||||
)
|
||||
def test_non_default_action_spaces(new_action_space):
|
||||
|
|
@ -179,7 +183,7 @@ def test_non_default_action_spaces(new_action_space):
|
|||
env.action_space = new_action_space
|
||||
|
||||
# Discrete action space
|
||||
if isinstance(new_action_space, spaces.Discrete):
|
||||
if isinstance(new_action_space, (spaces.Discrete, spaces.MultiDiscrete)):
|
||||
with pytest.warns(UserWarning):
|
||||
check_env(env)
|
||||
return
|
||||
|
|
|
|||
Loading…
Reference in a new issue