Fix distributions type hints (#1733)

* Fix distributions type hints

* Add test for multim binary action space

* Fix test
This commit is contained in:
Antonin RAFFIN 2023-11-06 10:09:01 +01:00 committed by GitHub
parent 294f2b4309
commit 018ea5ab67
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 32 additions and 15 deletions

View file

@ -3,7 +3,7 @@
Changelog
==========
Release 2.2.0a9 (WIP)
Release 2.2.0a10 (WIP)
--------------------------
**Support for options at reset, bug fixes and better error messages**
@ -59,6 +59,7 @@ Others:
- Buffers do no call an additional ``.copy()`` when storing new transitions
- Fixed ``ActorCriticPolicy.extract_features()`` signature by adding an optional ``features_extractor`` argument
- Update dependencies (accept newer Shimmy/Sphinx version and remove ``sphinx_autodoc_typehints``)
- Fixed ``stable_baselines3/common/distributions.py`` type hints
Documentation:
^^^^^^^^^^^^^^

View file

@ -35,7 +35,8 @@ exclude = [
"stable_baselines3/common/on_policy_algorithm.py",
"stable_baselines3/common/vec_env/stacked_observations.py",
"stable_baselines3/common/vec_env/subproc_vec_env.py",
"stable_baselines3/common/vec_env/patch_gym.py"
"stable_baselines3/common/vec_env/patch_gym.py",
"stable_baselines3/common/distributions.py",
]
[tool.mypy]
@ -43,8 +44,7 @@ ignore_missing_imports = true
follow_imports = "silent"
show_error_codes = true
exclude = """(?x)(
stable_baselines3/common/distributions.py$
| stable_baselines3/common/off_policy_algorithm.py$
stable_baselines3/common/off_policy_algorithm.py$
| stable_baselines3/common/policies.py$
| stable_baselines3/common/vec_env/__init__.py$
| stable_baselines3/common/vec_env/vec_normalize.py$

View file

@ -175,7 +175,7 @@ class DiagGaussianDistribution(Distribution):
log_prob = self.distribution.log_prob(actions)
return sum_independent_dims(log_prob)
def entropy(self) -> th.Tensor:
def entropy(self) -> Optional[th.Tensor]:
return sum_independent_dims(self.distribution.entropy())
def sample(self) -> th.Tensor:
@ -216,7 +216,7 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
super().__init__(action_dim)
# Avoid NaN (prevents division by zero or log of zero)
self.epsilon = epsilon
self.gaussian_actions = None
self.gaussian_actions: Optional[th.Tensor] = None
def proba_distribution(
self: SelfSquashedDiagGaussianDistribution, mean_actions: th.Tensor, log_std: th.Tensor
@ -339,7 +339,7 @@ class MultiCategoricalDistribution(Distribution):
def proba_distribution(
self: SelfMultiCategoricalDistribution, action_logits: th.Tensor
) -> SelfMultiCategoricalDistribution:
self.distribution = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)]
self.distribution = [Categorical(logits=split) for split in th.split(action_logits, list(self.action_dims), dim=1)]
return self
def log_prob(self, actions: th.Tensor) -> th.Tensor:
@ -440,6 +440,13 @@ class StateDependentNoiseDistribution(Distribution):
:param epsilon: small value to avoid NaN due to numerical imprecision.
"""
bijector: Optional["TanhBijector"]
latent_sde_dim: Optional[int]
weights_dist: Normal
_latent_sde: th.Tensor
exploration_mat: th.Tensor
exploration_matrices: th.Tensor
def __init__(
self,
action_dim: int,
@ -454,10 +461,6 @@ class StateDependentNoiseDistribution(Distribution):
self.latent_sde_dim = None
self.mean_actions = None
self.log_std = None
self.weights_dist = None
self.exploration_mat = None
self.exploration_matrices = None
self._latent_sde = None
self.use_expln = use_expln
self.full_std = full_std
self.epsilon = epsilon
@ -489,6 +492,7 @@ class StateDependentNoiseDistribution(Distribution):
if self.full_std:
return std
assert self.latent_sde_dim is not None
# Reduce the number of parameters:
return th.ones(self.latent_sde_dim, self.action_dim).to(log_std.device) * std
@ -675,10 +679,13 @@ def make_proba_distribution(
cls = StateDependentNoiseDistribution if use_sde else DiagGaussianDistribution
return cls(get_action_dim(action_space), **dist_kwargs)
elif isinstance(action_space, spaces.Discrete):
return CategoricalDistribution(action_space.n, **dist_kwargs)
return CategoricalDistribution(int(action_space.n), **dist_kwargs)
elif isinstance(action_space, spaces.MultiDiscrete):
return MultiCategoricalDistribution(list(action_space.nvec), **dist_kwargs)
elif isinstance(action_space, spaces.MultiBinary):
assert isinstance(
action_space.n, int
), f"Multi-dimensional MultiBinary({action_space.n}) action space is not supported. You can flatten it instead."
return BernoulliDistribution(action_space.n, **dist_kwargs)
else:
raise NotImplementedError(
@ -702,7 +709,10 @@ def kl_divergence(dist_true: Distribution, dist_pred: Distribution) -> th.Tensor
# MultiCategoricalDistribution is not a PyTorch Distribution subclass
# so we need to implement it ourselves!
if isinstance(dist_pred, MultiCategoricalDistribution):
assert np.allclose(dist_pred.action_dims, dist_true.action_dims), "Error: distributions must have the same input space"
assert isinstance(dist_true, MultiCategoricalDistribution) # already checked above, for mypy
assert np.allclose(
dist_pred.action_dims, dist_true.action_dims
), f"Error: distributions must have the same input space: {dist_pred.action_dims} != {dist_true.action_dims}"
return th.stack(
[th.distributions.kl_divergence(p, q) for p, q in zip(dist_true.distribution, dist_pred.distribution)],
dim=1,

View file

@ -204,7 +204,7 @@ def get_action_dim(action_space: spaces.Space) -> int:
# Number of binary actions
assert isinstance(
action_space.n, int
), "Multi-dimensional MultiBinary action space is not supported. You can flatten it instead."
), f"Multi-dimensional MultiBinary({action_space.n}) action space is not supported. You can flatten it instead."
return int(action_space.n)
else:
raise NotImplementedError(f"{action_space} action space is not supported")

View file

@ -1 +1 @@
2.2.0a9
2.2.0a10

View file

@ -168,3 +168,9 @@ def test_float64_action_space(model_class, obs_space, action_space):
initial_obs, _ = env.reset()
action, _ = model.predict(initial_obs, deterministic=False)
assert action.dtype == env.action_space.dtype
def test_multidim_binary_not_supported():
env = DummyEnv(BOX_SPACE_FLOAT32, spaces.MultiBinary([2, 3]))
with pytest.raises(AssertionError, match=r"Multi-dimensional MultiBinary\(.*\) action space is not supported"):
A2C("MlpPolicy", env)