From 512eea923afad6f6da4bb53d72b6ea4c6d856e59 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 13 Sep 2024 13:15:23 +0200 Subject: [PATCH] Warn users when using multi-dim MultiDiscrete obs space (#2003) * Update env checker to warn users when using multi-dim MultiDiscrete obs space * Update changelog --- docs/misc/changelog.rst | 10 +++++++++- stable_baselines3/common/env_checker.py | 8 ++++++++ stable_baselines3/version.txt | 2 +- tests/test_envs.py | 2 ++ 4 files changed, 20 insertions(+), 2 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index cc417a9..e8a2984 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.4.0a8 (WIP) +Release 2.4.0a9 (WIP) -------------------------- .. note:: @@ -13,6 +13,13 @@ Release 2.4.0a8 (WIP) To suppress the warning, simply save the model again. You can find more info in `PR #1963 `_ +.. warning:: + + Stable-Baselines3 (SB3) v2.4.0 will be the last one supporting Python 3.8 (end of life in October 2024) + and PyTorch < 2.0. + We highly recommended you to upgrade to Python >= 3.9 and PyTorch >= 2.0. + + Breaking Changes: ^^^^^^^^^^^^^^^^^ @@ -20,6 +27,7 @@ New Features: ^^^^^^^^^^^^^ - Added support for ``pre_linear_modules`` and ``post_linear_modules`` in ``create_mlp`` (useful for adding normalization layers, like in DroQ or CrossQ) - Enabled np.ndarray logging for TensorBoardOutputFormat as histogram (see GH#1634) (@iwishwasaneagle) +- Updated env checker to warn users when using multi-dim array to define `MultiDiscrete` spaces Bug Fixes: ^^^^^^^^^^ diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index 090d609..e47dd12 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -98,6 +98,14 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act "is not supported but `dict(space2=Box(), spaces3=Box(), spaces4=Discrete())` is." ) + if isinstance(observation_space, spaces.MultiDiscrete) and len(observation_space.nvec.shape) > 1: + warnings.warn( + f"The MultiDiscrete observation space uses a multidimensional array {observation_space.nvec} " + "which is currently not supported by Stable-Baselines3. " + "Please convert it to a 1D array using a wrapper: " + "https://github.com/DLR-RM/stable-baselines3/issues/1836." + ) + if isinstance(observation_space, spaces.Tuple): warnings.warn( "The observation space is a Tuple, " diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index ee717ba..636c433 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.4.0a8 +2.4.0a9 diff --git a/tests/test_envs.py b/tests/test_envs.py index 9a61eee..2fbce12 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -123,6 +123,8 @@ def test_high_dimension_action_space(): spaces.Dict({"img": spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8)}), # Non zero start index spaces.Discrete(3, start=-1), + # 2D MultiDiscrete + spaces.MultiDiscrete(np.array([[4, 4], [2, 3]])), # Non zero start index (MultiDiscrete) spaces.MultiDiscrete([4, 4], start=[1, 0]), # Non zero start index inside a Dict