From 852d635742e97495d200b104ab8e09f128211e26 Mon Sep 17 00:00:00 2001 From: Zikang Xiong <73256697+ZikangXiong@users.noreply.github.com> Date: Tue, 29 Nov 2022 17:33:46 -0500 Subject: [PATCH 1/5] Exposed modules in __init__.py with __all__ (#1195) * Exposed modules in __init__.py with __all__ * Remove flake8 ignore and update root __all__ * Update version Co-authored-by: Antonin Raffin --- docs/misc/changelog.rst | 5 +++-- setup.cfg | 11 ---------- stable_baselines3/__init__.py | 12 ++++++++++ stable_baselines3/a2c/__init__.py | 2 ++ stable_baselines3/common/envs/__init__.py | 11 ++++++++++ stable_baselines3/common/vec_env/__init__.py | 23 +++++++++++++++++++- stable_baselines3/ddpg/__init__.py | 2 ++ stable_baselines3/dqn/__init__.py | 2 ++ stable_baselines3/her/__init__.py | 2 ++ stable_baselines3/ppo/__init__.py | 2 ++ stable_baselines3/sac/__init__.py | 2 ++ stable_baselines3/td3/__init__.py | 2 ++ stable_baselines3/version.txt | 2 +- 13 files changed, 63 insertions(+), 15 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 4885d52..01fc719 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.7.0a4 (WIP) +Release 1.7.0a5 (WIP) -------------------------- Breaking Changes: @@ -43,6 +43,7 @@ Others: - Fixed ``tests/test_distributions.py`` type hint - Fixed ``stable_baselines3/common/type_aliases.py`` type hint - Fixed ``stable_baselines3/common/env_util.py`` type hint +- Exposed modules in ``__init__.py`` with the ``__all__`` attribute (@ZikangXiong) Documentation: ^^^^^^^^^^^^^^ @@ -1126,4 +1127,4 @@ And all the contributors: @simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485 @Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede @Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 -@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer +@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong diff --git a/setup.cfg b/setup.cfg index 331c8ff..733b5c3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -81,17 +81,6 @@ exclude = (?x)( ignore = W503,W504,E203,E231 # Ignore import not used when aliases are defined per-file-ignores = - ./stable_baselines3/__init__.py:F401 - ./stable_baselines3/common/__init__.py:F401 - ./stable_baselines3/common/envs/__init__.py:F401 - ./stable_baselines3/a2c/__init__.py:F401 - ./stable_baselines3/ddpg/__init__.py:F401 - ./stable_baselines3/dqn/__init__.py:F401 - ./stable_baselines3/her/__init__.py:F401 - ./stable_baselines3/ppo/__init__.py:F401 - ./stable_baselines3/sac/__init__.py:F401 - ./stable_baselines3/td3/__init__.py:F401 - ./stable_baselines3/common/vec_env/__init__.py:F401 # Default implementation in abstract methods ./stable_baselines3/common/callbacks.py:B027 ./stable_baselines3/common/noise.py:B027 diff --git a/stable_baselines3/__init__.py b/stable_baselines3/__init__.py index d73f5f0..0775a8e 100644 --- a/stable_baselines3/__init__.py +++ b/stable_baselines3/__init__.py @@ -20,3 +20,15 @@ def HER(*args, **kwargs): "Since Stable Baselines 2.1.0, `HER` is now a replay buffer class `HerReplayBuffer`.\n " "Please check the documentation for more information: https://stable-baselines3.readthedocs.io/" ) + + +__all__ = [ + "A2C", + "DDPG", + "DQN", + "PPO", + "SAC", + "TD3", + "HerReplayBuffer", + "get_system_info", +] diff --git a/stable_baselines3/a2c/__init__.py b/stable_baselines3/a2c/__init__.py index 7e99964..78fc54f 100644 --- a/stable_baselines3/a2c/__init__.py +++ b/stable_baselines3/a2c/__init__.py @@ -1,2 +1,4 @@ from stable_baselines3.a2c.a2c import A2C from stable_baselines3.a2c.policies import CnnPolicy, MlpPolicy, MultiInputPolicy + +__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "A2C"] diff --git a/stable_baselines3/common/envs/__init__.py b/stable_baselines3/common/envs/__init__.py index 23bd575..3ff0221 100644 --- a/stable_baselines3/common/envs/__init__.py +++ b/stable_baselines3/common/envs/__init__.py @@ -7,3 +7,14 @@ from stable_baselines3.common.envs.identity_env import ( IdentityEnvMultiDiscrete, ) from stable_baselines3.common.envs.multi_input_envs import SimpleMultiObsEnv + +__all__ = [ + "BitFlippingEnv", + "FakeImageEnv", + "IdentityEnv", + "IdentityEnvBox", + "IdentityEnvMultiBinary", + "IdentityEnvMultiDiscrete", + "SimpleMultiObsEnv", + "SimpleMultiObsEnv", +] diff --git a/stable_baselines3/common/vec_env/__init__.py b/stable_baselines3/common/vec_env/__init__.py index 3880fbd..33a103a 100644 --- a/stable_baselines3/common/vec_env/__init__.py +++ b/stable_baselines3/common/vec_env/__init__.py @@ -1,4 +1,3 @@ -# flake8: noqa F401 import typing from copy import deepcopy from typing import Optional, Type, Union @@ -72,3 +71,25 @@ def sync_envs_normalization(env: "GymEnv", eval_env: "GymEnv") -> None: eval_env_tmp.ret_rms = deepcopy(env_tmp.ret_rms) env_tmp = env_tmp.venv eval_env_tmp = eval_env_tmp.venv + + +__all__ = [ + "CloudpickleWrapper", + "VecEnv", + "VecEnvWrapper", + "DummyVecEnv", + "StackedDictObservations", + "StackedObservations", + "SubprocVecEnv", + "VecCheckNan", + "VecExtractDictObs", + "VecFrameStack", + "VecMonitor", + "VecNormalize", + "VecTransposeImage", + "VecVideoRecorder", + "unwrap_vec_wrapper", + "unwrap_vec_normalize", + "is_vecenv_wrapped", + "sync_envs_normalization", +] diff --git a/stable_baselines3/ddpg/__init__.py b/stable_baselines3/ddpg/__init__.py index 262e7f1..257a3e3 100644 --- a/stable_baselines3/ddpg/__init__.py +++ b/stable_baselines3/ddpg/__init__.py @@ -1,2 +1,4 @@ from stable_baselines3.ddpg.ddpg import DDPG from stable_baselines3.ddpg.policies import CnnPolicy, MlpPolicy, MultiInputPolicy + +__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "DDPG"] diff --git a/stable_baselines3/dqn/__init__.py b/stable_baselines3/dqn/__init__.py index f36f96e..2e5e2db 100644 --- a/stable_baselines3/dqn/__init__.py +++ b/stable_baselines3/dqn/__init__.py @@ -1,2 +1,4 @@ from stable_baselines3.dqn.dqn import DQN from stable_baselines3.dqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy + +__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "DQN"] diff --git a/stable_baselines3/her/__init__.py b/stable_baselines3/her/__init__.py index 1f58921..dc4c8c2 100644 --- a/stable_baselines3/her/__init__.py +++ b/stable_baselines3/her/__init__.py @@ -1,2 +1,4 @@ from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy from stable_baselines3.her.her_replay_buffer import HerReplayBuffer + +__all__ = ["GoalSelectionStrategy", "HerReplayBuffer"] diff --git a/stable_baselines3/ppo/__init__.py b/stable_baselines3/ppo/__init__.py index e5c23fc..cd91257 100644 --- a/stable_baselines3/ppo/__init__.py +++ b/stable_baselines3/ppo/__init__.py @@ -1,2 +1,4 @@ from stable_baselines3.ppo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy from stable_baselines3.ppo.ppo import PPO + +__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "PPO"] diff --git a/stable_baselines3/sac/__init__.py b/stable_baselines3/sac/__init__.py index 5a84dde..bdf780d 100644 --- a/stable_baselines3/sac/__init__.py +++ b/stable_baselines3/sac/__init__.py @@ -1,2 +1,4 @@ from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy from stable_baselines3.sac.sac import SAC + +__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "SAC"] diff --git a/stable_baselines3/td3/__init__.py b/stable_baselines3/td3/__init__.py index 0b903cd..428141e 100644 --- a/stable_baselines3/td3/__init__.py +++ b/stable_baselines3/td3/__init__.py @@ -1,2 +1,4 @@ from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy from stable_baselines3.td3.td3 import TD3 + +__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "TD3"] diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 0952a4b..5d819d4 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.7.0a4 +1.7.0a5 From 002850f8ace0e045f7e9d370149a6fbb6cbcebad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Tue, 29 Nov 2022 23:46:32 +0100 Subject: [PATCH 2/5] Fix `stable_baselines3/common/torch_layers.py` type hint (#1191) * Remove torch layers from mypy exclude * Make torch layers mypy compliant * Extra type specification * Update changelog Co-authored-by: Antonin RAFFIN --- docs/misc/changelog.rst | 1 + setup.cfg | 1 - stable_baselines3/common/torch_layers.py | 13 ++++++------- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 01fc719..67b5591 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -42,6 +42,7 @@ Others: - Replaced ``CartPole-v0`` by ``CartPole-v1`` is tests - Fixed ``tests/test_distributions.py`` type hint - Fixed ``stable_baselines3/common/type_aliases.py`` type hint +- Fixed ``stable_baselines3/common/torch_layers.py`` type hint - Fixed ``stable_baselines3/common/env_util.py`` type hint - Exposed modules in ``__init__.py`` with the ``__all__`` attribute (@ZikangXiong) diff --git a/setup.cfg b/setup.cfg index 733b5c3..045638c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,7 +47,6 @@ exclude = (?x)( | stable_baselines3/common/preprocessing.py$ | stable_baselines3/common/save_util.py$ | stable_baselines3/common/sb2_compat/rmsprop_tf_like.py$ - | stable_baselines3/common/torch_layers.py$ | stable_baselines3/common/utils.py$ | stable_baselines3/common/vec_env/__init__.py$ | stable_baselines3/common/vec_env/base_vec_env.py$ diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index 2ce0cc1..5de2af1 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -28,9 +28,6 @@ class BaseFeaturesExtractor(nn.Module): def features_dim(self) -> int: return self._features_dim - def forward(self, observations: th.Tensor) -> th.Tensor: - raise NotImplementedError() - class FlattenExtractor(BaseFeaturesExtractor): """ @@ -173,9 +170,11 @@ class MlpExtractor(nn.Module): ): super().__init__() device = get_device(device) - shared_net, policy_net, value_net = [], [], [] - policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network - value_only_layers = [] # Layer sizes of the network that only belongs to the value network + shared_net: List[nn.Module] = [] + policy_net: List[nn.Module] = [] + value_net: List[nn.Module] = [] + policy_only_layers: List[int] = [] # Layer sizes of the network that only belongs to the policy network + value_only_layers: List[int] = [] # Layer sizes of the network that only belongs to the value network last_layer_dim_shared = feature_dim # Iterate through the shared layers and build the shared parts of the network @@ -254,7 +253,7 @@ class CombinedExtractor(BaseFeaturesExtractor): # TODO we do not know features-dim here before going over all the items, so put something there. This is dirty! super().__init__(observation_space, features_dim=1) - extractors = {} + extractors: Dict[str, nn.Module] = {} total_concat_size = 0 for key, subspace in observation_space.spaces.items(): From f7d7ed3fa7095b560e34210e709c1c8b7b7877e5 Mon Sep 17 00:00:00 2001 From: Athanasios Theocharis Date: Tue, 6 Dec 2022 17:51:52 +0100 Subject: [PATCH 3/5] Update custom_policy.rst (#1183) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update custom_policy.rst * Update changelog Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Antonin RAFFIN Co-authored-by: Antonin Raffin --- docs/guide/custom_policy.rst | 6 +++--- docs/misc/changelog.rst | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/guide/custom_policy.rst b/docs/guide/custom_policy.rst index 1a3ae34..4ba3203 100644 --- a/docs/guide/custom_policy.rst +++ b/docs/guide/custom_policy.rst @@ -333,11 +333,11 @@ If your task requires even more granular control over the policy/value architect :return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network. If all layers are shared, then ``latent_policy == latent_value`` """ - return self.policy_net(features), self.value_net(features) - + return self.forward_actor(features), self.forward_critic(features) + def forward_actor(self, features: th.Tensor) -> th.Tensor: return self.policy_net(features) - + def forward_critic(self, features: th.Tensor) -> th.Tensor: return self.value_net(features) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 67b5591..619e1eb 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -50,6 +50,7 @@ Documentation: ^^^^^^^^^^^^^^ - Updated Hugging Face Integration page (@simoninithomas) - Changed ``env`` to ``vec_env`` when environment is vectorized +- Update custom policy documentation (@athatheo) Release 1.6.2 (2022-10-10) -------------------------- From 6763a864c80a0a131bba953ddef9642be0bbc520 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Wed, 7 Dec 2022 16:43:47 +0100 Subject: [PATCH 4/5] Upgrade CI/github-actions (#1204) * checkout v2 -> v3; setup-python v2 -> v4 * Update changelog.rst --- .github/workflows/ci.yml | 4 ++-- docs/misc/changelog.rst | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b50cc62..9d22a0a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,9 +22,9 @@ jobs: python-version: [3.7, 3.8, 3.9] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - name: Install dependencies diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 619e1eb..c3c455b 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -45,6 +45,7 @@ Others: - Fixed ``stable_baselines3/common/torch_layers.py`` type hint - Fixed ``stable_baselines3/common/env_util.py`` type hint - Exposed modules in ``__init__.py`` with the ``__all__`` attribute (@ZikangXiong) +- Upgraded GitHub CI/setup-python to v4 and checkout to v3 Documentation: ^^^^^^^^^^^^^^ From e39bc3da00c49413b765176af1b95f2361a35098 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Thu, 8 Dec 2022 18:46:41 +0100 Subject: [PATCH 5/5] Add support for multidimensional `spaces.MultiBinary` observations (#1179) * Fix `get_obs_shape` for multidimensi onnal Multibinary space * Update changelog * more tests * fix multidiscrete one-hot encoding * refactor tests * Update changelog.rst * Update changelog.rst * batched obs and revert preprocess_obs changes * Add support for multidimensional ``spaces.MultiBinary`` observations Co-authored-by: Antonin RAFFIN Co-authored-by: Antonin Raffin --- docs/misc/changelog.rst | 5 +- stable_baselines3/common/preprocessing.py | 5 +- stable_baselines3/common/utils.py | 6 +-- tests/test_preprocessing.py | 66 +++++++++++++++++++++++ tests/test_spaces.py | 2 +- 5 files changed, 77 insertions(+), 7 deletions(-) create mode 100644 tests/test_preprocessing.py diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index c3c455b..9c65909 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -18,14 +18,15 @@ New Features: ^^^^^^^^^^^^^ - Introduced mypy type checking - Added ``with_bias`` argument to ``create_mlp`` +- Added support for multidimensional ``spaces.MultiBinary`` observations SB3-Contrib ^^^^^^^^^^^ Bug Fixes: ^^^^^^^^^^ -- Fix return type of ``evaluate_actions`` in ``ActorCritcPolicy`` to reflect that entropy is an optional tensor (@Rocamonde) -- Fix type annotation of ``policy`` in ``BaseAlgorithm`` and ``OffPolicyAlgorithm`` +- Fixed return type of ``evaluate_actions`` in ``ActorCritcPolicy`` to reflect that entropy is an optional tensor (@Rocamonde) +- Fixed type annotation of ``policy`` in ``BaseAlgorithm`` and ``OffPolicyAlgorithm`` - Allowed model trained with Python 3.7 to be loaded with Python 3.8+ without the ``custom_objects`` workaround - Raise an error when the same gym environment instance is passed as separate environments when creating a vectorized environment with more than one environment. (@Rocamonde) - Fix type annotation of ``model`` in ``evaluate_policy`` diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py index 01422aa..a406a7d 100644 --- a/stable_baselines3/common/preprocessing.py +++ b/stable_baselines3/common/preprocessing.py @@ -150,7 +150,10 @@ def get_obs_shape( return (int(len(observation_space.nvec)),) elif isinstance(observation_space, spaces.MultiBinary): # Number of binary features - return (int(observation_space.n),) + if type(observation_space.n) in [tuple, list, np.ndarray]: + return tuple(observation_space.n) + else: + return (int(observation_space.n),) elif isinstance(observation_space, spaces.Dict): return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()} diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 53c642c..1a3f871 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -299,14 +299,14 @@ def is_vectorized_multibinary_observation(observation: np.ndarray, observation_s :param observation_space: the observation space :return: whether the given observation is vectorized or not """ - if observation.shape == (observation_space.n,): + if observation.shape == observation_space.shape: return False - elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n: + elif len(observation.shape) == len(observation_space.shape) + 1 and observation.shape[1:] == observation_space.shape: return True else: raise ValueError( f"Error: Unexpected observation shape {observation.shape} for MultiBinary " - + f"environment, please use ({observation_space.n},) or " + + f"environment, please use {observation_space.shape} or " + f"(n_env, {observation_space.n}) for the observation shape." ) diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py new file mode 100644 index 0000000..89f869b --- /dev/null +++ b/tests/test_preprocessing.py @@ -0,0 +1,66 @@ +import torch +from gym import spaces + +from stable_baselines3.common.preprocessing import get_obs_shape, preprocess_obs + + +def test_get_obs_shape_discrete(): + assert get_obs_shape(spaces.Discrete(3)) == (1,) + + +def test_get_obs_shape_multidiscrete(): + assert get_obs_shape(spaces.MultiDiscrete([3, 2])) == (2,) + + +def test_get_obs_shape_multibinary(): + assert get_obs_shape(spaces.MultiBinary(3)) == (3,) + + +def test_get_obs_shape_multidimensional_multibinary(): + assert get_obs_shape(spaces.MultiBinary([3, 2])) == (3, 2) + + +def test_get_obs_shape_box(): + assert get_obs_shape(spaces.Box(-2, 2, shape=(3,))) == (3,) + + +def test_get_obs_shape_multidimensional_box(): + assert get_obs_shape(spaces.Box(-2, 2, shape=(3, 2))) == (3, 2) + + +def test_preprocess_obs_discrete(): + actual = preprocess_obs(torch.tensor([2], dtype=torch.long), spaces.Discrete(3)) + expected = torch.tensor([[0.0, 0.0, 1.0]], dtype=torch.float32) + torch.testing.assert_close(actual, expected) + + +def test_preprocess_obs_multidiscrete(): + actual = preprocess_obs(torch.tensor([[2, 0]], dtype=torch.long), spaces.MultiDiscrete([3, 2])) + expected = torch.tensor([[0.0, 0.0, 1.0, 1.0, 0.0]], dtype=torch.float32) + torch.testing.assert_close(actual, expected) + + +def test_preprocess_obs_multibinary(): + actual = preprocess_obs(torch.tensor([[1, 0, 1]], dtype=torch.long), spaces.MultiBinary(3)) + expected = torch.tensor([[1.0, 0.0, 1.0]], dtype=torch.float32) + torch.testing.assert_close(actual, expected) + + +def test_preprocess_obs_multidimensional_multibinary(): + actual = preprocess_obs(torch.tensor([[[1, 0], [1, 1], [0, 1]]], dtype=torch.long), spaces.MultiBinary([3, 2])) + expected = torch.tensor([[[1.0, 0.0], [1.0, 1.0], [0.0, 1.0]]], dtype=torch.float32) + torch.testing.assert_close(actual, expected) + + +def test_preprocess_obs_box(): + actual = preprocess_obs(torch.tensor([[1.5, 0.3, -1.8]], dtype=torch.float32), spaces.Box(-2, 2, shape=(3,))) + expected = torch.tensor([[1.5, 0.3, -1.8]], dtype=torch.float32) + torch.testing.assert_close(actual, expected) + + +def test_preprocess_obs_multidimensional_box(): + actual = preprocess_obs( + torch.tensor([[[1.5, 0.3, -1.8], [0.1, -0.6, -1.4]]], dtype=torch.float32), spaces.Box(-2, 2, shape=(3, 2)) + ) + expected = torch.tensor([[[1.5, 0.3, -1.8], [0.1, -0.6, -1.4]]], dtype=torch.float32) + torch.testing.assert_close(actual, expected) diff --git a/tests/test_spaces.py b/tests/test_spaces.py index 0696492..6f530b7 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -47,7 +47,7 @@ class DummyMultidimensionalAction(gym.Env): @pytest.mark.parametrize("model_class", [SAC, TD3, DQN]) -@pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8)]) +@pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8), DummyMultiBinary((3, 2))]) def test_identity_spaces(model_class, env): """ Additional tests for DQ/SAC/TD3 to check observation space support