Review distributions

This commit is contained in:
Adam Gleave 2020-07-02 19:18:51 -07:00
parent 56fd89da8d
commit 7ba48dce48

View file

@ -1,3 +1,6 @@
"""Probability distributions."""
from abc import ABC, abstractmethod
from typing import Optional, Tuple, Dict, Any, List
import gym
import torch as th
@ -8,36 +11,38 @@ from gym import spaces
from stable_baselines3.common.preprocessing import get_action_dim
class Distribution(object):
class Distribution(ABC):
"""Abstract base class for distributions."""
def __init__(self):
super(Distribution, self).__init__()
@abstractmethod
def log_prob(self, x: th.Tensor) -> th.Tensor:
"""
returns the log likelihood
Returns the log likelihood
:param x: (th.Tensor) the taken action
:return: (th.Tensor) The log likelihood of the distribution
"""
raise NotImplementedError
@abstractmethod
def entropy(self) -> Optional[th.Tensor]:
"""
Returns Shannon's entropy of the probability
:return: (Optional[th.Tensor]) the entropy,
return None if no analytical form is known
:return: (Optional[th.Tensor]) the entropy, or None if no analytical form is known
"""
raise NotImplementedError
@abstractmethod
def sample(self) -> th.Tensor:
"""
Returns a sample from the probability distribution
:return: (th.Tensor) the stochastic action
"""
raise NotImplementedError
@abstractmethod
def mode(self) -> th.Tensor:
"""
Returns the most likely action (deterministic output)
@ -45,7 +50,6 @@ class Distribution(object):
:return: (th.Tensor) the stochastic action
"""
raise NotImplementedError
def get_actions(self, deterministic: bool = False) -> th.Tensor:
"""
@ -58,6 +62,7 @@ class Distribution(object):
return self.mode()
return self.sample()
@abstractmethod
def actions_from_params(self, *args, **kwargs) -> th.Tensor:
"""
Returns samples from the probability distribution
@ -65,8 +70,8 @@ class Distribution(object):
:return: (th.Tensor) actions
"""
raise NotImplementedError
@abstractmethod
def log_prob_from_params(self, *args, **kwargs) -> Tuple[th.Tensor, th.Tensor]:
"""
Returns samples and the associated log probabilities
@ -74,14 +79,12 @@ class Distribution(object):
:return: (th.Tuple[th.Tensor, th.Tensor]) actions and log prob
"""
raise NotImplementedError
def sum_independent_dims(tensor: th.Tensor) -> th.Tensor:
"""
Continuous actions are usually considered to be independent,
so we can sum the components for the ``log_prob``
or the entropy.
so we can sum components of the ``log_prob`` or the entropy.
:param tensor: (th.Tensor) shape: (n_batch, n_actions) or (n_batch,)
:return: (th.Tensor) shape: (n_batch,)
@ -95,8 +98,7 @@ def sum_independent_dims(tensor: th.Tensor) -> th.Tensor:
class DiagGaussianDistribution(Distribution):
"""
Gaussian distribution with diagonal covariance matrix,
for continuous actions.
Gaussian distribution with diagonal covariance matrix, for continuous actions.
:param action_dim: (int) Dimension of the action space.
"""
@ -115,7 +117,7 @@ class DiagGaussianDistribution(Distribution):
one output will be the mean of the Gaussian, the other parameter will be the
standard deviation (log std in fact to allow negative values)
:param latent_dim: (int) Dimension og the last layer of the policy (before the action layer)
:param latent_dim: (int) Dimension of the last layer of the policy (before the action layer)
:param log_std_init: (float) Initial value for the log standard deviation
:return: (nn.Linear, nn.Parameter)
"""
@ -137,15 +139,26 @@ class DiagGaussianDistribution(Distribution):
self.distribution = Normal(mean_actions, action_std)
return self
def mode(self) -> th.Tensor:
return self.distribution.mean
def log_prob(self, actions: th.Tensor) -> th.Tensor:
"""
Get the log probabilities of actions according to the distribution.
Note that you must first call the ``proba_distribution()`` method.
:param actions: (th.Tensor)
:return: (th.Tensor)
"""
log_prob = self.distribution.log_prob(actions)
return sum_independent_dims(log_prob)
def entropy(self) -> th.Tensor:
return sum_independent_dims(self.distribution.entropy())
def sample(self) -> th.Tensor:
# Reparametrization trick to pass gradients
return self.distribution.rsample()
def entropy(self) -> th.Tensor:
return sum_independent_dims(self.distribution.entropy())
def mode(self) -> th.Tensor:
return self.distribution.mean
def actions_from_params(self, mean_actions: th.Tensor,
log_std: th.Tensor,
@ -168,22 +181,10 @@ class DiagGaussianDistribution(Distribution):
log_prob = self.log_prob(actions)
return actions, log_prob
def log_prob(self, actions: th.Tensor) -> th.Tensor:
"""
Get the log probabilities of actions according to the distribution.
Note that you must call ``proba_distribution()`` method before.
:param actions: (th.Tensor)
:return: (th.Tensor)
"""
log_prob = self.distribution.log_prob(actions)
return sum_independent_dims(log_prob)
class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
"""
Gaussian distribution with diagonal covariance matrix,
followed by a squashing function (tanh) to ensure bounds.
Gaussian distribution with diagonal covariance matrix, followed by a squashing function (tanh) to ensure bounds.
:param action_dim: (int) Dimension of the action space.
:param epsilon: (float) small value to avoid NaN due to numerical imprecision.
@ -200,27 +201,6 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
super(SquashedDiagGaussianDistribution, self).proba_distribution(mean_actions, log_std)
return self
def mode(self) -> th.Tensor:
self.gaussian_actions = self.distribution.mean
# Squash the output
return th.tanh(self.gaussian_actions)
def entropy(self) -> Optional[th.Tensor]:
# No analytical form,
# entropy needs to be estimated using -log_prob.mean()
return None
def sample(self) -> th.Tensor:
# Reparametrization trick to pass gradients
self.gaussian_actions = self.distribution.rsample()
return th.tanh(self.gaussian_actions)
def log_prob_from_params(self, mean_actions: th.Tensor,
log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
action = self.actions_from_params(mean_actions, log_std)
log_prob = self.log_prob(action, self.gaussian_actions)
return action, log_prob
def log_prob(self, actions: th.Tensor,
gaussian_actions: Optional[th.Tensor] = None) -> th.Tensor:
# Inverse tanh
@ -237,6 +217,27 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
log_prob -= th.sum(th.log(1 - actions ** 2 + self.epsilon), dim=1)
return log_prob
def entropy(self) -> Optional[th.Tensor]:
# No analytical form,
# entropy needs to be estimated using -log_prob.mean()
return None
def sample(self) -> th.Tensor:
# Reparametrization trick to pass gradients
self.gaussian_actions = super().sample()
return th.tanh(self.gaussian_actions)
def mode(self) -> th.Tensor:
self.gaussian_actions = super().mode()
# Squash the output
return th.tanh(self.gaussian_actions)
def log_prob_from_params(self, mean_actions: th.Tensor,
log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
action = self.actions_from_params(mean_actions, log_std)
log_prob = self.log_prob(action, self.gaussian_actions)
return action, log_prob
class CategoricalDistribution(Distribution):
"""
@ -267,14 +268,17 @@ class CategoricalDistribution(Distribution):
self.distribution = Categorical(logits=action_logits)
return self
def mode(self) -> th.Tensor:
return th.argmax(self.distribution.probs, dim=1)
def log_prob(self, actions: th.Tensor) -> th.Tensor:
return self.distribution.log_prob(actions)
def entropy(self) -> th.Tensor:
return self.distribution.entropy()
def sample(self) -> th.Tensor:
return self.distribution.sample()
def entropy(self) -> th.Tensor:
return self.distribution.entropy()
def mode(self) -> th.Tensor:
return th.argmax(self.distribution.probs, dim=1)
def actions_from_params(self, action_logits: th.Tensor,
deterministic: bool = False) -> th.Tensor:
@ -287,9 +291,6 @@ class CategoricalDistribution(Distribution):
log_prob = self.log_prob(actions)
return actions, log_prob
def log_prob(self, actions: th.Tensor) -> th.Tensor:
return self.distribution.log_prob(actions)
class MultiCategoricalDistribution(Distribution):
"""
@ -321,14 +322,19 @@ class MultiCategoricalDistribution(Distribution):
self.distributions = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)]
return self
def mode(self) -> th.Tensor:
return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distributions], dim=1)
def log_prob(self, actions: th.Tensor) -> th.Tensor:
# Extract each discrete action and compute log prob for their respective distributions
return th.stack([dist.log_prob(action) for dist, action in zip(self.distributions,
th.unbind(actions, dim=1))], dim=1).sum(dim=1)
def entropy(self) -> th.Tensor:
return th.stack([dist.entropy() for dist in self.distributions], dim=1).sum(dim=1)
def sample(self) -> th.Tensor:
return th.stack([dist.sample() for dist in self.distributions], dim=1)
def entropy(self) -> th.Tensor:
return th.stack([dist.entropy() for dist in self.distributions], dim=1).sum(dim=1)
def mode(self) -> th.Tensor:
return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distributions], dim=1)
def actions_from_params(self, action_logits: th.Tensor,
deterministic: bool = False) -> th.Tensor:
@ -341,11 +347,6 @@ class MultiCategoricalDistribution(Distribution):
log_prob = self.log_prob(actions)
return actions, log_prob
def log_prob(self, actions: th.Tensor) -> th.Tensor:
# Extract each discrete action and compute log prob for their respective distributions
return th.stack([dist.log_prob(action) for dist, action in zip(self.distributions,
th.unbind(actions, dim=1))], dim=1).sum(dim=1)
class BernoulliDistribution(Distribution):
"""
@ -375,14 +376,17 @@ class BernoulliDistribution(Distribution):
self.distribution = Bernoulli(logits=action_logits)
return self
def mode(self) -> th.Tensor:
return th.round(self.distribution.probs)
def log_prob(self, actions: th.Tensor) -> th.Tensor:
return self.distribution.log_prob(actions).sum(dim=1)
def entropy(self) -> th.Tensor:
return self.distribution.entropy().sum(dim=1)
def sample(self) -> th.Tensor:
return self.distribution.sample()
def entropy(self) -> th.Tensor:
return self.distribution.entropy().sum(dim=1)
def mode(self) -> th.Tensor:
return th.round(self.distribution.probs)
def actions_from_params(self, action_logits: th.Tensor,
deterministic: bool = False) -> th.Tensor:
@ -395,9 +399,6 @@ class BernoulliDistribution(Distribution):
log_prob = self.log_prob(actions)
return actions, log_prob
def log_prob(self, actions: th.Tensor) -> th.Tensor:
return self.distribution.log_prob(actions).sum(dim=1)
class StateDependentNoiseDistribution(Distribution):
"""
@ -414,7 +415,7 @@ class StateDependentNoiseDistribution(Distribution):
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param squash_output: (bool) Whether to squash the output using a tanh function,
this allows to ensure boundaries.
this ensures bounds are satisfied.
:param learn_features: (bool) Whether to learn features for gSDE or not.
This will enable gradients to be backpropagated through the features
``latent_sde`` in the code.
@ -529,6 +530,35 @@ class StateDependentNoiseDistribution(Distribution):
self.distribution = Normal(mean_actions, th.sqrt(variance + self.epsilon))
return self
def log_prob(self, actions: th.Tensor) -> th.Tensor:
if self.bijector is not None:
gaussian_actions = self.bijector.inverse(actions)
else:
gaussian_actions = actions
# log likelihood for a gaussian
log_prob = self.distribution.log_prob(gaussian_actions)
# Sum along action dim
log_prob = sum_independent_dims(log_prob)
if self.bijector is not None:
# Squash correction (from original SAC implementation)
log_prob -= th.sum(self.bijector.log_prob_correction(gaussian_actions), dim=1)
return log_prob
def entropy(self) -> Optional[th.Tensor]:
if self.bijector is not None:
# No analytical form,
# entropy needs to be estimated using -log_prob.mean()
return None
return sum_independent_dims(self.distribution.entropy())
def sample(self) -> th.Tensor:
noise = self.get_noise(self._latent_sde)
actions = self.distribution.mean + noise
if self.bijector is not None:
return self.bijector.forward(actions)
return actions
def mode(self) -> th.Tensor:
actions = self.distribution.mean
if self.bijector is not None:
@ -547,20 +577,6 @@ class StateDependentNoiseDistribution(Distribution):
noise = th.bmm(latent_sde, self.exploration_matrices)
return noise.squeeze(1)
def sample(self) -> th.Tensor:
noise = self.get_noise(self._latent_sde)
actions = self.distribution.mean + noise
if self.bijector is not None:
return self.bijector.forward(actions)
return actions
def entropy(self) -> Optional[th.Tensor]:
# No analytical form,
# entropy needs to be estimated using -log_prob.mean()
if self.bijector is not None:
return None
return sum_independent_dims(self.distribution.entropy())
def actions_from_params(self, mean_actions: th.Tensor,
log_std: th.Tensor,
latent_sde: th.Tensor,
@ -576,21 +592,6 @@ class StateDependentNoiseDistribution(Distribution):
log_prob = self.log_prob(actions)
return actions, log_prob
def log_prob(self, actions: th.Tensor) -> th.Tensor:
if self.bijector is not None:
gaussian_actions = self.bijector.inverse(actions)
else:
gaussian_actions = actions
# log likelihood for a gaussian
log_prob = self.distribution.log_prob(gaussian_actions)
# Sum along action dim
log_prob = sum_independent_dims(log_prob)
if self.bijector is not None:
# Squash correction (from original SAC implementation)
log_prob -= th.sum(self.bijector.log_prob_correction(gaussian_actions), dim=1)
return log_prob
class TanhBijector(object):
"""
@ -653,9 +654,8 @@ def make_proba_distribution(action_space: gym.spaces.Space,
if isinstance(action_space, spaces.Box):
assert len(action_space.shape) == 1, "Error: the action space must be a vector"
if use_sde:
return StateDependentNoiseDistribution(get_action_dim(action_space), **dist_kwargs)
return DiagGaussianDistribution(get_action_dim(action_space), **dist_kwargs)
cls = StateDependentNoiseDistribution if use_sde else DiagGaussianDistribution
cls(get_action_dim(action_space), **dist_kwargs)
elif isinstance(action_space, spaces.Discrete):
return CategoricalDistribution(action_space.n, **dist_kwargs)
elif isinstance(action_space, spaces.MultiDiscrete):