diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index b4355c7..26b147b 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -39,6 +39,7 @@ Others: - Fixed flake8 config to be compatible with flake8 6+ - Goal-conditioned environments are now characterized by the availability of the ``compute_reward`` method, rather than by their inheritance to ``gym.GoalEnv`` - Replaced ``CartPole-v0`` by ``CartPole-v1`` is tests +- Fixed ``tests/test_distributions.py`` type hint Documentation: ^^^^^^^^^^^^^^ diff --git a/setup.cfg b/setup.cfg index 5e30726..5e41364 100644 --- a/setup.cfg +++ b/setup.cfg @@ -72,7 +72,6 @@ exclude = (?x)( | stable_baselines3/sac/sac.py$ | stable_baselines3/td3/policies.py$ | stable_baselines3/td3/td3.py$ - | tests/test_distributions.py$ | tests/test_logger.py$ | tests/test_tensorboard.py$ | tests/test_train_eval_mode.py$ diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index 63eb475..b78ef82 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -679,7 +679,7 @@ def make_proba_distribution( elif isinstance(action_space, spaces.Discrete): return CategoricalDistribution(action_space.n, **dist_kwargs) elif isinstance(action_space, spaces.MultiDiscrete): - return MultiCategoricalDistribution(action_space.nvec, **dist_kwargs) + return MultiCategoricalDistribution(list(action_space.nvec), **dist_kwargs) elif isinstance(action_space, spaces.MultiBinary): return BernoulliDistribution(action_space.n, **dist_kwargs) else: diff --git a/tests/test_distributions.py b/tests/test_distributions.py index 513429b..e782182 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -55,7 +55,7 @@ def test_squashed_gaussian(model_class): @pytest.fixture() -def dummy_model_distribution_obs_and_actions() -> Tuple[A2C, np.array, np.array]: +def dummy_model_distribution_obs_and_actions() -> Tuple[A2C, np.ndarray, np.ndarray]: """ Fixture creating a Pendulum-v1 gym env, an A2C model and sampling 10 random observations and actions from the env :return: A2C model, random observations, random actions @@ -165,9 +165,7 @@ def test_categorical(dist, CAT_ACTIONS): BernoulliDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS)), CategoricalDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS)), DiagGaussianDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS), th.rand(N_ACTIONS)), - MultiCategoricalDistribution(np.array([N_ACTIONS, N_ACTIONS])).proba_distribution( - th.rand(1, sum([N_ACTIONS, N_ACTIONS])) - ), + MultiCategoricalDistribution([N_ACTIONS, N_ACTIONS]).proba_distribution(th.rand(1, sum([N_ACTIONS, N_ACTIONS]))), SquashedDiagGaussianDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS), th.rand(N_ACTIONS)), StateDependentNoiseDistribution(N_ACTIONS).proba_distribution( th.rand(N_ACTIONS), th.rand([N_ACTIONS, N_ACTIONS]), th.rand([N_ACTIONS, N_ACTIONS])