diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index 951f163..068bba5 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -1,3 +1,6 @@ +"""Probability distributions.""" + +from abc import ABC, abstractmethod from typing import Optional, Tuple, Dict, Any, List import gym import torch as th @@ -8,36 +11,38 @@ from gym import spaces from stable_baselines3.common.preprocessing import get_action_dim -class Distribution(object): +class Distribution(ABC): + """Abstract base class for distributions.""" + def __init__(self): super(Distribution, self).__init__() + @abstractmethod def log_prob(self, x: th.Tensor) -> th.Tensor: """ - returns the log likelihood + Returns the log likelihood :param x: (th.Tensor) the taken action :return: (th.Tensor) The log likelihood of the distribution """ - raise NotImplementedError + @abstractmethod def entropy(self) -> Optional[th.Tensor]: """ Returns Shannon's entropy of the probability - :return: (Optional[th.Tensor]) the entropy, - return None if no analytical form is known + :return: (Optional[th.Tensor]) the entropy, or None if no analytical form is known """ - raise NotImplementedError + @abstractmethod def sample(self) -> th.Tensor: """ Returns a sample from the probability distribution :return: (th.Tensor) the stochastic action """ - raise NotImplementedError + @abstractmethod def mode(self) -> th.Tensor: """ Returns the most likely action (deterministic output) @@ -45,7 +50,6 @@ class Distribution(object): :return: (th.Tensor) the stochastic action """ - raise NotImplementedError def get_actions(self, deterministic: bool = False) -> th.Tensor: """ @@ -58,6 +62,7 @@ class Distribution(object): return self.mode() return self.sample() + @abstractmethod def actions_from_params(self, *args, **kwargs) -> th.Tensor: """ Returns samples from the probability distribution @@ -65,8 +70,8 @@ class Distribution(object): :return: (th.Tensor) actions """ - raise NotImplementedError + @abstractmethod def log_prob_from_params(self, *args, **kwargs) -> Tuple[th.Tensor, th.Tensor]: """ Returns samples and the associated log probabilities @@ -74,14 +79,12 @@ class Distribution(object): :return: (th.Tuple[th.Tensor, th.Tensor]) actions and log prob """ - raise NotImplementedError def sum_independent_dims(tensor: th.Tensor) -> th.Tensor: """ Continuous actions are usually considered to be independent, - so we can sum the components for the ``log_prob`` - or the entropy. + so we can sum components of the ``log_prob`` or the entropy. :param tensor: (th.Tensor) shape: (n_batch, n_actions) or (n_batch,) :return: (th.Tensor) shape: (n_batch,) @@ -95,8 +98,7 @@ def sum_independent_dims(tensor: th.Tensor) -> th.Tensor: class DiagGaussianDistribution(Distribution): """ - Gaussian distribution with diagonal covariance matrix, - for continuous actions. + Gaussian distribution with diagonal covariance matrix, for continuous actions. :param action_dim: (int) Dimension of the action space. """ @@ -115,7 +117,7 @@ class DiagGaussianDistribution(Distribution): one output will be the mean of the Gaussian, the other parameter will be the standard deviation (log std in fact to allow negative values) - :param latent_dim: (int) Dimension og the last layer of the policy (before the action layer) + :param latent_dim: (int) Dimension of the last layer of the policy (before the action layer) :param log_std_init: (float) Initial value for the log standard deviation :return: (nn.Linear, nn.Parameter) """ @@ -137,15 +139,26 @@ class DiagGaussianDistribution(Distribution): self.distribution = Normal(mean_actions, action_std) return self - def mode(self) -> th.Tensor: - return self.distribution.mean + def log_prob(self, actions: th.Tensor) -> th.Tensor: + """ + Get the log probabilities of actions according to the distribution. + Note that you must first call the ``proba_distribution()`` method. + + :param actions: (th.Tensor) + :return: (th.Tensor) + """ + log_prob = self.distribution.log_prob(actions) + return sum_independent_dims(log_prob) + + def entropy(self) -> th.Tensor: + return sum_independent_dims(self.distribution.entropy()) def sample(self) -> th.Tensor: # Reparametrization trick to pass gradients return self.distribution.rsample() - def entropy(self) -> th.Tensor: - return sum_independent_dims(self.distribution.entropy()) + def mode(self) -> th.Tensor: + return self.distribution.mean def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, @@ -168,22 +181,10 @@ class DiagGaussianDistribution(Distribution): log_prob = self.log_prob(actions) return actions, log_prob - def log_prob(self, actions: th.Tensor) -> th.Tensor: - """ - Get the log probabilities of actions according to the distribution. - Note that you must call ``proba_distribution()`` method before. - - :param actions: (th.Tensor) - :return: (th.Tensor) - """ - log_prob = self.distribution.log_prob(actions) - return sum_independent_dims(log_prob) - class SquashedDiagGaussianDistribution(DiagGaussianDistribution): """ - Gaussian distribution with diagonal covariance matrix, - followed by a squashing function (tanh) to ensure bounds. + Gaussian distribution with diagonal covariance matrix, followed by a squashing function (tanh) to ensure bounds. :param action_dim: (int) Dimension of the action space. :param epsilon: (float) small value to avoid NaN due to numerical imprecision. @@ -200,27 +201,6 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution): super(SquashedDiagGaussianDistribution, self).proba_distribution(mean_actions, log_std) return self - def mode(self) -> th.Tensor: - self.gaussian_actions = self.distribution.mean - # Squash the output - return th.tanh(self.gaussian_actions) - - def entropy(self) -> Optional[th.Tensor]: - # No analytical form, - # entropy needs to be estimated using -log_prob.mean() - return None - - def sample(self) -> th.Tensor: - # Reparametrization trick to pass gradients - self.gaussian_actions = self.distribution.rsample() - return th.tanh(self.gaussian_actions) - - def log_prob_from_params(self, mean_actions: th.Tensor, - log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: - action = self.actions_from_params(mean_actions, log_std) - log_prob = self.log_prob(action, self.gaussian_actions) - return action, log_prob - def log_prob(self, actions: th.Tensor, gaussian_actions: Optional[th.Tensor] = None) -> th.Tensor: # Inverse tanh @@ -237,6 +217,27 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution): log_prob -= th.sum(th.log(1 - actions ** 2 + self.epsilon), dim=1) return log_prob + def entropy(self) -> Optional[th.Tensor]: + # No analytical form, + # entropy needs to be estimated using -log_prob.mean() + return None + + def sample(self) -> th.Tensor: + # Reparametrization trick to pass gradients + self.gaussian_actions = super().sample() + return th.tanh(self.gaussian_actions) + + def mode(self) -> th.Tensor: + self.gaussian_actions = super().mode() + # Squash the output + return th.tanh(self.gaussian_actions) + + def log_prob_from_params(self, mean_actions: th.Tensor, + log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + action = self.actions_from_params(mean_actions, log_std) + log_prob = self.log_prob(action, self.gaussian_actions) + return action, log_prob + class CategoricalDistribution(Distribution): """ @@ -267,14 +268,17 @@ class CategoricalDistribution(Distribution): self.distribution = Categorical(logits=action_logits) return self - def mode(self) -> th.Tensor: - return th.argmax(self.distribution.probs, dim=1) + def log_prob(self, actions: th.Tensor) -> th.Tensor: + return self.distribution.log_prob(actions) + + def entropy(self) -> th.Tensor: + return self.distribution.entropy() def sample(self) -> th.Tensor: return self.distribution.sample() - def entropy(self) -> th.Tensor: - return self.distribution.entropy() + def mode(self) -> th.Tensor: + return th.argmax(self.distribution.probs, dim=1) def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor: @@ -287,9 +291,6 @@ class CategoricalDistribution(Distribution): log_prob = self.log_prob(actions) return actions, log_prob - def log_prob(self, actions: th.Tensor) -> th.Tensor: - return self.distribution.log_prob(actions) - class MultiCategoricalDistribution(Distribution): """ @@ -321,14 +322,19 @@ class MultiCategoricalDistribution(Distribution): self.distributions = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)] return self - def mode(self) -> th.Tensor: - return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distributions], dim=1) + def log_prob(self, actions: th.Tensor) -> th.Tensor: + # Extract each discrete action and compute log prob for their respective distributions + return th.stack([dist.log_prob(action) for dist, action in zip(self.distributions, + th.unbind(actions, dim=1))], dim=1).sum(dim=1) + + def entropy(self) -> th.Tensor: + return th.stack([dist.entropy() for dist in self.distributions], dim=1).sum(dim=1) def sample(self) -> th.Tensor: return th.stack([dist.sample() for dist in self.distributions], dim=1) - def entropy(self) -> th.Tensor: - return th.stack([dist.entropy() for dist in self.distributions], dim=1).sum(dim=1) + def mode(self) -> th.Tensor: + return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distributions], dim=1) def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor: @@ -341,11 +347,6 @@ class MultiCategoricalDistribution(Distribution): log_prob = self.log_prob(actions) return actions, log_prob - def log_prob(self, actions: th.Tensor) -> th.Tensor: - # Extract each discrete action and compute log prob for their respective distributions - return th.stack([dist.log_prob(action) for dist, action in zip(self.distributions, - th.unbind(actions, dim=1))], dim=1).sum(dim=1) - class BernoulliDistribution(Distribution): """ @@ -375,14 +376,17 @@ class BernoulliDistribution(Distribution): self.distribution = Bernoulli(logits=action_logits) return self - def mode(self) -> th.Tensor: - return th.round(self.distribution.probs) + def log_prob(self, actions: th.Tensor) -> th.Tensor: + return self.distribution.log_prob(actions).sum(dim=1) + + def entropy(self) -> th.Tensor: + return self.distribution.entropy().sum(dim=1) def sample(self) -> th.Tensor: return self.distribution.sample() - def entropy(self) -> th.Tensor: - return self.distribution.entropy().sum(dim=1) + def mode(self) -> th.Tensor: + return th.round(self.distribution.probs) def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor: @@ -395,9 +399,6 @@ class BernoulliDistribution(Distribution): log_prob = self.log_prob(actions) return actions, log_prob - def log_prob(self, actions: th.Tensor) -> th.Tensor: - return self.distribution.log_prob(actions).sum(dim=1) - class StateDependentNoiseDistribution(Distribution): """ @@ -414,7 +415,7 @@ class StateDependentNoiseDistribution(Distribution): a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. :param squash_output: (bool) Whether to squash the output using a tanh function, - this allows to ensure boundaries. + this ensures bounds are satisfied. :param learn_features: (bool) Whether to learn features for gSDE or not. This will enable gradients to be backpropagated through the features ``latent_sde`` in the code. @@ -529,6 +530,35 @@ class StateDependentNoiseDistribution(Distribution): self.distribution = Normal(mean_actions, th.sqrt(variance + self.epsilon)) return self + def log_prob(self, actions: th.Tensor) -> th.Tensor: + if self.bijector is not None: + gaussian_actions = self.bijector.inverse(actions) + else: + gaussian_actions = actions + # log likelihood for a gaussian + log_prob = self.distribution.log_prob(gaussian_actions) + # Sum along action dim + log_prob = sum_independent_dims(log_prob) + + if self.bijector is not None: + # Squash correction (from original SAC implementation) + log_prob -= th.sum(self.bijector.log_prob_correction(gaussian_actions), dim=1) + return log_prob + + def entropy(self) -> Optional[th.Tensor]: + if self.bijector is not None: + # No analytical form, + # entropy needs to be estimated using -log_prob.mean() + return None + return sum_independent_dims(self.distribution.entropy()) + + def sample(self) -> th.Tensor: + noise = self.get_noise(self._latent_sde) + actions = self.distribution.mean + noise + if self.bijector is not None: + return self.bijector.forward(actions) + return actions + def mode(self) -> th.Tensor: actions = self.distribution.mean if self.bijector is not None: @@ -547,20 +577,6 @@ class StateDependentNoiseDistribution(Distribution): noise = th.bmm(latent_sde, self.exploration_matrices) return noise.squeeze(1) - def sample(self) -> th.Tensor: - noise = self.get_noise(self._latent_sde) - actions = self.distribution.mean + noise - if self.bijector is not None: - return self.bijector.forward(actions) - return actions - - def entropy(self) -> Optional[th.Tensor]: - # No analytical form, - # entropy needs to be estimated using -log_prob.mean() - if self.bijector is not None: - return None - return sum_independent_dims(self.distribution.entropy()) - def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor, @@ -576,21 +592,6 @@ class StateDependentNoiseDistribution(Distribution): log_prob = self.log_prob(actions) return actions, log_prob - def log_prob(self, actions: th.Tensor) -> th.Tensor: - if self.bijector is not None: - gaussian_actions = self.bijector.inverse(actions) - else: - gaussian_actions = actions - # log likelihood for a gaussian - log_prob = self.distribution.log_prob(gaussian_actions) - # Sum along action dim - log_prob = sum_independent_dims(log_prob) - - if self.bijector is not None: - # Squash correction (from original SAC implementation) - log_prob -= th.sum(self.bijector.log_prob_correction(gaussian_actions), dim=1) - return log_prob - class TanhBijector(object): """ @@ -653,9 +654,8 @@ def make_proba_distribution(action_space: gym.spaces.Space, if isinstance(action_space, spaces.Box): assert len(action_space.shape) == 1, "Error: the action space must be a vector" - if use_sde: - return StateDependentNoiseDistribution(get_action_dim(action_space), **dist_kwargs) - return DiagGaussianDistribution(get_action_dim(action_space), **dist_kwargs) + cls = StateDependentNoiseDistribution if use_sde else DiagGaussianDistribution + cls(get_action_dim(action_space), **dist_kwargs) elif isinstance(action_space, spaces.Discrete): return CategoricalDistribution(action_space.n, **dist_kwargs) elif isinstance(action_space, spaces.MultiDiscrete):