[Reinstate] Wishart distribution (#70377)

Summary:
Implement https://github.com/pytorch/pytorch/issues/68050
Reopened merged and reverted PR https://github.com/pytorch/pytorch/issues/68588 worked with neerajprad
cc neerajprad

Sorry for the confusion.

TODO:

- [x] Unit Test
- [x] Documentation
- [x] Change constraint of matrix variables with 'torch.distributions.constraints.symmetric' if it is reviewed and merged. Debug positive definite constraints https://github.com/pytorch/pytorch/issues/68720

Pull Request resolved: https://github.com/pytorch/pytorch/pull/70377

Reviewed By: mikaylagawarecki

Differential Revision: D33355132

Pulled By: neerajprad

fbshipit-source-id: e968c0d9a3061fb2855564b96074235e46a57b6c
This commit is contained in:
Juhyeong Kim 2021-12-30 11:40:18 -08:00 committed by Facebook GitHub Bot
parent 14d3d29b16
commit bc40fb5639
4 changed files with 506 additions and 5 deletions

View file

@ -356,6 +356,15 @@ Probability distributions - torch.distributions
:undoc-members:
:show-inheritance:
:hidden:`Wishart`
~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.wishart
.. autoclass:: Wishart
:members:
:undoc-members:
:show-inheritance:
`KL Divergence`
~~~~~~~~~~~~~~~~~~~~~~~

View file

@ -61,7 +61,7 @@ from torch.distributions import (Bernoulli, Beta, Binomial, Categorical,
OneHotCategorical, OneHotCategoricalStraightThrough,
Pareto, Poisson, RelaxedBernoulli, RelaxedOneHotCategorical,
StudentT, TransformedDistribution, Uniform,
VonMises, Weibull, constraints, kl_divergence)
VonMises, Weibull, Wishart, constraints, kl_divergence)
from torch.distributions.constraint_registry import transform_to
from torch.distributions.constraints import Constraint, is_dependent
from torch.distributions.dirichlet import _Dirichlet_backward
@ -473,6 +473,32 @@ EXAMPLES = [
'concentration': torch.randn(1).abs().requires_grad_()
}
]),
Example(Wishart, [
{
'covariance_matrix': torch.tensor([[2.0, 0.3], [0.3, 0.25]], requires_grad=True),
'df': torch.tensor([4.], requires_grad=True),
},
{
'precision_matrix': torch.tensor([[2.0, 0.1, 0.0],
[0.1, 0.25, 0.0],
[0.0, 0.0, 0.3]], requires_grad=True),
'df': torch.tensor([2.5, 3], requires_grad=True),
},
{
'scale_tril': torch.tensor([[[2.0, 0.0], [-0.5, 0.25]],
[[2.0, 0.0], [0.3, 0.25]],
[[5.0, 0.0], [-0.5, 1.5]]], requires_grad=True),
'df': torch.tensor([5., 3.5, 2], requires_grad=True),
},
{
'covariance_matrix': torch.tensor([[5.0, -0.5], [-0.5, 1.5]]),
'df': torch.tensor([2.0]),
},
{
'covariance_matrix': torch.tensor([[5.0, -0.5], [-0.5, 1.5]]),
'df': 2.0,
},
]),
Example(MixtureSameFamily, [
{
'mixture_distribution': Categorical(torch.rand(5, requires_grad=True)),
@ -740,6 +766,20 @@ BAD_EXAMPLES = [
'concentration': torch.tensor([-1.0], requires_grad=True)
}
]),
Example(Wishart, [
{
'covariance_matrix': torch.tensor([[1.0, 0.0], [0.0, -2.0]], requires_grad=True),
'df': torch.tensor([1.5], requires_grad=True),
},
{
'covariance_matrix': torch.tensor([[1.0, 1.0], [1.0, -2.0]], requires_grad=True),
'df': torch.tensor([3.], requires_grad=True),
},
{
'covariance_matrix': torch.tensor([[1.0, 1.0], [1.0, -2.0]], requires_grad=True),
'df': 3.,
},
]),
Example(ContinuousBernoulli, [
{'probs': torch.tensor([1.1, 0.2, 0.4], requires_grad=True)},
{'probs': torch.tensor([-0.5], requires_grad=True)},
@ -792,10 +832,10 @@ class TestDistributions(TestCase):
ref_samples = ref_dist.rvs(num_samples).astype(np.float64)
if multivariate:
# Project onto a random axis.
axis = np.random.normal(size=torch_samples.shape[-1])
axis = np.random.normal(size=(1,) + torch_samples.shape[1:])
axis /= np.linalg.norm(axis)
torch_samples = np.dot(torch_samples, axis)
ref_samples = np.dot(ref_samples, axis)
torch_samples = (axis * torch_samples).reshape(num_samples, -1).sum(-1)
ref_samples = (axis * ref_samples).reshape(num_samples, -1).sum(-1)
samples = [(x, +1) for x in torch_samples] + [(x, -1) for x in ref_samples]
if circular:
samples = [(np.cos(x), v) for (x, v) in samples]
@ -2168,6 +2208,148 @@ class TestDistributions(TestCase):
empirical_var = samples.var(0)
self.assertEqual(d.variance, empirical_var, atol=0.05, rtol=0)
# We applied same tests in Multivariate Normal distribution for Wishart distribution
def test_wishart_shape(self):
df = (torch.rand(5, requires_grad=True) + 1) * 10
df_no_batch = (torch.rand([], requires_grad=True) + 1) * 10
df_multi_batch = (torch.rand(6, 5, requires_grad=True) + 1) * 10
# construct PSD covariance
tmp = torch.randn(3, 10)
cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
prec = cov.inverse().requires_grad_()
scale_tril = torch.linalg.cholesky(cov).requires_grad_()
# construct batch of PSD covariances
tmp = torch.randn(6, 5, 3, 10)
cov_batched = (tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1).requires_grad_()
prec_batched = cov_batched.inverse()
scale_tril_batched = torch.linalg.cholesky(cov_batched)
# ensure that sample, batch, event shapes all handled correctly
self.assertEqual(Wishart(df, cov).sample().size(), (5, 3, 3))
self.assertEqual(Wishart(df_no_batch, cov).sample().size(), (3, 3))
self.assertEqual(Wishart(df_multi_batch, cov).sample().size(), (6, 5, 3, 3))
self.assertEqual(Wishart(df, cov).sample((2,)).size(), (2, 5, 3, 3))
self.assertEqual(Wishart(df_no_batch, cov).sample((2,)).size(), (2, 3, 3))
self.assertEqual(Wishart(df_multi_batch, cov).sample((2,)).size(), (2, 6, 5, 3, 3))
self.assertEqual(Wishart(df, cov).sample((2, 7)).size(), (2, 7, 5, 3, 3))
self.assertEqual(Wishart(df_no_batch, cov).sample((2, 7)).size(), (2, 7, 3, 3))
self.assertEqual(Wishart(df_multi_batch, cov).sample((2, 7)).size(), (2, 7, 6, 5, 3, 3))
self.assertEqual(Wishart(df, cov_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3, 3))
self.assertEqual(Wishart(df_no_batch, cov_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3, 3))
self.assertEqual(Wishart(df_multi_batch, cov_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3, 3))
self.assertEqual(Wishart(df, precision_matrix=prec).sample((2, 7)).size(), (2, 7, 5, 3, 3))
self.assertEqual(Wishart(df, precision_matrix=prec_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3, 3))
self.assertEqual(Wishart(df, scale_tril=scale_tril).sample((2, 7)).size(), (2, 7, 5, 3, 3))
self.assertEqual(Wishart(df, scale_tril=scale_tril_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3, 3))
# check gradients
# Modified and applied the same tests for multivariate_normal
def wishart_log_prob_gradcheck(df=None, covariance=None, precision=None, scale_tril=None):
wishart_samples = Wishart(df, covariance, precision, scale_tril).sample().requires_grad_()
def gradcheck_func(samples, nu, sigma, prec, scale_tril):
if sigma is not None:
sigma = 0.5 * (sigma + sigma.mT) # Ensure symmetry of covariance
if prec is not None:
prec = 0.5 * (prec + prec.mT) # Ensure symmetry of precision
if scale_tril is not None:
scale_tril = scale_tril.tril()
return Wishart(nu, sigma, prec, scale_tril).log_prob(samples)
gradcheck(gradcheck_func, (wishart_samples, df, covariance, precision, scale_tril), raise_exception=True)
wishart_log_prob_gradcheck(df, cov)
wishart_log_prob_gradcheck(df_multi_batch, cov)
wishart_log_prob_gradcheck(df_multi_batch, cov_batched)
wishart_log_prob_gradcheck(df, None, prec)
wishart_log_prob_gradcheck(df_no_batch, None, prec_batched)
wishart_log_prob_gradcheck(df, None, None, scale_tril)
wishart_log_prob_gradcheck(df_no_batch, None, None, scale_tril_batched)
def test_wishart_stable_with_precision_matrix(self):
x = torch.randn(10)
P = torch.exp(-(x - x.unsqueeze(-1)) ** 2) # RBF kernel
Wishart(torch.tensor(10), precision_matrix=P)
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_wishart_log_prob(self):
df = (torch.rand([], requires_grad=True) + 1) * 10
tmp = torch.randn(3, 10)
cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
prec = cov.inverse().requires_grad_()
scale_tril = torch.linalg.cholesky(cov).requires_grad_()
# check that logprob values match scipy logpdf,
# and that covariance and scale_tril parameters are equivalent
dist1 = Wishart(df, cov)
dist2 = Wishart(df, precision_matrix=prec)
dist3 = Wishart(df, scale_tril=scale_tril)
ref_dist = scipy.stats.wishart(df.item(), cov.detach().numpy())
x = dist1.sample((10,))
expected = ref_dist.logpdf(x.transpose(0, 2).numpy())
self.assertEqual(0.0, np.mean((dist1.log_prob(x).detach().numpy() - expected)**2), atol=1e-3, rtol=0)
self.assertEqual(0.0, np.mean((dist2.log_prob(x).detach().numpy() - expected)**2), atol=1e-3, rtol=0)
self.assertEqual(0.0, np.mean((dist3.log_prob(x).detach().numpy() - expected)**2), atol=1e-3, rtol=0)
# Double-check that batched versions behave the same as unbatched
df = (torch.rand(5, requires_grad=True) + 1) * 3
tmp = torch.randn(5, 3, 10)
cov = (tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1).requires_grad_()
dist_batched = Wishart(df, cov)
dist_unbatched = [Wishart(df[i], cov[i]) for i in range(df.size(0))]
x = dist_batched.sample((10,))
batched_prob = dist_batched.log_prob(x)
unbatched_prob = torch.stack([dist_unbatched[i].log_prob(x[:, i]) for i in range(5)]).t()
self.assertEqual(batched_prob.shape, unbatched_prob.shape)
self.assertEqual(0.0, (batched_prob - unbatched_prob).abs().max(), atol=1e-3, rtol=0)
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_wishart_sample(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
df = (torch.rand([], requires_grad=True) + 1) * 3
tmp = torch.randn(3, 10)
cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
prec = cov.inverse().requires_grad_()
scale_tril = torch.linalg.cholesky(cov).requires_grad_()
self._check_sampler_sampler(Wishart(df, cov),
scipy.stats.wishart(df.item(), cov.detach().numpy()),
'Wishart(df={}, covariance_matrix={})'.format(df, cov),
multivariate=True)
self._check_sampler_sampler(Wishart(df, precision_matrix=prec),
scipy.stats.wishart(df.item(), cov.detach().numpy()),
'Wishart(df={}, precision_matrix={})'.format(df, prec),
multivariate=True)
self._check_sampler_sampler(Wishart(df, scale_tril=scale_tril),
scipy.stats.wishart(df.item(), cov.detach().numpy()),
'Wishart(df={}, scale_tril={})'.format(df, scale_tril),
multivariate=True)
def test_wishart_properties(self):
df = (torch.rand([]) + 1) * 5
scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(5, 5))
m = Wishart(df=df, scale_tril=scale_tril)
self.assertEqual(m.covariance_matrix, m.scale_tril.mm(m.scale_tril.t()))
self.assertEqual(m.covariance_matrix.mm(m.precision_matrix), torch.eye(m.event_shape[0]))
self.assertEqual(m.scale_tril, torch.linalg.cholesky(m.covariance_matrix))
def test_wishart_moments(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
df = (torch.rand([]) + 1) * 3
scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(3, 3))
d = Wishart(df=df, scale_tril=scale_tril)
samples = d.rsample((100000,))
empirical_mean = samples.mean(0)
self.assertEqual(d.mean, empirical_mean, atol=5, rtol=0)
empirical_var = samples.var(0)
self.assertEqual(d.variance, empirical_var, atol=5, rtol=0)
def test_exponential(self):
rate = torch.randn(5, 5).abs().requires_grad_()
rate_1d = torch.randn(1).abs().requires_grad_()
@ -3487,6 +3669,23 @@ class TestDistributionShapes(TestCase):
self.assertEqual(weibull.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(weibull.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
def test_wishart_shape_scalar_params(self):
wishart = Wishart(torch.tensor(1), torch.tensor([[1.]]))
self.assertEqual(wishart._batch_shape, torch.Size())
self.assertEqual(wishart._event_shape, torch.Size((1, 1)))
self.assertEqual(wishart.sample().size(), torch.Size((1, 1)))
self.assertEqual(wishart.sample((3, 2)).size(), torch.Size((3, 2, 1, 1)))
self.assertRaises(ValueError, wishart.log_prob, self.scalar_sample)
def test_wishart_shape_tensor_params(self):
wishart = Wishart(torch.tensor([1., 1.]), torch.tensor([[[1.]], [[1.]]]))
self.assertEqual(wishart._batch_shape, torch.Size((2,)))
self.assertEqual(wishart._event_shape, torch.Size((1, 1)))
self.assertEqual(wishart.sample().size(), torch.Size((2, 1, 1)))
self.assertEqual(wishart.sample((3, 2)).size(), torch.Size((3, 2, 2, 1, 1)))
self.assertRaises(ValueError, wishart.log_prob, self.tensor_sample_2)
self.assertEqual(wishart.log_prob(torch.ones(2, 1, 1)).size(), torch.Size((2,)))
def test_normal_shape_scalar_params(self):
normal = Normal(0, 1)
self.assertEqual(normal._batch_shape, torch.Size())
@ -4305,6 +4504,8 @@ class TestAgainstScipy(TestCase):
positive_var2 = torch.randn(20).exp()
random_var = torch.randn(20)
simplex_tensor = softmax(torch.randn(20), dim=-1)
cov_tensor = torch.randn(20, 20)
cov_tensor = cov_tensor @ cov_tensor.mT
self.distribution_pairs = [
(
Bernoulli(simplex_tensor),
@ -4375,6 +4576,10 @@ class TestAgainstScipy(TestCase):
MultivariateNormal(random_var, torch.diag(positive_var2)),
scipy.stats.multivariate_normal(random_var, torch.diag(positive_var2))
),
(
MultivariateNormal(random_var, cov_tensor),
scipy.stats.multivariate_normal(random_var, cov_tensor)
),
(
Normal(random_var, positive_var2),
scipy.stats.norm(random_var, positive_var2)
@ -4406,7 +4611,11 @@ class TestAgainstScipy(TestCase):
(
Weibull(positive_var[0], positive_var2[0]), # scipy var for Weibull only supports scalars
scipy.stats.weibull_min(c=positive_var2[0], scale=positive_var[0])
)
),
(
Wishart(20 + positive_var[0], cov_tensor), # scipy var for Wishart only supports scalars
scipy.stats.wishart(20 + positive_var[0].item(), cov_tensor),
),
]
def test_mean(self):

View file

@ -113,6 +113,7 @@ from .transforms import * # noqa: F403
from .uniform import Uniform
from .von_mises import VonMises
from .weibull import Weibull
from .wishart import Wishart
from . import transforms
__all__ = [
@ -155,6 +156,7 @@ __all__ = [
'Uniform',
'VonMises',
'Weibull',
'Wishart',
'TransformedDistribution',
'biject_to',
'kl_divergence',

View file

@ -0,0 +1,281 @@
import math
import warnings
from numbers import Number
from typing import Union
import torch
from torch.distributions import constraints
from torch.distributions.exp_family import ExponentialFamily
from torch.distributions.utils import lazy_property
from torch.distributions.multivariate_normal import _precision_to_scale_tril
_log_2 = math.log(2)
def _mvdigamma(x: torch.Tensor, p: int) -> torch.Tensor:
assert x.gt((p - 1) / 2).all(), "Wrong domain for multivariate digamma function."
return torch.digamma(
x.unsqueeze(-1)
- torch.arange(p, dtype=x.dtype, device=x.device).div(2).expand(x.shape + (-1,))
).sum(-1)
class Wishart(ExponentialFamily):
r"""
Creates a Wishart distribution parameterized by a symmetric positive definite matrix :math:`\Sigma`,
or its Cholesky decomposition :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top`
Example:
>>> m = Wishart(torch.eye(2), torch.Tensor([2]))
>>> m.sample() #Wishart distributed with mean=`df * I` and
#variance(x_ij)=`df` for i != j and variance(x_ij)=`2 * df` for i == j
Args:
covariance_matrix (Tensor): positive-definite covariance matrix
precision_matrix (Tensor): positive-definite precision matrix
scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal
df (float or Tensor): real-valued parameter larger than the (dimension of Square matrix) - 1
Note:
Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or
:attr:`scale_tril` can be specified.
Using :attr:`scale_tril` will be more efficient: all computations internally
are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or
:attr:`precision_matrix` is passed instead, it is only used to compute
the corresponding lower triangular matrices using a Cholesky decomposition.
'torch.distributions.LKJCholesky' is a restricted Wishart distribution.[1]
**References**
[1] `On equivalence of the LKJ distribution and the restricted Wishart distribution`,
Zhenxun Wang, Yunan Wu, Haitao Chu.
"""
arg_constraints = {
'covariance_matrix': constraints.positive_definite,
'precision_matrix': constraints.positive_definite,
'scale_tril': constraints.lower_cholesky,
'df': constraints.greater_than(0),
}
support = constraints.positive_definite
has_rsample = True
def __init__(self,
df: Union[torch.Tensor, Number],
covariance_matrix: torch.Tensor = None,
precision_matrix: torch.Tensor = None,
scale_tril: torch.Tensor = None,
validate_args=None):
assert (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) == 1, \
"Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified."
param = next(p for p in (covariance_matrix, precision_matrix, scale_tril) if p is not None)
if param.dim() < 2:
raise ValueError("scale_tril must be at least two-dimensional, with optional leading batch dimensions")
if isinstance(df, Number):
batch_shape = torch.Size(param.shape[:-2])
self.df = torch.tensor(df, dtype=param.dtype, device=param.device)
else:
batch_shape = torch.broadcast_shapes(param.shape[:-2], df.shape)
self.df = df.expand(batch_shape)
event_shape = param.shape[-2:]
if self.df.le(event_shape[-1] - 1).any():
raise ValueError(f"Value of df={df} expected to be greater than ndim={event_shape[-1]-1}.")
if scale_tril is not None:
self.scale_tril = param.expand(batch_shape + (-1, -1))
elif covariance_matrix is not None:
self.covariance_matrix = param.expand(batch_shape + (-1, -1))
elif precision_matrix is not None:
self.precision_matrix = param.expand(batch_shape + (-1, -1))
self.arg_constraints['df'] = constraints.greater_than(event_shape[-1] - 1)
if self.df.lt(event_shape[-1]).any():
warnings.warn("Low df values detected. Singular samples are highly likely to occur for ndim - 1 < df < ndim.")
super(Wishart, self).__init__(batch_shape, event_shape, validate_args=validate_args)
self._batch_dims = [-(x + 1) for x in range(len(self._batch_shape))]
if scale_tril is not None:
self._unbroadcasted_scale_tril = scale_tril
elif covariance_matrix is not None:
self._unbroadcasted_scale_tril = torch.linalg.cholesky(covariance_matrix)
else: # precision_matrix is not None
self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix)
# Chi2 distribution is needed for Bartlett decomposition sampling
self._dist_chi2 = torch.distributions.chi2.Chi2(
df=(
self.df.unsqueeze(-1)
- torch.arange(
self._event_shape[-1],
dtype=self._unbroadcasted_scale_tril.dtype,
device=self._unbroadcasted_scale_tril.device,
).expand(batch_shape + (-1,))
)
)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Wishart, _instance)
batch_shape = torch.Size(batch_shape)
cov_shape = batch_shape + self.event_shape
df_shape = batch_shape
new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril.expand(cov_shape)
new.df = self.df.expand(df_shape)
new._batch_dims = [-(x + 1) for x in range(len(batch_shape))]
if 'covariance_matrix' in self.__dict__:
new.covariance_matrix = self.covariance_matrix.expand(cov_shape)
if 'scale_tril' in self.__dict__:
new.scale_tril = self.scale_tril.expand(cov_shape)
if 'precision_matrix' in self.__dict__:
new.precision_matrix = self.precision_matrix.expand(cov_shape)
# Chi2 distribution is needed for Bartlett decomposition sampling
new._dist_chi2 = torch.distributions.chi2.Chi2(
df=(
new.df.unsqueeze(-1)
- torch.arange(
self.event_shape[-1],
dtype=new._unbroadcasted_scale_tril.dtype,
device=new._unbroadcasted_scale_tril.device,
).expand(batch_shape + (-1,))
)
)
super(Wishart, new).__init__(batch_shape, self.event_shape, validate_args=False)
new._validate_args = self._validate_args
return new
@lazy_property
def scale_tril(self):
return self._unbroadcasted_scale_tril.expand(
self._batch_shape + self._event_shape)
@lazy_property
def covariance_matrix(self):
return (
self._unbroadcasted_scale_tril @ self._unbroadcasted_scale_tril.transpose(-2, -1)
).expand(self._batch_shape + self._event_shape)
@lazy_property
def precision_matrix(self):
identity = torch.eye(
self._event_shape[-1],
device=self._unbroadcasted_scale_tril.device,
dtype=self._unbroadcasted_scale_tril.dtype,
)
return torch.cholesky_solve(
identity, self._unbroadcasted_scale_tril
).expand(self._batch_shape + self._event_shape)
@property
def mean(self):
return self.df.view(self._batch_shape + (1, 1,)) * self.covariance_matrix
@property
def variance(self):
V = self.covariance_matrix # has shape (batch_shape x event_shape)
diag_V = V.diagonal(dim1=-2, dim2=-1)
return self.df.view(self._batch_shape + (1, 1,)) * (V.pow(2) + torch.einsum("...i,...j->...ij", diag_V, diag_V))
def _bartlett_sampling(self, sample_shape=torch.Size()):
p = self._event_shape[-1] # has singleton shape
# Implemented Sampling using Bartlett decomposition
noise = self._dist_chi2.rsample(sample_shape).sqrt().diag_embed(dim1=-2, dim2=-1)
i, j = torch.tril_indices(p, p, offset=-1)
noise[..., i, j] = torch.randn(
torch.Size(sample_shape) + self._batch_shape + (int(p * (p - 1) / 2),),
dtype=noise.dtype,
device=noise.device,
)
chol = self._unbroadcasted_scale_tril @ noise
return chol @ chol.transpose(-2, -1)
def rsample(self, sample_shape=torch.Size(), max_try_correction=None):
r"""
.. warning::
In some cases, sampling algorithn based on Bartlett decomposition may return singular matrix samples.
Several tries to correct singular samples are performed by default, but it may end up returning
singular matrix samples. Sigular samples may return `-inf` values in `.log_prob()`.
In those cases, the user should validate the samples and either fix the value of `df`
or adjust `max_try_correction` value for argument in `.rsample` accordingly.
"""
if max_try_correction is None:
max_try_correction = 3 if torch._C._get_tracing_state() else 10
sample_shape = torch.Size(sample_shape)
sample = self._bartlett_sampling(sample_shape)
# Below part is to improve numerical stability temporally and should be removed in the future
is_singular = self.support.check(sample)
if self._batch_shape:
is_singular = is_singular.amax(self._batch_dims)
if torch._C._get_tracing_state():
# Less optimized version for JIT
for _ in range(max_try_correction):
sample_new = self._bartlett_sampling(sample_shape)
sample = torch.where(is_singular, sample_new, sample)
is_singular = ~self.support.check(sample)
if self._batch_shape:
is_singular = is_singular.amax(self._batch_dims)
else:
# More optimized version with data-dependent control flow.
if is_singular.any():
warnings.warn("Singular sample detected.")
for _ in range(max_try_correction):
sample_new = self._bartlett_sampling(is_singular[is_singular].shape)
sample[is_singular] = sample_new
is_singular_new = ~self.support.check(sample_new)
if self._batch_shape:
is_singular_new = is_singular_new.amax(self._batch_dims)
is_singular[is_singular.clone()] = is_singular_new
if not is_singular.any():
break
return sample
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
nu = self.df # has shape (batch_shape)
p = self._event_shape[-1] # has singleton shape
return (
- nu * p * _log_2 / 2
- nu * self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
- torch.mvlgamma(nu / 2, p=p)
+ (nu - p - 1) / 2 * torch.linalg.slogdet(value).logabsdet
- torch.cholesky_solve(value, self._unbroadcasted_scale_tril).diagonal(dim1=-2, dim2=-1).sum(dim=-1) / 2
)
def entropy(self):
nu = self.df # has shape (batch_shape)
p = self._event_shape[-1] # has singleton shape
V = self.covariance_matrix # has shape (batch_shape x event_shape)
return (
(p + 1) * self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
+ p * (p + 1) * _log_2 / 2
+ torch.mvlgamma(nu / 2, p=p)
- (nu - p - 1) / 2 * _mvdigamma(nu / 2, p=p)
+ nu * p / 2
)
@property
def _natural_params(self):
return (
0.5 * self.df,
- 0.5 * self.precision_matrix,
)
def _log_normalizer(self, x, y):
p = y.shape[-1]
return x * (- torch.linalg.slogdet(-2 * y).logabsdet + _log_2 * p) + _mvdigamma(x, p=p)