mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[Reinstate] Wishart distribution (#70377)
Summary: Implement https://github.com/pytorch/pytorch/issues/68050 Reopened merged and reverted PR https://github.com/pytorch/pytorch/issues/68588 worked with neerajprad cc neerajprad Sorry for the confusion. TODO: - [x] Unit Test - [x] Documentation - [x] Change constraint of matrix variables with 'torch.distributions.constraints.symmetric' if it is reviewed and merged. Debug positive definite constraints https://github.com/pytorch/pytorch/issues/68720 Pull Request resolved: https://github.com/pytorch/pytorch/pull/70377 Reviewed By: mikaylagawarecki Differential Revision: D33355132 Pulled By: neerajprad fbshipit-source-id: e968c0d9a3061fb2855564b96074235e46a57b6c
This commit is contained in:
parent
14d3d29b16
commit
bc40fb5639
4 changed files with 506 additions and 5 deletions
|
|
@ -356,6 +356,15 @@ Probability distributions - torch.distributions
|
|||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
:hidden:`Wishart`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. currentmodule:: torch.distributions.wishart
|
||||
.. autoclass:: Wishart
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
`KL Divergence`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ from torch.distributions import (Bernoulli, Beta, Binomial, Categorical,
|
|||
OneHotCategorical, OneHotCategoricalStraightThrough,
|
||||
Pareto, Poisson, RelaxedBernoulli, RelaxedOneHotCategorical,
|
||||
StudentT, TransformedDistribution, Uniform,
|
||||
VonMises, Weibull, constraints, kl_divergence)
|
||||
VonMises, Weibull, Wishart, constraints, kl_divergence)
|
||||
from torch.distributions.constraint_registry import transform_to
|
||||
from torch.distributions.constraints import Constraint, is_dependent
|
||||
from torch.distributions.dirichlet import _Dirichlet_backward
|
||||
|
|
@ -473,6 +473,32 @@ EXAMPLES = [
|
|||
'concentration': torch.randn(1).abs().requires_grad_()
|
||||
}
|
||||
]),
|
||||
Example(Wishart, [
|
||||
{
|
||||
'covariance_matrix': torch.tensor([[2.0, 0.3], [0.3, 0.25]], requires_grad=True),
|
||||
'df': torch.tensor([4.], requires_grad=True),
|
||||
},
|
||||
{
|
||||
'precision_matrix': torch.tensor([[2.0, 0.1, 0.0],
|
||||
[0.1, 0.25, 0.0],
|
||||
[0.0, 0.0, 0.3]], requires_grad=True),
|
||||
'df': torch.tensor([2.5, 3], requires_grad=True),
|
||||
},
|
||||
{
|
||||
'scale_tril': torch.tensor([[[2.0, 0.0], [-0.5, 0.25]],
|
||||
[[2.0, 0.0], [0.3, 0.25]],
|
||||
[[5.0, 0.0], [-0.5, 1.5]]], requires_grad=True),
|
||||
'df': torch.tensor([5., 3.5, 2], requires_grad=True),
|
||||
},
|
||||
{
|
||||
'covariance_matrix': torch.tensor([[5.0, -0.5], [-0.5, 1.5]]),
|
||||
'df': torch.tensor([2.0]),
|
||||
},
|
||||
{
|
||||
'covariance_matrix': torch.tensor([[5.0, -0.5], [-0.5, 1.5]]),
|
||||
'df': 2.0,
|
||||
},
|
||||
]),
|
||||
Example(MixtureSameFamily, [
|
||||
{
|
||||
'mixture_distribution': Categorical(torch.rand(5, requires_grad=True)),
|
||||
|
|
@ -740,6 +766,20 @@ BAD_EXAMPLES = [
|
|||
'concentration': torch.tensor([-1.0], requires_grad=True)
|
||||
}
|
||||
]),
|
||||
Example(Wishart, [
|
||||
{
|
||||
'covariance_matrix': torch.tensor([[1.0, 0.0], [0.0, -2.0]], requires_grad=True),
|
||||
'df': torch.tensor([1.5], requires_grad=True),
|
||||
},
|
||||
{
|
||||
'covariance_matrix': torch.tensor([[1.0, 1.0], [1.0, -2.0]], requires_grad=True),
|
||||
'df': torch.tensor([3.], requires_grad=True),
|
||||
},
|
||||
{
|
||||
'covariance_matrix': torch.tensor([[1.0, 1.0], [1.0, -2.0]], requires_grad=True),
|
||||
'df': 3.,
|
||||
},
|
||||
]),
|
||||
Example(ContinuousBernoulli, [
|
||||
{'probs': torch.tensor([1.1, 0.2, 0.4], requires_grad=True)},
|
||||
{'probs': torch.tensor([-0.5], requires_grad=True)},
|
||||
|
|
@ -792,10 +832,10 @@ class TestDistributions(TestCase):
|
|||
ref_samples = ref_dist.rvs(num_samples).astype(np.float64)
|
||||
if multivariate:
|
||||
# Project onto a random axis.
|
||||
axis = np.random.normal(size=torch_samples.shape[-1])
|
||||
axis = np.random.normal(size=(1,) + torch_samples.shape[1:])
|
||||
axis /= np.linalg.norm(axis)
|
||||
torch_samples = np.dot(torch_samples, axis)
|
||||
ref_samples = np.dot(ref_samples, axis)
|
||||
torch_samples = (axis * torch_samples).reshape(num_samples, -1).sum(-1)
|
||||
ref_samples = (axis * ref_samples).reshape(num_samples, -1).sum(-1)
|
||||
samples = [(x, +1) for x in torch_samples] + [(x, -1) for x in ref_samples]
|
||||
if circular:
|
||||
samples = [(np.cos(x), v) for (x, v) in samples]
|
||||
|
|
@ -2168,6 +2208,148 @@ class TestDistributions(TestCase):
|
|||
empirical_var = samples.var(0)
|
||||
self.assertEqual(d.variance, empirical_var, atol=0.05, rtol=0)
|
||||
|
||||
# We applied same tests in Multivariate Normal distribution for Wishart distribution
|
||||
def test_wishart_shape(self):
|
||||
df = (torch.rand(5, requires_grad=True) + 1) * 10
|
||||
df_no_batch = (torch.rand([], requires_grad=True) + 1) * 10
|
||||
df_multi_batch = (torch.rand(6, 5, requires_grad=True) + 1) * 10
|
||||
|
||||
# construct PSD covariance
|
||||
tmp = torch.randn(3, 10)
|
||||
cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
|
||||
prec = cov.inverse().requires_grad_()
|
||||
scale_tril = torch.linalg.cholesky(cov).requires_grad_()
|
||||
|
||||
# construct batch of PSD covariances
|
||||
tmp = torch.randn(6, 5, 3, 10)
|
||||
cov_batched = (tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1).requires_grad_()
|
||||
prec_batched = cov_batched.inverse()
|
||||
scale_tril_batched = torch.linalg.cholesky(cov_batched)
|
||||
|
||||
# ensure that sample, batch, event shapes all handled correctly
|
||||
self.assertEqual(Wishart(df, cov).sample().size(), (5, 3, 3))
|
||||
self.assertEqual(Wishart(df_no_batch, cov).sample().size(), (3, 3))
|
||||
self.assertEqual(Wishart(df_multi_batch, cov).sample().size(), (6, 5, 3, 3))
|
||||
self.assertEqual(Wishart(df, cov).sample((2,)).size(), (2, 5, 3, 3))
|
||||
self.assertEqual(Wishart(df_no_batch, cov).sample((2,)).size(), (2, 3, 3))
|
||||
self.assertEqual(Wishart(df_multi_batch, cov).sample((2,)).size(), (2, 6, 5, 3, 3))
|
||||
self.assertEqual(Wishart(df, cov).sample((2, 7)).size(), (2, 7, 5, 3, 3))
|
||||
self.assertEqual(Wishart(df_no_batch, cov).sample((2, 7)).size(), (2, 7, 3, 3))
|
||||
self.assertEqual(Wishart(df_multi_batch, cov).sample((2, 7)).size(), (2, 7, 6, 5, 3, 3))
|
||||
self.assertEqual(Wishart(df, cov_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3, 3))
|
||||
self.assertEqual(Wishart(df_no_batch, cov_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3, 3))
|
||||
self.assertEqual(Wishart(df_multi_batch, cov_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3, 3))
|
||||
self.assertEqual(Wishart(df, precision_matrix=prec).sample((2, 7)).size(), (2, 7, 5, 3, 3))
|
||||
self.assertEqual(Wishart(df, precision_matrix=prec_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3, 3))
|
||||
self.assertEqual(Wishart(df, scale_tril=scale_tril).sample((2, 7)).size(), (2, 7, 5, 3, 3))
|
||||
self.assertEqual(Wishart(df, scale_tril=scale_tril_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3, 3))
|
||||
|
||||
# check gradients
|
||||
# Modified and applied the same tests for multivariate_normal
|
||||
def wishart_log_prob_gradcheck(df=None, covariance=None, precision=None, scale_tril=None):
|
||||
wishart_samples = Wishart(df, covariance, precision, scale_tril).sample().requires_grad_()
|
||||
|
||||
def gradcheck_func(samples, nu, sigma, prec, scale_tril):
|
||||
if sigma is not None:
|
||||
sigma = 0.5 * (sigma + sigma.mT) # Ensure symmetry of covariance
|
||||
if prec is not None:
|
||||
prec = 0.5 * (prec + prec.mT) # Ensure symmetry of precision
|
||||
if scale_tril is not None:
|
||||
scale_tril = scale_tril.tril()
|
||||
return Wishart(nu, sigma, prec, scale_tril).log_prob(samples)
|
||||
gradcheck(gradcheck_func, (wishart_samples, df, covariance, precision, scale_tril), raise_exception=True)
|
||||
|
||||
wishart_log_prob_gradcheck(df, cov)
|
||||
wishart_log_prob_gradcheck(df_multi_batch, cov)
|
||||
wishart_log_prob_gradcheck(df_multi_batch, cov_batched)
|
||||
wishart_log_prob_gradcheck(df, None, prec)
|
||||
wishart_log_prob_gradcheck(df_no_batch, None, prec_batched)
|
||||
wishart_log_prob_gradcheck(df, None, None, scale_tril)
|
||||
wishart_log_prob_gradcheck(df_no_batch, None, None, scale_tril_batched)
|
||||
|
||||
def test_wishart_stable_with_precision_matrix(self):
|
||||
x = torch.randn(10)
|
||||
P = torch.exp(-(x - x.unsqueeze(-1)) ** 2) # RBF kernel
|
||||
Wishart(torch.tensor(10), precision_matrix=P)
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
||||
def test_wishart_log_prob(self):
|
||||
df = (torch.rand([], requires_grad=True) + 1) * 10
|
||||
tmp = torch.randn(3, 10)
|
||||
cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
|
||||
prec = cov.inverse().requires_grad_()
|
||||
scale_tril = torch.linalg.cholesky(cov).requires_grad_()
|
||||
|
||||
# check that logprob values match scipy logpdf,
|
||||
# and that covariance and scale_tril parameters are equivalent
|
||||
dist1 = Wishart(df, cov)
|
||||
dist2 = Wishart(df, precision_matrix=prec)
|
||||
dist3 = Wishart(df, scale_tril=scale_tril)
|
||||
ref_dist = scipy.stats.wishart(df.item(), cov.detach().numpy())
|
||||
|
||||
x = dist1.sample((10,))
|
||||
expected = ref_dist.logpdf(x.transpose(0, 2).numpy())
|
||||
|
||||
self.assertEqual(0.0, np.mean((dist1.log_prob(x).detach().numpy() - expected)**2), atol=1e-3, rtol=0)
|
||||
self.assertEqual(0.0, np.mean((dist2.log_prob(x).detach().numpy() - expected)**2), atol=1e-3, rtol=0)
|
||||
self.assertEqual(0.0, np.mean((dist3.log_prob(x).detach().numpy() - expected)**2), atol=1e-3, rtol=0)
|
||||
|
||||
# Double-check that batched versions behave the same as unbatched
|
||||
df = (torch.rand(5, requires_grad=True) + 1) * 3
|
||||
tmp = torch.randn(5, 3, 10)
|
||||
cov = (tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1).requires_grad_()
|
||||
|
||||
dist_batched = Wishart(df, cov)
|
||||
dist_unbatched = [Wishart(df[i], cov[i]) for i in range(df.size(0))]
|
||||
|
||||
x = dist_batched.sample((10,))
|
||||
batched_prob = dist_batched.log_prob(x)
|
||||
unbatched_prob = torch.stack([dist_unbatched[i].log_prob(x[:, i]) for i in range(5)]).t()
|
||||
|
||||
self.assertEqual(batched_prob.shape, unbatched_prob.shape)
|
||||
self.assertEqual(0.0, (batched_prob - unbatched_prob).abs().max(), atol=1e-3, rtol=0)
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
|
||||
def test_wishart_sample(self):
|
||||
set_rng_seed(0) # see Note [Randomized statistical tests]
|
||||
df = (torch.rand([], requires_grad=True) + 1) * 3
|
||||
tmp = torch.randn(3, 10)
|
||||
cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
|
||||
prec = cov.inverse().requires_grad_()
|
||||
scale_tril = torch.linalg.cholesky(cov).requires_grad_()
|
||||
|
||||
self._check_sampler_sampler(Wishart(df, cov),
|
||||
scipy.stats.wishart(df.item(), cov.detach().numpy()),
|
||||
'Wishart(df={}, covariance_matrix={})'.format(df, cov),
|
||||
multivariate=True)
|
||||
self._check_sampler_sampler(Wishart(df, precision_matrix=prec),
|
||||
scipy.stats.wishart(df.item(), cov.detach().numpy()),
|
||||
'Wishart(df={}, precision_matrix={})'.format(df, prec),
|
||||
multivariate=True)
|
||||
self._check_sampler_sampler(Wishart(df, scale_tril=scale_tril),
|
||||
scipy.stats.wishart(df.item(), cov.detach().numpy()),
|
||||
'Wishart(df={}, scale_tril={})'.format(df, scale_tril),
|
||||
multivariate=True)
|
||||
|
||||
def test_wishart_properties(self):
|
||||
df = (torch.rand([]) + 1) * 5
|
||||
scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(5, 5))
|
||||
m = Wishart(df=df, scale_tril=scale_tril)
|
||||
self.assertEqual(m.covariance_matrix, m.scale_tril.mm(m.scale_tril.t()))
|
||||
self.assertEqual(m.covariance_matrix.mm(m.precision_matrix), torch.eye(m.event_shape[0]))
|
||||
self.assertEqual(m.scale_tril, torch.linalg.cholesky(m.covariance_matrix))
|
||||
|
||||
def test_wishart_moments(self):
|
||||
set_rng_seed(0) # see Note [Randomized statistical tests]
|
||||
df = (torch.rand([]) + 1) * 3
|
||||
scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(3, 3))
|
||||
d = Wishart(df=df, scale_tril=scale_tril)
|
||||
samples = d.rsample((100000,))
|
||||
empirical_mean = samples.mean(0)
|
||||
self.assertEqual(d.mean, empirical_mean, atol=5, rtol=0)
|
||||
empirical_var = samples.var(0)
|
||||
self.assertEqual(d.variance, empirical_var, atol=5, rtol=0)
|
||||
|
||||
def test_exponential(self):
|
||||
rate = torch.randn(5, 5).abs().requires_grad_()
|
||||
rate_1d = torch.randn(1).abs().requires_grad_()
|
||||
|
|
@ -3487,6 +3669,23 @@ class TestDistributionShapes(TestCase):
|
|||
self.assertEqual(weibull.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
|
||||
self.assertEqual(weibull.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
|
||||
|
||||
def test_wishart_shape_scalar_params(self):
|
||||
wishart = Wishart(torch.tensor(1), torch.tensor([[1.]]))
|
||||
self.assertEqual(wishart._batch_shape, torch.Size())
|
||||
self.assertEqual(wishart._event_shape, torch.Size((1, 1)))
|
||||
self.assertEqual(wishart.sample().size(), torch.Size((1, 1)))
|
||||
self.assertEqual(wishart.sample((3, 2)).size(), torch.Size((3, 2, 1, 1)))
|
||||
self.assertRaises(ValueError, wishart.log_prob, self.scalar_sample)
|
||||
|
||||
def test_wishart_shape_tensor_params(self):
|
||||
wishart = Wishart(torch.tensor([1., 1.]), torch.tensor([[[1.]], [[1.]]]))
|
||||
self.assertEqual(wishart._batch_shape, torch.Size((2,)))
|
||||
self.assertEqual(wishart._event_shape, torch.Size((1, 1)))
|
||||
self.assertEqual(wishart.sample().size(), torch.Size((2, 1, 1)))
|
||||
self.assertEqual(wishart.sample((3, 2)).size(), torch.Size((3, 2, 2, 1, 1)))
|
||||
self.assertRaises(ValueError, wishart.log_prob, self.tensor_sample_2)
|
||||
self.assertEqual(wishart.log_prob(torch.ones(2, 1, 1)).size(), torch.Size((2,)))
|
||||
|
||||
def test_normal_shape_scalar_params(self):
|
||||
normal = Normal(0, 1)
|
||||
self.assertEqual(normal._batch_shape, torch.Size())
|
||||
|
|
@ -4305,6 +4504,8 @@ class TestAgainstScipy(TestCase):
|
|||
positive_var2 = torch.randn(20).exp()
|
||||
random_var = torch.randn(20)
|
||||
simplex_tensor = softmax(torch.randn(20), dim=-1)
|
||||
cov_tensor = torch.randn(20, 20)
|
||||
cov_tensor = cov_tensor @ cov_tensor.mT
|
||||
self.distribution_pairs = [
|
||||
(
|
||||
Bernoulli(simplex_tensor),
|
||||
|
|
@ -4375,6 +4576,10 @@ class TestAgainstScipy(TestCase):
|
|||
MultivariateNormal(random_var, torch.diag(positive_var2)),
|
||||
scipy.stats.multivariate_normal(random_var, torch.diag(positive_var2))
|
||||
),
|
||||
(
|
||||
MultivariateNormal(random_var, cov_tensor),
|
||||
scipy.stats.multivariate_normal(random_var, cov_tensor)
|
||||
),
|
||||
(
|
||||
Normal(random_var, positive_var2),
|
||||
scipy.stats.norm(random_var, positive_var2)
|
||||
|
|
@ -4406,7 +4611,11 @@ class TestAgainstScipy(TestCase):
|
|||
(
|
||||
Weibull(positive_var[0], positive_var2[0]), # scipy var for Weibull only supports scalars
|
||||
scipy.stats.weibull_min(c=positive_var2[0], scale=positive_var[0])
|
||||
)
|
||||
),
|
||||
(
|
||||
Wishart(20 + positive_var[0], cov_tensor), # scipy var for Wishart only supports scalars
|
||||
scipy.stats.wishart(20 + positive_var[0].item(), cov_tensor),
|
||||
),
|
||||
]
|
||||
|
||||
def test_mean(self):
|
||||
|
|
|
|||
|
|
@ -113,6 +113,7 @@ from .transforms import * # noqa: F403
|
|||
from .uniform import Uniform
|
||||
from .von_mises import VonMises
|
||||
from .weibull import Weibull
|
||||
from .wishart import Wishart
|
||||
from . import transforms
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -155,6 +156,7 @@ __all__ = [
|
|||
'Uniform',
|
||||
'VonMises',
|
||||
'Weibull',
|
||||
'Wishart',
|
||||
'TransformedDistribution',
|
||||
'biject_to',
|
||||
'kl_divergence',
|
||||
|
|
|
|||
281
torch/distributions/wishart.py
Normal file
281
torch/distributions/wishart.py
Normal file
|
|
@ -0,0 +1,281 @@
|
|||
import math
|
||||
import warnings
|
||||
from numbers import Number
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.exp_family import ExponentialFamily
|
||||
from torch.distributions.utils import lazy_property
|
||||
from torch.distributions.multivariate_normal import _precision_to_scale_tril
|
||||
|
||||
|
||||
_log_2 = math.log(2)
|
||||
|
||||
|
||||
def _mvdigamma(x: torch.Tensor, p: int) -> torch.Tensor:
|
||||
assert x.gt((p - 1) / 2).all(), "Wrong domain for multivariate digamma function."
|
||||
return torch.digamma(
|
||||
x.unsqueeze(-1)
|
||||
- torch.arange(p, dtype=x.dtype, device=x.device).div(2).expand(x.shape + (-1,))
|
||||
).sum(-1)
|
||||
|
||||
class Wishart(ExponentialFamily):
|
||||
r"""
|
||||
Creates a Wishart distribution parameterized by a symmetric positive definite matrix :math:`\Sigma`,
|
||||
or its Cholesky decomposition :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top`
|
||||
|
||||
Example:
|
||||
>>> m = Wishart(torch.eye(2), torch.Tensor([2]))
|
||||
>>> m.sample() #Wishart distributed with mean=`df * I` and
|
||||
#variance(x_ij)=`df` for i != j and variance(x_ij)=`2 * df` for i == j
|
||||
Args:
|
||||
covariance_matrix (Tensor): positive-definite covariance matrix
|
||||
precision_matrix (Tensor): positive-definite precision matrix
|
||||
scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal
|
||||
df (float or Tensor): real-valued parameter larger than the (dimension of Square matrix) - 1
|
||||
Note:
|
||||
Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or
|
||||
:attr:`scale_tril` can be specified.
|
||||
Using :attr:`scale_tril` will be more efficient: all computations internally
|
||||
are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or
|
||||
:attr:`precision_matrix` is passed instead, it is only used to compute
|
||||
the corresponding lower triangular matrices using a Cholesky decomposition.
|
||||
'torch.distributions.LKJCholesky' is a restricted Wishart distribution.[1]
|
||||
|
||||
**References**
|
||||
|
||||
[1] `On equivalence of the LKJ distribution and the restricted Wishart distribution`,
|
||||
Zhenxun Wang, Yunan Wu, Haitao Chu.
|
||||
"""
|
||||
arg_constraints = {
|
||||
'covariance_matrix': constraints.positive_definite,
|
||||
'precision_matrix': constraints.positive_definite,
|
||||
'scale_tril': constraints.lower_cholesky,
|
||||
'df': constraints.greater_than(0),
|
||||
}
|
||||
support = constraints.positive_definite
|
||||
has_rsample = True
|
||||
|
||||
def __init__(self,
|
||||
df: Union[torch.Tensor, Number],
|
||||
covariance_matrix: torch.Tensor = None,
|
||||
precision_matrix: torch.Tensor = None,
|
||||
scale_tril: torch.Tensor = None,
|
||||
validate_args=None):
|
||||
assert (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) == 1, \
|
||||
"Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified."
|
||||
|
||||
param = next(p for p in (covariance_matrix, precision_matrix, scale_tril) if p is not None)
|
||||
|
||||
if param.dim() < 2:
|
||||
raise ValueError("scale_tril must be at least two-dimensional, with optional leading batch dimensions")
|
||||
|
||||
if isinstance(df, Number):
|
||||
batch_shape = torch.Size(param.shape[:-2])
|
||||
self.df = torch.tensor(df, dtype=param.dtype, device=param.device)
|
||||
else:
|
||||
batch_shape = torch.broadcast_shapes(param.shape[:-2], df.shape)
|
||||
self.df = df.expand(batch_shape)
|
||||
event_shape = param.shape[-2:]
|
||||
|
||||
if self.df.le(event_shape[-1] - 1).any():
|
||||
raise ValueError(f"Value of df={df} expected to be greater than ndim={event_shape[-1]-1}.")
|
||||
|
||||
if scale_tril is not None:
|
||||
self.scale_tril = param.expand(batch_shape + (-1, -1))
|
||||
elif covariance_matrix is not None:
|
||||
self.covariance_matrix = param.expand(batch_shape + (-1, -1))
|
||||
elif precision_matrix is not None:
|
||||
self.precision_matrix = param.expand(batch_shape + (-1, -1))
|
||||
|
||||
self.arg_constraints['df'] = constraints.greater_than(event_shape[-1] - 1)
|
||||
if self.df.lt(event_shape[-1]).any():
|
||||
warnings.warn("Low df values detected. Singular samples are highly likely to occur for ndim - 1 < df < ndim.")
|
||||
|
||||
super(Wishart, self).__init__(batch_shape, event_shape, validate_args=validate_args)
|
||||
self._batch_dims = [-(x + 1) for x in range(len(self._batch_shape))]
|
||||
|
||||
if scale_tril is not None:
|
||||
self._unbroadcasted_scale_tril = scale_tril
|
||||
elif covariance_matrix is not None:
|
||||
self._unbroadcasted_scale_tril = torch.linalg.cholesky(covariance_matrix)
|
||||
else: # precision_matrix is not None
|
||||
self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix)
|
||||
|
||||
# Chi2 distribution is needed for Bartlett decomposition sampling
|
||||
self._dist_chi2 = torch.distributions.chi2.Chi2(
|
||||
df=(
|
||||
self.df.unsqueeze(-1)
|
||||
- torch.arange(
|
||||
self._event_shape[-1],
|
||||
dtype=self._unbroadcasted_scale_tril.dtype,
|
||||
device=self._unbroadcasted_scale_tril.device,
|
||||
).expand(batch_shape + (-1,))
|
||||
)
|
||||
)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(Wishart, _instance)
|
||||
batch_shape = torch.Size(batch_shape)
|
||||
cov_shape = batch_shape + self.event_shape
|
||||
df_shape = batch_shape
|
||||
new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril.expand(cov_shape)
|
||||
new.df = self.df.expand(df_shape)
|
||||
|
||||
new._batch_dims = [-(x + 1) for x in range(len(batch_shape))]
|
||||
|
||||
if 'covariance_matrix' in self.__dict__:
|
||||
new.covariance_matrix = self.covariance_matrix.expand(cov_shape)
|
||||
if 'scale_tril' in self.__dict__:
|
||||
new.scale_tril = self.scale_tril.expand(cov_shape)
|
||||
if 'precision_matrix' in self.__dict__:
|
||||
new.precision_matrix = self.precision_matrix.expand(cov_shape)
|
||||
|
||||
# Chi2 distribution is needed for Bartlett decomposition sampling
|
||||
new._dist_chi2 = torch.distributions.chi2.Chi2(
|
||||
df=(
|
||||
new.df.unsqueeze(-1)
|
||||
- torch.arange(
|
||||
self.event_shape[-1],
|
||||
dtype=new._unbroadcasted_scale_tril.dtype,
|
||||
device=new._unbroadcasted_scale_tril.device,
|
||||
).expand(batch_shape + (-1,))
|
||||
)
|
||||
)
|
||||
|
||||
super(Wishart, new).__init__(batch_shape, self.event_shape, validate_args=False)
|
||||
new._validate_args = self._validate_args
|
||||
return new
|
||||
|
||||
@lazy_property
|
||||
def scale_tril(self):
|
||||
return self._unbroadcasted_scale_tril.expand(
|
||||
self._batch_shape + self._event_shape)
|
||||
|
||||
@lazy_property
|
||||
def covariance_matrix(self):
|
||||
return (
|
||||
self._unbroadcasted_scale_tril @ self._unbroadcasted_scale_tril.transpose(-2, -1)
|
||||
).expand(self._batch_shape + self._event_shape)
|
||||
|
||||
@lazy_property
|
||||
def precision_matrix(self):
|
||||
identity = torch.eye(
|
||||
self._event_shape[-1],
|
||||
device=self._unbroadcasted_scale_tril.device,
|
||||
dtype=self._unbroadcasted_scale_tril.dtype,
|
||||
)
|
||||
return torch.cholesky_solve(
|
||||
identity, self._unbroadcasted_scale_tril
|
||||
).expand(self._batch_shape + self._event_shape)
|
||||
|
||||
@property
|
||||
def mean(self):
|
||||
return self.df.view(self._batch_shape + (1, 1,)) * self.covariance_matrix
|
||||
|
||||
@property
|
||||
def variance(self):
|
||||
V = self.covariance_matrix # has shape (batch_shape x event_shape)
|
||||
diag_V = V.diagonal(dim1=-2, dim2=-1)
|
||||
return self.df.view(self._batch_shape + (1, 1,)) * (V.pow(2) + torch.einsum("...i,...j->...ij", diag_V, diag_V))
|
||||
|
||||
def _bartlett_sampling(self, sample_shape=torch.Size()):
|
||||
p = self._event_shape[-1] # has singleton shape
|
||||
|
||||
# Implemented Sampling using Bartlett decomposition
|
||||
noise = self._dist_chi2.rsample(sample_shape).sqrt().diag_embed(dim1=-2, dim2=-1)
|
||||
i, j = torch.tril_indices(p, p, offset=-1)
|
||||
noise[..., i, j] = torch.randn(
|
||||
torch.Size(sample_shape) + self._batch_shape + (int(p * (p - 1) / 2),),
|
||||
dtype=noise.dtype,
|
||||
device=noise.device,
|
||||
)
|
||||
chol = self._unbroadcasted_scale_tril @ noise
|
||||
return chol @ chol.transpose(-2, -1)
|
||||
|
||||
def rsample(self, sample_shape=torch.Size(), max_try_correction=None):
|
||||
r"""
|
||||
.. warning::
|
||||
In some cases, sampling algorithn based on Bartlett decomposition may return singular matrix samples.
|
||||
Several tries to correct singular samples are performed by default, but it may end up returning
|
||||
singular matrix samples. Sigular samples may return `-inf` values in `.log_prob()`.
|
||||
In those cases, the user should validate the samples and either fix the value of `df`
|
||||
or adjust `max_try_correction` value for argument in `.rsample` accordingly.
|
||||
"""
|
||||
|
||||
if max_try_correction is None:
|
||||
max_try_correction = 3 if torch._C._get_tracing_state() else 10
|
||||
|
||||
sample_shape = torch.Size(sample_shape)
|
||||
sample = self._bartlett_sampling(sample_shape)
|
||||
|
||||
# Below part is to improve numerical stability temporally and should be removed in the future
|
||||
is_singular = self.support.check(sample)
|
||||
if self._batch_shape:
|
||||
is_singular = is_singular.amax(self._batch_dims)
|
||||
|
||||
if torch._C._get_tracing_state():
|
||||
# Less optimized version for JIT
|
||||
for _ in range(max_try_correction):
|
||||
sample_new = self._bartlett_sampling(sample_shape)
|
||||
sample = torch.where(is_singular, sample_new, sample)
|
||||
|
||||
is_singular = ~self.support.check(sample)
|
||||
if self._batch_shape:
|
||||
is_singular = is_singular.amax(self._batch_dims)
|
||||
|
||||
else:
|
||||
# More optimized version with data-dependent control flow.
|
||||
if is_singular.any():
|
||||
warnings.warn("Singular sample detected.")
|
||||
|
||||
for _ in range(max_try_correction):
|
||||
sample_new = self._bartlett_sampling(is_singular[is_singular].shape)
|
||||
sample[is_singular] = sample_new
|
||||
|
||||
is_singular_new = ~self.support.check(sample_new)
|
||||
if self._batch_shape:
|
||||
is_singular_new = is_singular_new.amax(self._batch_dims)
|
||||
is_singular[is_singular.clone()] = is_singular_new
|
||||
|
||||
if not is_singular.any():
|
||||
break
|
||||
|
||||
return sample
|
||||
|
||||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
nu = self.df # has shape (batch_shape)
|
||||
p = self._event_shape[-1] # has singleton shape
|
||||
return (
|
||||
- nu * p * _log_2 / 2
|
||||
- nu * self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
|
||||
- torch.mvlgamma(nu / 2, p=p)
|
||||
+ (nu - p - 1) / 2 * torch.linalg.slogdet(value).logabsdet
|
||||
- torch.cholesky_solve(value, self._unbroadcasted_scale_tril).diagonal(dim1=-2, dim2=-1).sum(dim=-1) / 2
|
||||
)
|
||||
|
||||
def entropy(self):
|
||||
nu = self.df # has shape (batch_shape)
|
||||
p = self._event_shape[-1] # has singleton shape
|
||||
V = self.covariance_matrix # has shape (batch_shape x event_shape)
|
||||
return (
|
||||
(p + 1) * self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
|
||||
+ p * (p + 1) * _log_2 / 2
|
||||
+ torch.mvlgamma(nu / 2, p=p)
|
||||
- (nu - p - 1) / 2 * _mvdigamma(nu / 2, p=p)
|
||||
+ nu * p / 2
|
||||
)
|
||||
|
||||
@property
|
||||
def _natural_params(self):
|
||||
return (
|
||||
0.5 * self.df,
|
||||
- 0.5 * self.precision_matrix,
|
||||
)
|
||||
|
||||
def _log_normalizer(self, x, y):
|
||||
p = y.shape[-1]
|
||||
return x * (- torch.linalg.slogdet(-2 * y).logabsdet + _log_2 * p) + _mvdigamma(x, p=p)
|
||||
Loading…
Reference in a new issue