mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-18 21:30:19 +00:00
Fix distributions type hints (#1733)
* Fix distributions type hints * Add test for multim binary action space * Fix test
This commit is contained in:
parent
294f2b4309
commit
018ea5ab67
6 changed files with 32 additions and 15 deletions
|
|
@ -3,7 +3,7 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 2.2.0a9 (WIP)
|
||||
Release 2.2.0a10 (WIP)
|
||||
--------------------------
|
||||
**Support for options at reset, bug fixes and better error messages**
|
||||
|
||||
|
|
@ -59,6 +59,7 @@ Others:
|
|||
- Buffers do no call an additional ``.copy()`` when storing new transitions
|
||||
- Fixed ``ActorCriticPolicy.extract_features()`` signature by adding an optional ``features_extractor`` argument
|
||||
- Update dependencies (accept newer Shimmy/Sphinx version and remove ``sphinx_autodoc_typehints``)
|
||||
- Fixed ``stable_baselines3/common/distributions.py`` type hints
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -35,7 +35,8 @@ exclude = [
|
|||
"stable_baselines3/common/on_policy_algorithm.py",
|
||||
"stable_baselines3/common/vec_env/stacked_observations.py",
|
||||
"stable_baselines3/common/vec_env/subproc_vec_env.py",
|
||||
"stable_baselines3/common/vec_env/patch_gym.py"
|
||||
"stable_baselines3/common/vec_env/patch_gym.py",
|
||||
"stable_baselines3/common/distributions.py",
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
|
|
@ -43,8 +44,7 @@ ignore_missing_imports = true
|
|||
follow_imports = "silent"
|
||||
show_error_codes = true
|
||||
exclude = """(?x)(
|
||||
stable_baselines3/common/distributions.py$
|
||||
| stable_baselines3/common/off_policy_algorithm.py$
|
||||
stable_baselines3/common/off_policy_algorithm.py$
|
||||
| stable_baselines3/common/policies.py$
|
||||
| stable_baselines3/common/vec_env/__init__.py$
|
||||
| stable_baselines3/common/vec_env/vec_normalize.py$
|
||||
|
|
|
|||
|
|
@ -175,7 +175,7 @@ class DiagGaussianDistribution(Distribution):
|
|||
log_prob = self.distribution.log_prob(actions)
|
||||
return sum_independent_dims(log_prob)
|
||||
|
||||
def entropy(self) -> th.Tensor:
|
||||
def entropy(self) -> Optional[th.Tensor]:
|
||||
return sum_independent_dims(self.distribution.entropy())
|
||||
|
||||
def sample(self) -> th.Tensor:
|
||||
|
|
@ -216,7 +216,7 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
|
|||
super().__init__(action_dim)
|
||||
# Avoid NaN (prevents division by zero or log of zero)
|
||||
self.epsilon = epsilon
|
||||
self.gaussian_actions = None
|
||||
self.gaussian_actions: Optional[th.Tensor] = None
|
||||
|
||||
def proba_distribution(
|
||||
self: SelfSquashedDiagGaussianDistribution, mean_actions: th.Tensor, log_std: th.Tensor
|
||||
|
|
@ -339,7 +339,7 @@ class MultiCategoricalDistribution(Distribution):
|
|||
def proba_distribution(
|
||||
self: SelfMultiCategoricalDistribution, action_logits: th.Tensor
|
||||
) -> SelfMultiCategoricalDistribution:
|
||||
self.distribution = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)]
|
||||
self.distribution = [Categorical(logits=split) for split in th.split(action_logits, list(self.action_dims), dim=1)]
|
||||
return self
|
||||
|
||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||
|
|
@ -440,6 +440,13 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
:param epsilon: small value to avoid NaN due to numerical imprecision.
|
||||
"""
|
||||
|
||||
bijector: Optional["TanhBijector"]
|
||||
latent_sde_dim: Optional[int]
|
||||
weights_dist: Normal
|
||||
_latent_sde: th.Tensor
|
||||
exploration_mat: th.Tensor
|
||||
exploration_matrices: th.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
action_dim: int,
|
||||
|
|
@ -454,10 +461,6 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
self.latent_sde_dim = None
|
||||
self.mean_actions = None
|
||||
self.log_std = None
|
||||
self.weights_dist = None
|
||||
self.exploration_mat = None
|
||||
self.exploration_matrices = None
|
||||
self._latent_sde = None
|
||||
self.use_expln = use_expln
|
||||
self.full_std = full_std
|
||||
self.epsilon = epsilon
|
||||
|
|
@ -489,6 +492,7 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
|
||||
if self.full_std:
|
||||
return std
|
||||
assert self.latent_sde_dim is not None
|
||||
# Reduce the number of parameters:
|
||||
return th.ones(self.latent_sde_dim, self.action_dim).to(log_std.device) * std
|
||||
|
||||
|
|
@ -675,10 +679,13 @@ def make_proba_distribution(
|
|||
cls = StateDependentNoiseDistribution if use_sde else DiagGaussianDistribution
|
||||
return cls(get_action_dim(action_space), **dist_kwargs)
|
||||
elif isinstance(action_space, spaces.Discrete):
|
||||
return CategoricalDistribution(action_space.n, **dist_kwargs)
|
||||
return CategoricalDistribution(int(action_space.n), **dist_kwargs)
|
||||
elif isinstance(action_space, spaces.MultiDiscrete):
|
||||
return MultiCategoricalDistribution(list(action_space.nvec), **dist_kwargs)
|
||||
elif isinstance(action_space, spaces.MultiBinary):
|
||||
assert isinstance(
|
||||
action_space.n, int
|
||||
), f"Multi-dimensional MultiBinary({action_space.n}) action space is not supported. You can flatten it instead."
|
||||
return BernoulliDistribution(action_space.n, **dist_kwargs)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
|
|
@ -702,7 +709,10 @@ def kl_divergence(dist_true: Distribution, dist_pred: Distribution) -> th.Tensor
|
|||
# MultiCategoricalDistribution is not a PyTorch Distribution subclass
|
||||
# so we need to implement it ourselves!
|
||||
if isinstance(dist_pred, MultiCategoricalDistribution):
|
||||
assert np.allclose(dist_pred.action_dims, dist_true.action_dims), "Error: distributions must have the same input space"
|
||||
assert isinstance(dist_true, MultiCategoricalDistribution) # already checked above, for mypy
|
||||
assert np.allclose(
|
||||
dist_pred.action_dims, dist_true.action_dims
|
||||
), f"Error: distributions must have the same input space: {dist_pred.action_dims} != {dist_true.action_dims}"
|
||||
return th.stack(
|
||||
[th.distributions.kl_divergence(p, q) for p, q in zip(dist_true.distribution, dist_pred.distribution)],
|
||||
dim=1,
|
||||
|
|
|
|||
|
|
@ -204,7 +204,7 @@ def get_action_dim(action_space: spaces.Space) -> int:
|
|||
# Number of binary actions
|
||||
assert isinstance(
|
||||
action_space.n, int
|
||||
), "Multi-dimensional MultiBinary action space is not supported. You can flatten it instead."
|
||||
), f"Multi-dimensional MultiBinary({action_space.n}) action space is not supported. You can flatten it instead."
|
||||
return int(action_space.n)
|
||||
else:
|
||||
raise NotImplementedError(f"{action_space} action space is not supported")
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
2.2.0a9
|
||||
2.2.0a10
|
||||
|
|
|
|||
|
|
@ -168,3 +168,9 @@ def test_float64_action_space(model_class, obs_space, action_space):
|
|||
initial_obs, _ = env.reset()
|
||||
action, _ = model.predict(initial_obs, deterministic=False)
|
||||
assert action.dtype == env.action_space.dtype
|
||||
|
||||
|
||||
def test_multidim_binary_not_supported():
|
||||
env = DummyEnv(BOX_SPACE_FLOAT32, spaces.MultiBinary([2, 3]))
|
||||
with pytest.raises(AssertionError, match=r"Multi-dimensional MultiBinary\(.*\) action space is not supported"):
|
||||
A2C("MlpPolicy", env)
|
||||
|
|
|
|||
Loading…
Reference in a new issue