mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-03 23:49:57 +00:00
Review distributions
This commit is contained in:
parent
56fd89da8d
commit
7ba48dce48
1 changed files with 109 additions and 109 deletions
|
|
@ -1,3 +1,6 @@
|
|||
"""Probability distributions."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Tuple, Dict, Any, List
|
||||
import gym
|
||||
import torch as th
|
||||
|
|
@ -8,36 +11,38 @@ from gym import spaces
|
|||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
|
||||
|
||||
class Distribution(object):
|
||||
class Distribution(ABC):
|
||||
"""Abstract base class for distributions."""
|
||||
|
||||
def __init__(self):
|
||||
super(Distribution, self).__init__()
|
||||
|
||||
@abstractmethod
|
||||
def log_prob(self, x: th.Tensor) -> th.Tensor:
|
||||
"""
|
||||
returns the log likelihood
|
||||
Returns the log likelihood
|
||||
|
||||
:param x: (th.Tensor) the taken action
|
||||
:return: (th.Tensor) The log likelihood of the distribution
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def entropy(self) -> Optional[th.Tensor]:
|
||||
"""
|
||||
Returns Shannon's entropy of the probability
|
||||
|
||||
:return: (Optional[th.Tensor]) the entropy,
|
||||
return None if no analytical form is known
|
||||
:return: (Optional[th.Tensor]) the entropy, or None if no analytical form is known
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def sample(self) -> th.Tensor:
|
||||
"""
|
||||
Returns a sample from the probability distribution
|
||||
|
||||
:return: (th.Tensor) the stochastic action
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def mode(self) -> th.Tensor:
|
||||
"""
|
||||
Returns the most likely action (deterministic output)
|
||||
|
|
@ -45,7 +50,6 @@ class Distribution(object):
|
|||
|
||||
:return: (th.Tensor) the stochastic action
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_actions(self, deterministic: bool = False) -> th.Tensor:
|
||||
"""
|
||||
|
|
@ -58,6 +62,7 @@ class Distribution(object):
|
|||
return self.mode()
|
||||
return self.sample()
|
||||
|
||||
@abstractmethod
|
||||
def actions_from_params(self, *args, **kwargs) -> th.Tensor:
|
||||
"""
|
||||
Returns samples from the probability distribution
|
||||
|
|
@ -65,8 +70,8 @@ class Distribution(object):
|
|||
|
||||
:return: (th.Tensor) actions
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def log_prob_from_params(self, *args, **kwargs) -> Tuple[th.Tensor, th.Tensor]:
|
||||
"""
|
||||
Returns samples and the associated log probabilities
|
||||
|
|
@ -74,14 +79,12 @@ class Distribution(object):
|
|||
|
||||
:return: (th.Tuple[th.Tensor, th.Tensor]) actions and log prob
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def sum_independent_dims(tensor: th.Tensor) -> th.Tensor:
|
||||
"""
|
||||
Continuous actions are usually considered to be independent,
|
||||
so we can sum the components for the ``log_prob``
|
||||
or the entropy.
|
||||
so we can sum components of the ``log_prob`` or the entropy.
|
||||
|
||||
:param tensor: (th.Tensor) shape: (n_batch, n_actions) or (n_batch,)
|
||||
:return: (th.Tensor) shape: (n_batch,)
|
||||
|
|
@ -95,8 +98,7 @@ def sum_independent_dims(tensor: th.Tensor) -> th.Tensor:
|
|||
|
||||
class DiagGaussianDistribution(Distribution):
|
||||
"""
|
||||
Gaussian distribution with diagonal covariance matrix,
|
||||
for continuous actions.
|
||||
Gaussian distribution with diagonal covariance matrix, for continuous actions.
|
||||
|
||||
:param action_dim: (int) Dimension of the action space.
|
||||
"""
|
||||
|
|
@ -115,7 +117,7 @@ class DiagGaussianDistribution(Distribution):
|
|||
one output will be the mean of the Gaussian, the other parameter will be the
|
||||
standard deviation (log std in fact to allow negative values)
|
||||
|
||||
:param latent_dim: (int) Dimension og the last layer of the policy (before the action layer)
|
||||
:param latent_dim: (int) Dimension of the last layer of the policy (before the action layer)
|
||||
:param log_std_init: (float) Initial value for the log standard deviation
|
||||
:return: (nn.Linear, nn.Parameter)
|
||||
"""
|
||||
|
|
@ -137,15 +139,26 @@ class DiagGaussianDistribution(Distribution):
|
|||
self.distribution = Normal(mean_actions, action_std)
|
||||
return self
|
||||
|
||||
def mode(self) -> th.Tensor:
|
||||
return self.distribution.mean
|
||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||
"""
|
||||
Get the log probabilities of actions according to the distribution.
|
||||
Note that you must first call the ``proba_distribution()`` method.
|
||||
|
||||
:param actions: (th.Tensor)
|
||||
:return: (th.Tensor)
|
||||
"""
|
||||
log_prob = self.distribution.log_prob(actions)
|
||||
return sum_independent_dims(log_prob)
|
||||
|
||||
def entropy(self) -> th.Tensor:
|
||||
return sum_independent_dims(self.distribution.entropy())
|
||||
|
||||
def sample(self) -> th.Tensor:
|
||||
# Reparametrization trick to pass gradients
|
||||
return self.distribution.rsample()
|
||||
|
||||
def entropy(self) -> th.Tensor:
|
||||
return sum_independent_dims(self.distribution.entropy())
|
||||
def mode(self) -> th.Tensor:
|
||||
return self.distribution.mean
|
||||
|
||||
def actions_from_params(self, mean_actions: th.Tensor,
|
||||
log_std: th.Tensor,
|
||||
|
|
@ -168,22 +181,10 @@ class DiagGaussianDistribution(Distribution):
|
|||
log_prob = self.log_prob(actions)
|
||||
return actions, log_prob
|
||||
|
||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||
"""
|
||||
Get the log probabilities of actions according to the distribution.
|
||||
Note that you must call ``proba_distribution()`` method before.
|
||||
|
||||
:param actions: (th.Tensor)
|
||||
:return: (th.Tensor)
|
||||
"""
|
||||
log_prob = self.distribution.log_prob(actions)
|
||||
return sum_independent_dims(log_prob)
|
||||
|
||||
|
||||
class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
|
||||
"""
|
||||
Gaussian distribution with diagonal covariance matrix,
|
||||
followed by a squashing function (tanh) to ensure bounds.
|
||||
Gaussian distribution with diagonal covariance matrix, followed by a squashing function (tanh) to ensure bounds.
|
||||
|
||||
:param action_dim: (int) Dimension of the action space.
|
||||
:param epsilon: (float) small value to avoid NaN due to numerical imprecision.
|
||||
|
|
@ -200,27 +201,6 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
|
|||
super(SquashedDiagGaussianDistribution, self).proba_distribution(mean_actions, log_std)
|
||||
return self
|
||||
|
||||
def mode(self) -> th.Tensor:
|
||||
self.gaussian_actions = self.distribution.mean
|
||||
# Squash the output
|
||||
return th.tanh(self.gaussian_actions)
|
||||
|
||||
def entropy(self) -> Optional[th.Tensor]:
|
||||
# No analytical form,
|
||||
# entropy needs to be estimated using -log_prob.mean()
|
||||
return None
|
||||
|
||||
def sample(self) -> th.Tensor:
|
||||
# Reparametrization trick to pass gradients
|
||||
self.gaussian_actions = self.distribution.rsample()
|
||||
return th.tanh(self.gaussian_actions)
|
||||
|
||||
def log_prob_from_params(self, mean_actions: th.Tensor,
|
||||
log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
action = self.actions_from_params(mean_actions, log_std)
|
||||
log_prob = self.log_prob(action, self.gaussian_actions)
|
||||
return action, log_prob
|
||||
|
||||
def log_prob(self, actions: th.Tensor,
|
||||
gaussian_actions: Optional[th.Tensor] = None) -> th.Tensor:
|
||||
# Inverse tanh
|
||||
|
|
@ -237,6 +217,27 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
|
|||
log_prob -= th.sum(th.log(1 - actions ** 2 + self.epsilon), dim=1)
|
||||
return log_prob
|
||||
|
||||
def entropy(self) -> Optional[th.Tensor]:
|
||||
# No analytical form,
|
||||
# entropy needs to be estimated using -log_prob.mean()
|
||||
return None
|
||||
|
||||
def sample(self) -> th.Tensor:
|
||||
# Reparametrization trick to pass gradients
|
||||
self.gaussian_actions = super().sample()
|
||||
return th.tanh(self.gaussian_actions)
|
||||
|
||||
def mode(self) -> th.Tensor:
|
||||
self.gaussian_actions = super().mode()
|
||||
# Squash the output
|
||||
return th.tanh(self.gaussian_actions)
|
||||
|
||||
def log_prob_from_params(self, mean_actions: th.Tensor,
|
||||
log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
action = self.actions_from_params(mean_actions, log_std)
|
||||
log_prob = self.log_prob(action, self.gaussian_actions)
|
||||
return action, log_prob
|
||||
|
||||
|
||||
class CategoricalDistribution(Distribution):
|
||||
"""
|
||||
|
|
@ -267,14 +268,17 @@ class CategoricalDistribution(Distribution):
|
|||
self.distribution = Categorical(logits=action_logits)
|
||||
return self
|
||||
|
||||
def mode(self) -> th.Tensor:
|
||||
return th.argmax(self.distribution.probs, dim=1)
|
||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||
return self.distribution.log_prob(actions)
|
||||
|
||||
def entropy(self) -> th.Tensor:
|
||||
return self.distribution.entropy()
|
||||
|
||||
def sample(self) -> th.Tensor:
|
||||
return self.distribution.sample()
|
||||
|
||||
def entropy(self) -> th.Tensor:
|
||||
return self.distribution.entropy()
|
||||
def mode(self) -> th.Tensor:
|
||||
return th.argmax(self.distribution.probs, dim=1)
|
||||
|
||||
def actions_from_params(self, action_logits: th.Tensor,
|
||||
deterministic: bool = False) -> th.Tensor:
|
||||
|
|
@ -287,9 +291,6 @@ class CategoricalDistribution(Distribution):
|
|||
log_prob = self.log_prob(actions)
|
||||
return actions, log_prob
|
||||
|
||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||
return self.distribution.log_prob(actions)
|
||||
|
||||
|
||||
class MultiCategoricalDistribution(Distribution):
|
||||
"""
|
||||
|
|
@ -321,14 +322,19 @@ class MultiCategoricalDistribution(Distribution):
|
|||
self.distributions = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)]
|
||||
return self
|
||||
|
||||
def mode(self) -> th.Tensor:
|
||||
return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distributions], dim=1)
|
||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||
# Extract each discrete action and compute log prob for their respective distributions
|
||||
return th.stack([dist.log_prob(action) for dist, action in zip(self.distributions,
|
||||
th.unbind(actions, dim=1))], dim=1).sum(dim=1)
|
||||
|
||||
def entropy(self) -> th.Tensor:
|
||||
return th.stack([dist.entropy() for dist in self.distributions], dim=1).sum(dim=1)
|
||||
|
||||
def sample(self) -> th.Tensor:
|
||||
return th.stack([dist.sample() for dist in self.distributions], dim=1)
|
||||
|
||||
def entropy(self) -> th.Tensor:
|
||||
return th.stack([dist.entropy() for dist in self.distributions], dim=1).sum(dim=1)
|
||||
def mode(self) -> th.Tensor:
|
||||
return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distributions], dim=1)
|
||||
|
||||
def actions_from_params(self, action_logits: th.Tensor,
|
||||
deterministic: bool = False) -> th.Tensor:
|
||||
|
|
@ -341,11 +347,6 @@ class MultiCategoricalDistribution(Distribution):
|
|||
log_prob = self.log_prob(actions)
|
||||
return actions, log_prob
|
||||
|
||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||
# Extract each discrete action and compute log prob for their respective distributions
|
||||
return th.stack([dist.log_prob(action) for dist, action in zip(self.distributions,
|
||||
th.unbind(actions, dim=1))], dim=1).sum(dim=1)
|
||||
|
||||
|
||||
class BernoulliDistribution(Distribution):
|
||||
"""
|
||||
|
|
@ -375,14 +376,17 @@ class BernoulliDistribution(Distribution):
|
|||
self.distribution = Bernoulli(logits=action_logits)
|
||||
return self
|
||||
|
||||
def mode(self) -> th.Tensor:
|
||||
return th.round(self.distribution.probs)
|
||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||
return self.distribution.log_prob(actions).sum(dim=1)
|
||||
|
||||
def entropy(self) -> th.Tensor:
|
||||
return self.distribution.entropy().sum(dim=1)
|
||||
|
||||
def sample(self) -> th.Tensor:
|
||||
return self.distribution.sample()
|
||||
|
||||
def entropy(self) -> th.Tensor:
|
||||
return self.distribution.entropy().sum(dim=1)
|
||||
def mode(self) -> th.Tensor:
|
||||
return th.round(self.distribution.probs)
|
||||
|
||||
def actions_from_params(self, action_logits: th.Tensor,
|
||||
deterministic: bool = False) -> th.Tensor:
|
||||
|
|
@ -395,9 +399,6 @@ class BernoulliDistribution(Distribution):
|
|||
log_prob = self.log_prob(actions)
|
||||
return actions, log_prob
|
||||
|
||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||
return self.distribution.log_prob(actions).sum(dim=1)
|
||||
|
||||
|
||||
class StateDependentNoiseDistribution(Distribution):
|
||||
"""
|
||||
|
|
@ -414,7 +415,7 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
a positive standard deviation (cf paper). It allows to keep variance
|
||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||
:param squash_output: (bool) Whether to squash the output using a tanh function,
|
||||
this allows to ensure boundaries.
|
||||
this ensures bounds are satisfied.
|
||||
:param learn_features: (bool) Whether to learn features for gSDE or not.
|
||||
This will enable gradients to be backpropagated through the features
|
||||
``latent_sde`` in the code.
|
||||
|
|
@ -529,6 +530,35 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
self.distribution = Normal(mean_actions, th.sqrt(variance + self.epsilon))
|
||||
return self
|
||||
|
||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||
if self.bijector is not None:
|
||||
gaussian_actions = self.bijector.inverse(actions)
|
||||
else:
|
||||
gaussian_actions = actions
|
||||
# log likelihood for a gaussian
|
||||
log_prob = self.distribution.log_prob(gaussian_actions)
|
||||
# Sum along action dim
|
||||
log_prob = sum_independent_dims(log_prob)
|
||||
|
||||
if self.bijector is not None:
|
||||
# Squash correction (from original SAC implementation)
|
||||
log_prob -= th.sum(self.bijector.log_prob_correction(gaussian_actions), dim=1)
|
||||
return log_prob
|
||||
|
||||
def entropy(self) -> Optional[th.Tensor]:
|
||||
if self.bijector is not None:
|
||||
# No analytical form,
|
||||
# entropy needs to be estimated using -log_prob.mean()
|
||||
return None
|
||||
return sum_independent_dims(self.distribution.entropy())
|
||||
|
||||
def sample(self) -> th.Tensor:
|
||||
noise = self.get_noise(self._latent_sde)
|
||||
actions = self.distribution.mean + noise
|
||||
if self.bijector is not None:
|
||||
return self.bijector.forward(actions)
|
||||
return actions
|
||||
|
||||
def mode(self) -> th.Tensor:
|
||||
actions = self.distribution.mean
|
||||
if self.bijector is not None:
|
||||
|
|
@ -547,20 +577,6 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
noise = th.bmm(latent_sde, self.exploration_matrices)
|
||||
return noise.squeeze(1)
|
||||
|
||||
def sample(self) -> th.Tensor:
|
||||
noise = self.get_noise(self._latent_sde)
|
||||
actions = self.distribution.mean + noise
|
||||
if self.bijector is not None:
|
||||
return self.bijector.forward(actions)
|
||||
return actions
|
||||
|
||||
def entropy(self) -> Optional[th.Tensor]:
|
||||
# No analytical form,
|
||||
# entropy needs to be estimated using -log_prob.mean()
|
||||
if self.bijector is not None:
|
||||
return None
|
||||
return sum_independent_dims(self.distribution.entropy())
|
||||
|
||||
def actions_from_params(self, mean_actions: th.Tensor,
|
||||
log_std: th.Tensor,
|
||||
latent_sde: th.Tensor,
|
||||
|
|
@ -576,21 +592,6 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
log_prob = self.log_prob(actions)
|
||||
return actions, log_prob
|
||||
|
||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||
if self.bijector is not None:
|
||||
gaussian_actions = self.bijector.inverse(actions)
|
||||
else:
|
||||
gaussian_actions = actions
|
||||
# log likelihood for a gaussian
|
||||
log_prob = self.distribution.log_prob(gaussian_actions)
|
||||
# Sum along action dim
|
||||
log_prob = sum_independent_dims(log_prob)
|
||||
|
||||
if self.bijector is not None:
|
||||
# Squash correction (from original SAC implementation)
|
||||
log_prob -= th.sum(self.bijector.log_prob_correction(gaussian_actions), dim=1)
|
||||
return log_prob
|
||||
|
||||
|
||||
class TanhBijector(object):
|
||||
"""
|
||||
|
|
@ -653,9 +654,8 @@ def make_proba_distribution(action_space: gym.spaces.Space,
|
|||
|
||||
if isinstance(action_space, spaces.Box):
|
||||
assert len(action_space.shape) == 1, "Error: the action space must be a vector"
|
||||
if use_sde:
|
||||
return StateDependentNoiseDistribution(get_action_dim(action_space), **dist_kwargs)
|
||||
return DiagGaussianDistribution(get_action_dim(action_space), **dist_kwargs)
|
||||
cls = StateDependentNoiseDistribution if use_sde else DiagGaussianDistribution
|
||||
cls(get_action_dim(action_space), **dist_kwargs)
|
||||
elif isinstance(action_space, spaces.Discrete):
|
||||
return CategoricalDistribution(action_space.n, **dist_kwargs)
|
||||
elif isinstance(action_space, spaces.MultiDiscrete):
|
||||
|
|
|
|||
Loading…
Reference in a new issue