mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-08 00:23:22 +00:00
Fix tests/test_distributions.py type hint (#1186)
* Fixed test_distribution type hint * Impose list[int] for action dim
This commit is contained in:
parent
aee0ba03c7
commit
0973b01b9d
4 changed files with 4 additions and 6 deletions
|
|
@ -39,6 +39,7 @@ Others:
|
|||
- Fixed flake8 config to be compatible with flake8 6+
|
||||
- Goal-conditioned environments are now characterized by the availability of the ``compute_reward`` method, rather than by their inheritance to ``gym.GoalEnv``
|
||||
- Replaced ``CartPole-v0`` by ``CartPole-v1`` is tests
|
||||
- Fixed ``tests/test_distributions.py`` type hint
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -72,7 +72,6 @@ exclude = (?x)(
|
|||
| stable_baselines3/sac/sac.py$
|
||||
| stable_baselines3/td3/policies.py$
|
||||
| stable_baselines3/td3/td3.py$
|
||||
| tests/test_distributions.py$
|
||||
| tests/test_logger.py$
|
||||
| tests/test_tensorboard.py$
|
||||
| tests/test_train_eval_mode.py$
|
||||
|
|
|
|||
|
|
@ -679,7 +679,7 @@ def make_proba_distribution(
|
|||
elif isinstance(action_space, spaces.Discrete):
|
||||
return CategoricalDistribution(action_space.n, **dist_kwargs)
|
||||
elif isinstance(action_space, spaces.MultiDiscrete):
|
||||
return MultiCategoricalDistribution(action_space.nvec, **dist_kwargs)
|
||||
return MultiCategoricalDistribution(list(action_space.nvec), **dist_kwargs)
|
||||
elif isinstance(action_space, spaces.MultiBinary):
|
||||
return BernoulliDistribution(action_space.n, **dist_kwargs)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ def test_squashed_gaussian(model_class):
|
|||
|
||||
|
||||
@pytest.fixture()
|
||||
def dummy_model_distribution_obs_and_actions() -> Tuple[A2C, np.array, np.array]:
|
||||
def dummy_model_distribution_obs_and_actions() -> Tuple[A2C, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Fixture creating a Pendulum-v1 gym env, an A2C model and sampling 10 random observations and actions from the env
|
||||
:return: A2C model, random observations, random actions
|
||||
|
|
@ -165,9 +165,7 @@ def test_categorical(dist, CAT_ACTIONS):
|
|||
BernoulliDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS)),
|
||||
CategoricalDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS)),
|
||||
DiagGaussianDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS), th.rand(N_ACTIONS)),
|
||||
MultiCategoricalDistribution(np.array([N_ACTIONS, N_ACTIONS])).proba_distribution(
|
||||
th.rand(1, sum([N_ACTIONS, N_ACTIONS]))
|
||||
),
|
||||
MultiCategoricalDistribution([N_ACTIONS, N_ACTIONS]).proba_distribution(th.rand(1, sum([N_ACTIONS, N_ACTIONS]))),
|
||||
SquashedDiagGaussianDistribution(N_ACTIONS).proba_distribution(th.rand(N_ACTIONS), th.rand(N_ACTIONS)),
|
||||
StateDependentNoiseDistribution(N_ACTIONS).proba_distribution(
|
||||
th.rand(N_ACTIONS), th.rand([N_ACTIONS, N_ACTIONS]), th.rand([N_ACTIONS, N_ACTIONS])
|
||||
|
|
|
|||
Loading…
Reference in a new issue