Fix tests/test_distributions.py type hint (#1186)

* Fixed test_distribution type hint

* Impose list[int] for action dim
This commit is contained in:
Quentin Gallouédec 2022-11-29 11:27:59 +01:00 committed by GitHub
parent aee0ba03c7
commit 0973b01b9d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 4 additions and 6 deletions

View file

@ -39,6 +39,7 @@ Others:
- Fixed flake8 config to be compatible with flake8 6+
- Goal-conditioned environments are now characterized by the availability of the ``compute_reward`` method, rather than by their inheritance to ``gym.GoalEnv``
- Replaced ``CartPole-v0`` by ``CartPole-v1`` is tests
- Fixed ``tests/test_distributions.py`` type hint
Documentation:
^^^^^^^^^^^^^^

View file

@ -72,7 +72,6 @@ exclude = (?x)(
| stable_baselines3/sac/sac.py$
| stable_baselines3/td3/policies.py$
| stable_baselines3/td3/td3.py$
| tests/test_distributions.py$
| tests/test_logger.py$
| tests/test_tensorboard.py$
| tests/test_train_eval_mode.py$

View file

@ -679,7 +679,7 @@ def make_proba_distribution(
elif isinstance(action_space, spaces.Discrete):
return CategoricalDistribution(action_space.n, **dist_kwargs)
elif isinstance(action_space, spaces.MultiDiscrete):
return MultiCategoricalDistribution(action_space.nvec, **dist_kwargs)
return MultiCategoricalDistribution(list(action_space.nvec), **dist_kwargs)
elif isinstance(action_space, spaces.MultiBinary):
return BernoulliDistribution(action_space.n, **dist_kwargs)
else:

View file

@ -55,7 +55,7 @@ def test_squashed_gaussian(model_class):
@pytest.fixture()
def dummy_model_distribution_obs_and_actions() -> Tuple[A2C, np.array, np.array]:
def dummy_model_distribution_obs_and_actions() -> Tuple[A2C, np.ndarray, np.ndarray]:
"""
Fixture creating a Pendulum-v1 gym env, an A2C model and sampling 10 random observations and actions from the env
:return: A2C model, random observations, random actions
@ -165,9 +165,7 @@ def test_categorical(dist, CAT_ACTIONS):
BernoulliDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS)),
CategoricalDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS)),
DiagGaussianDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS), th.rand(N_ACTIONS)),
MultiCategoricalDistribution(np.array([N_ACTIONS, N_ACTIONS])).proba_distribution(
th.rand(1, sum([N_ACTIONS, N_ACTIONS]))
),
MultiCategoricalDistribution([N_ACTIONS, N_ACTIONS]).proba_distribution(th.rand(1, sum([N_ACTIONS, N_ACTIONS]))),
SquashedDiagGaussianDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS), th.rand(N_ACTIONS)),
StateDependentNoiseDistribution(N_ACTIONS).proba_distribution(
th.rand(N_ACTIONS), th.rand([N_ACTIONS, N_ACTIONS]), th.rand([N_ACTIONS, N_ACTIONS])