From 0973b01b9dcee853bdd7314db55c4bc524a7f20b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Tue, 29 Nov 2022 11:27:59 +0100 Subject: [PATCH] Fix `tests/test_distributions.py` type hint (#1186) * Fixed test_distribution type hint * Impose list[int] for action dim --- docs/misc/changelog.rst | 1 + setup.cfg | 1 - stable_baselines3/common/distributions.py | 2 +- tests/test_distributions.py | 6 ++---- 4 files changed, 4 insertions(+), 6 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index b4355c7..26b147b 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -39,6 +39,7 @@ Others: - Fixed flake8 config to be compatible with flake8 6+ - Goal-conditioned environments are now characterized by the availability of the ``compute_reward`` method, rather than by their inheritance to ``gym.GoalEnv`` - Replaced ``CartPole-v0`` by ``CartPole-v1`` is tests +- Fixed ``tests/test_distributions.py`` type hint Documentation: ^^^^^^^^^^^^^^ diff --git a/setup.cfg b/setup.cfg index 5e30726..5e41364 100644 --- a/setup.cfg +++ b/setup.cfg @@ -72,7 +72,6 @@ exclude = (?x)( | stable_baselines3/sac/sac.py$ | stable_baselines3/td3/policies.py$ | stable_baselines3/td3/td3.py$ - | tests/test_distributions.py$ | tests/test_logger.py$ | tests/test_tensorboard.py$ | tests/test_train_eval_mode.py$ diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index 63eb475..b78ef82 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -679,7 +679,7 @@ def make_proba_distribution( elif isinstance(action_space, spaces.Discrete): return CategoricalDistribution(action_space.n, **dist_kwargs) elif isinstance(action_space, spaces.MultiDiscrete): - return MultiCategoricalDistribution(action_space.nvec, **dist_kwargs) + return MultiCategoricalDistribution(list(action_space.nvec), **dist_kwargs) elif isinstance(action_space, spaces.MultiBinary): return BernoulliDistribution(action_space.n, **dist_kwargs) else: diff --git a/tests/test_distributions.py b/tests/test_distributions.py index 513429b..e782182 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -55,7 +55,7 @@ def test_squashed_gaussian(model_class): @pytest.fixture() -def dummy_model_distribution_obs_and_actions() -> Tuple[A2C, np.array, np.array]: +def dummy_model_distribution_obs_and_actions() -> Tuple[A2C, np.ndarray, np.ndarray]: """ Fixture creating a Pendulum-v1 gym env, an A2C model and sampling 10 random observations and actions from the env :return: A2C model, random observations, random actions @@ -165,9 +165,7 @@ def test_categorical(dist, CAT_ACTIONS): BernoulliDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS)), CategoricalDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS)), DiagGaussianDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS), th.rand(N_ACTIONS)), - MultiCategoricalDistribution(np.array([N_ACTIONS, N_ACTIONS])).proba_distribution( - th.rand(1, sum([N_ACTIONS, N_ACTIONS])) - ), + MultiCategoricalDistribution([N_ACTIONS, N_ACTIONS]).proba_distribution(th.rand(1, sum([N_ACTIONS, N_ACTIONS]))), SquashedDiagGaussianDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS), th.rand(N_ACTIONS)), StateDependentNoiseDistribution(N_ACTIONS).proba_distribution( th.rand(N_ACTIONS), th.rand([N_ACTIONS, N_ACTIONS]), th.rand([N_ACTIONS, N_ACTIONS])