Warn users when using multi-dim MultiDiscrete obs space (#2003)

* Update env checker to warn users when using multi-dim MultiDiscrete obs space

* Update changelog
This commit is contained in:
Antonin RAFFIN 2024-09-13 13:15:23 +02:00 committed by GitHub
parent 9a3b28bb9f
commit 512eea923a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 20 additions and 2 deletions

View file

@ -3,7 +3,7 @@
Changelog
==========
Release 2.4.0a8 (WIP)
Release 2.4.0a9 (WIP)
--------------------------
.. note::
@ -13,6 +13,13 @@ Release 2.4.0a8 (WIP)
To suppress the warning, simply save the model again.
You can find more info in `PR #1963 <https://github.com/DLR-RM/stable-baselines3/pull/1963>`_
.. warning::
Stable-Baselines3 (SB3) v2.4.0 will be the last one supporting Python 3.8 (end of life in October 2024)
and PyTorch < 2.0.
We highly recommended you to upgrade to Python >= 3.9 and PyTorch >= 2.0.
Breaking Changes:
^^^^^^^^^^^^^^^^^
@ -20,6 +27,7 @@ New Features:
^^^^^^^^^^^^^
- Added support for ``pre_linear_modules`` and ``post_linear_modules`` in ``create_mlp`` (useful for adding normalization layers, like in DroQ or CrossQ)
- Enabled np.ndarray logging for TensorBoardOutputFormat as histogram (see GH#1634) (@iwishwasaneagle)
- Updated env checker to warn users when using multi-dim array to define `MultiDiscrete` spaces
Bug Fixes:
^^^^^^^^^^

View file

@ -98,6 +98,14 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act
"is not supported but `dict(space2=Box(), spaces3=Box(), spaces4=Discrete())` is."
)
if isinstance(observation_space, spaces.MultiDiscrete) and len(observation_space.nvec.shape) > 1:
warnings.warn(
f"The MultiDiscrete observation space uses a multidimensional array {observation_space.nvec} "
"which is currently not supported by Stable-Baselines3. "
"Please convert it to a 1D array using a wrapper: "
"https://github.com/DLR-RM/stable-baselines3/issues/1836."
)
if isinstance(observation_space, spaces.Tuple):
warnings.warn(
"The observation space is a Tuple, "

View file

@ -1 +1 @@
2.4.0a8
2.4.0a9

View file

@ -123,6 +123,8 @@ def test_high_dimension_action_space():
spaces.Dict({"img": spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8)}),
# Non zero start index
spaces.Discrete(3, start=-1),
# 2D MultiDiscrete
spaces.MultiDiscrete(np.array([[4, 4], [2, 3]])),
# Non zero start index (MultiDiscrete)
spaces.MultiDiscrete([4, 4], start=[1, 0]),
# Non zero start index inside a Dict