mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Continuous bernoulli distribution (take 2) (#34619)
Summary: We recently had a NeurIPS paper (https://arxiv.org/abs/1907.06845 and https://papers.nips.cc/paper/9484-the-continuous-bernoulli-fixing-a-pervasive-error-in-variational-autoencoders) where we introduce a new [0,1]-supported distribution: the continuous Bernoulli. This pull request implements this distribution in pytorch. Pull Request resolved: https://github.com/pytorch/pytorch/pull/34619 Differential Revision: D20403123 Pulled By: ngimel fbshipit-source-id: d807c7d0d372c6daf6cb6ef09df178bc7491abb2
This commit is contained in:
parent
944ea4c334
commit
a74fbea345
5 changed files with 454 additions and 3 deletions
|
|
@ -77,6 +77,15 @@ Probability distributions - torch.distributions
|
|||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
:hidden:`ContinuousBernoulli`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. currentmodule:: torch.distributions.continuous_bernoulli
|
||||
.. autoclass:: ContinuousBernoulli
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
:hidden:`Dirichlet`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
|
|
|||
|
|
@ -40,8 +40,8 @@ from torch.testing._internal.common_utils import TestCase, run_tests, set_rng_se
|
|||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.autograd import grad, gradcheck
|
||||
from torch.distributions import (Bernoulli, Beta, Binomial, Categorical,
|
||||
Cauchy, Chi2, Dirichlet, Distribution,
|
||||
Exponential, ExponentialFamily,
|
||||
Cauchy, Chi2, ContinuousBernoulli, Dirichlet,
|
||||
Distribution, Exponential, ExponentialFamily,
|
||||
FisherSnedecor, Gamma, Geometric, Gumbel,
|
||||
HalfCauchy, HalfNormal,
|
||||
Independent, Laplace, LogisticNormal,
|
||||
|
|
@ -452,7 +452,13 @@ EXAMPLES = [
|
|||
{
|
||||
'loc': torch.tensor([0.0, math.pi / 2], requires_grad=True),
|
||||
'concentration': torch.tensor([1.0, 10.0], requires_grad=True)
|
||||
}
|
||||
},
|
||||
]),
|
||||
Example(ContinuousBernoulli, [
|
||||
{'probs': torch.tensor([0.7, 0.2, 0.4], requires_grad=True)},
|
||||
{'probs': torch.tensor([0.3], requires_grad=True)},
|
||||
{'probs': 0.3},
|
||||
{'logits': torch.tensor([0.], requires_grad=True)},
|
||||
])
|
||||
]
|
||||
|
||||
|
|
@ -673,6 +679,11 @@ BAD_EXAMPLES = [
|
|||
'scale': torch.tensor([1.0], requires_grad=True),
|
||||
'concentration': torch.tensor([-1.0], requires_grad=True)
|
||||
}
|
||||
]),
|
||||
Example(ContinuousBernoulli, [
|
||||
{'probs': torch.tensor([1.1, 0.2, 0.4], requires_grad=True)},
|
||||
{'probs': torch.tensor([-0.5], requires_grad=True)},
|
||||
{'probs': 1.00001},
|
||||
])
|
||||
]
|
||||
|
||||
|
|
@ -2402,6 +2413,44 @@ class TestDistributions(TestCase):
|
|||
self.assertEqual(frac_zeros, 0.5, 0.12)
|
||||
self.assertEqual(frac_ones, 0.5, 0.12)
|
||||
|
||||
def test_continuous_bernoulli(self):
|
||||
p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True)
|
||||
r = torch.tensor(0.3, requires_grad=True)
|
||||
s = 0.3
|
||||
self.assertEqual(ContinuousBernoulli(p).sample((8,)).size(), (8, 3))
|
||||
self.assertFalse(ContinuousBernoulli(p).sample().requires_grad)
|
||||
self.assertEqual(ContinuousBernoulli(r).sample((8,)).size(), (8,))
|
||||
self.assertEqual(ContinuousBernoulli(r).sample().size(), ())
|
||||
self.assertEqual(ContinuousBernoulli(r).sample((3, 2)).size(), (3, 2,))
|
||||
self.assertEqual(ContinuousBernoulli(s).sample().size(), ())
|
||||
self._gradcheck_log_prob(ContinuousBernoulli, (p,))
|
||||
|
||||
def ref_log_prob(idx, val, log_prob):
|
||||
prob = p[idx]
|
||||
if prob > 0.499 and prob < 0.501: # using default value of lim here
|
||||
log_norm_const = math.log(2.) + 4. / 3. * math.pow(prob - 0.5, 2) + 104. / 45. * math.pow(prob - 0.5, 4)
|
||||
else:
|
||||
log_norm_const = math.log(2. * math.atanh(1. - 2. * prob) / (1. - 2.0 * prob))
|
||||
res = val * math.log(prob) + (1. - val) * math.log1p(-prob) + log_norm_const
|
||||
self.assertEqual(log_prob, res)
|
||||
|
||||
self._check_log_prob(ContinuousBernoulli(p), ref_log_prob)
|
||||
self._check_log_prob(ContinuousBernoulli(logits=p.log() - (-p).log1p()), ref_log_prob)
|
||||
|
||||
# check entropy computation
|
||||
self.assertEqual(ContinuousBernoulli(p).entropy(), torch.tensor([-0.02938, -0.07641, -0.00682]), prec=1e-4)
|
||||
# entropy below corresponds to the clamped value of prob when using float 64
|
||||
# the value for float32 should be -1.76898
|
||||
self.assertEqual(ContinuousBernoulli(torch.tensor([0.0])).entropy(), torch.tensor([-2.58473]))
|
||||
self.assertEqual(ContinuousBernoulli(s).entropy(), torch.tensor(-0.02938), prec=1e-4)
|
||||
|
||||
def test_continuous_bernoulli_3d(self):
|
||||
p = torch.full((2, 3, 5), 0.5).requires_grad_()
|
||||
self.assertEqual(ContinuousBernoulli(p).sample().size(), (2, 3, 5))
|
||||
self.assertEqual(ContinuousBernoulli(p).sample(sample_shape=(2, 5)).size(),
|
||||
(2, 5, 2, 3, 5))
|
||||
self.assertEqual(ContinuousBernoulli(p).sample((2,)).size(), (2, 2, 3, 5))
|
||||
|
||||
def test_independent_shape(self):
|
||||
for Dist, params in EXAMPLES:
|
||||
for param in params:
|
||||
|
|
@ -3271,6 +3320,26 @@ class TestDistributionShapes(TestCase):
|
|||
self.assertRaises(ValueError, laplace.log_prob, self.tensor_sample_2)
|
||||
self.assertEqual(laplace.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
|
||||
|
||||
def test_continuous_bernoulli_shape_scalar_params(self):
|
||||
continuous_bernoulli = ContinuousBernoulli(0.3)
|
||||
self.assertEqual(continuous_bernoulli._batch_shape, torch.Size())
|
||||
self.assertEqual(continuous_bernoulli._event_shape, torch.Size())
|
||||
self.assertEqual(continuous_bernoulli.sample().size(), torch.Size())
|
||||
self.assertEqual(continuous_bernoulli.sample((3, 2)).size(), torch.Size((3, 2)))
|
||||
self.assertRaises(ValueError, continuous_bernoulli.log_prob, self.scalar_sample)
|
||||
self.assertEqual(continuous_bernoulli.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
|
||||
self.assertEqual(continuous_bernoulli.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
|
||||
|
||||
def test_continuous_bernoulli_shape_tensor_params(self):
|
||||
continuous_bernoulli = ContinuousBernoulli(torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
|
||||
self.assertEqual(continuous_bernoulli._batch_shape, torch.Size((3, 2)))
|
||||
self.assertEqual(continuous_bernoulli._event_shape, torch.Size(()))
|
||||
self.assertEqual(continuous_bernoulli.sample().size(), torch.Size((3, 2)))
|
||||
self.assertEqual(continuous_bernoulli.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
|
||||
self.assertEqual(continuous_bernoulli.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
|
||||
self.assertRaises(ValueError, continuous_bernoulli.log_prob, self.tensor_sample_2)
|
||||
self.assertEqual(continuous_bernoulli.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2)))
|
||||
|
||||
|
||||
class TestKL(TestCase):
|
||||
|
||||
|
|
@ -3316,6 +3385,7 @@ class TestKL(TestCase):
|
|||
uniform_positive = pairwise(Uniform, [1, 1.5, 2, 4], [1.2, 2.0, 3, 7])
|
||||
uniform_real = pairwise(Uniform, [-2., -1, 0, 2], [-1., 1, 1, 4])
|
||||
uniform_pareto = pairwise(Uniform, [6.5, 8.5, 6.5, 8.5], [7.5, 7.5, 9.5, 9.5])
|
||||
continuous_bernoulli = pairwise(ContinuousBernoulli, [0.1, 0.2, 0.5, 0.9])
|
||||
|
||||
# These tests should pass with precision = 0.01, but that makes tests very expensive.
|
||||
# Instead, we test with precision = 0.1 and only test with higher precision locally
|
||||
|
|
@ -3374,6 +3444,10 @@ class TestKL(TestCase):
|
|||
(uniform_real, gumbel),
|
||||
(uniform_real, normal),
|
||||
(uniform_pareto, pareto),
|
||||
(continuous_bernoulli, continuous_bernoulli),
|
||||
(continuous_bernoulli, exponential),
|
||||
(continuous_bernoulli, normal),
|
||||
(beta, continuous_bernoulli)
|
||||
]
|
||||
|
||||
self.infinite_examples = [
|
||||
|
|
@ -3429,6 +3503,18 @@ class TestKL(TestCase):
|
|||
(Uniform(-1, 2), Exponential(3)),
|
||||
(Uniform(-1, 2), Gamma(3, 4)),
|
||||
(Uniform(-1, 2), Pareto(3, 4)),
|
||||
(ContinuousBernoulli(0.25), Uniform(0.25, 1)),
|
||||
(ContinuousBernoulli(0.25), Uniform(0, 0.75)),
|
||||
(ContinuousBernoulli(0.25), Uniform(0.25, 0.75)),
|
||||
(ContinuousBernoulli(0.25), Pareto(1, 2)),
|
||||
(Exponential(1), ContinuousBernoulli(0.75)),
|
||||
(Gamma(1, 2), ContinuousBernoulli(0.75)),
|
||||
(Gumbel(-1, 2), ContinuousBernoulli(0.75)),
|
||||
(Laplace(-1, 2), ContinuousBernoulli(0.75)),
|
||||
(Normal(-1, 2), ContinuousBernoulli(0.75)),
|
||||
(Uniform(-1, 1), ContinuousBernoulli(0.75)),
|
||||
(Uniform(0, 2), ContinuousBernoulli(0.75)),
|
||||
(Uniform(-1, 2), ContinuousBernoulli(0.75))
|
||||
]
|
||||
|
||||
def test_kl_monte_carlo(self):
|
||||
|
|
@ -3787,10 +3873,107 @@ class TestNumericalStability(TestCase):
|
|||
log_pdf_prob_0 = multinomial.log_prob(torch.tensor([10, 0], dtype=dtype))
|
||||
self.assertEqual(log_pdf_prob_0.item(), -inf, allow_inf=True)
|
||||
|
||||
def test_continuous_bernoulli_gradient(self):
|
||||
|
||||
def expec_val(x, probs=None, logits=None):
|
||||
assert not (probs is None and logits is None)
|
||||
if logits is not None:
|
||||
probs = 1. / (1. + math.exp(-logits))
|
||||
bern_log_lik = x * math.log(probs) + (1. - x) * math.log1p(-probs)
|
||||
if probs < 0.499 or probs > 0.501: # using default values of lims here
|
||||
log_norm_const = math.log(
|
||||
math.fabs(math.atanh(1. - 2. * probs))) - math.log(math.fabs(1. - 2. * probs)) + math.log(2.)
|
||||
else:
|
||||
aux = math.pow(probs - 0.5, 2)
|
||||
log_norm_const = math.log(2.0) + (4.0 / 3.0 + 104.0 / 45.0 * aux) * aux
|
||||
log_lik = bern_log_lik + log_norm_const
|
||||
return log_lik
|
||||
|
||||
def expec_grad(x, probs=None, logits=None):
|
||||
assert not (probs is None and logits is None)
|
||||
if logits is not None:
|
||||
probs = 1. / (1. + math.exp(-logits))
|
||||
grad_bern_log_lik = x / probs - (1. - x) / (1. - probs)
|
||||
if probs < 0.499 or probs > 0.501: # using default values of lims here
|
||||
grad_log_c = 2. * probs - 4. * (probs - 1.) * probs * math.atanh(1. - 2. * probs) - 1.
|
||||
grad_log_c /= 2. * (probs - 1.) * probs * (2. * probs - 1.) * math.atanh(1. - 2. * probs)
|
||||
else:
|
||||
grad_log_c = 8. / 3. * (probs - 0.5) + 416. / 45. * math.pow(probs - 0.5, 3)
|
||||
grad = grad_bern_log_lik + grad_log_c
|
||||
if logits is not None:
|
||||
grad *= 1. / (1. + math.exp(logits)) - 1. / math.pow(1. + math.exp(logits), 2)
|
||||
return grad
|
||||
|
||||
for tensor_type in [torch.FloatTensor, torch.DoubleTensor]:
|
||||
self._test_pdf_score(dist_class=ContinuousBernoulli,
|
||||
probs=tensor_type([0.1]),
|
||||
x=tensor_type([0.1]),
|
||||
expected_value=tensor_type([expec_val(0.1, probs=0.1)]),
|
||||
expected_gradient=tensor_type([expec_grad(0.1, probs=0.1)]))
|
||||
|
||||
self._test_pdf_score(dist_class=ContinuousBernoulli,
|
||||
probs=tensor_type([0.1]),
|
||||
x=tensor_type([1.]),
|
||||
expected_value=tensor_type([expec_val(1., probs=0.1)]),
|
||||
expected_gradient=tensor_type([expec_grad(1., probs=0.1)]))
|
||||
|
||||
self._test_pdf_score(dist_class=ContinuousBernoulli,
|
||||
probs=tensor_type([0.4999]),
|
||||
x=tensor_type([0.9]),
|
||||
expected_value=tensor_type([expec_val(0.9, probs=0.4999)]),
|
||||
expected_gradient=tensor_type([expec_grad(0.9, probs=0.4999)]))
|
||||
|
||||
self._test_pdf_score(dist_class=ContinuousBernoulli,
|
||||
probs=tensor_type([1e-4]),
|
||||
x=tensor_type([1]),
|
||||
expected_value=tensor_type([expec_val(1, probs=1e-4)]),
|
||||
expected_gradient=tensor_type(tensor_type([expec_grad(1, probs=1e-4)])),
|
||||
prec=1e-3)
|
||||
|
||||
self._test_pdf_score(dist_class=ContinuousBernoulli,
|
||||
probs=tensor_type([1 - 1e-4]),
|
||||
x=tensor_type([0.1]),
|
||||
expected_value=tensor_type([expec_val(0.1, probs=1 - 1e-4)]),
|
||||
expected_gradient=tensor_type([expec_grad(0.1, probs=1 - 1e-4)]),
|
||||
prec=2)
|
||||
|
||||
self._test_pdf_score(dist_class=ContinuousBernoulli,
|
||||
logits=tensor_type([math.log(9999)]),
|
||||
x=tensor_type([0]),
|
||||
expected_value=tensor_type([expec_val(0, logits=math.log(9999))]),
|
||||
expected_gradient=tensor_type([expec_grad(0, logits=math.log(9999))]),
|
||||
prec=1e-3)
|
||||
|
||||
self._test_pdf_score(dist_class=ContinuousBernoulli,
|
||||
logits=tensor_type([0.001]),
|
||||
x=tensor_type([0.5]),
|
||||
expected_value=tensor_type([expec_val(0.5, logits=0.001)]),
|
||||
expected_gradient=tensor_type([expec_grad(0.5, logits=0.001)]))
|
||||
|
||||
def test_continuous_bernoulli_with_logits_underflow(self):
|
||||
for tensor_type, lim, expected in ([(torch.FloatTensor, -1e38, 2.76898),
|
||||
(torch.DoubleTensor, -1e308, 3.58473)]):
|
||||
self._test_pdf_score(dist_class=ContinuousBernoulli,
|
||||
logits=tensor_type([lim]),
|
||||
x=tensor_type([0]),
|
||||
expected_value=tensor_type([expected]),
|
||||
expected_gradient=tensor_type([0.]))
|
||||
|
||||
def test_continuous_bernoulli_with_logits_overflow(self):
|
||||
for tensor_type, lim, expected in ([(torch.FloatTensor, 1e38, 2.76898),
|
||||
(torch.DoubleTensor, 1e308, 3.58473)]):
|
||||
self._test_pdf_score(dist_class=ContinuousBernoulli,
|
||||
logits=tensor_type([lim]),
|
||||
x=tensor_type([1]),
|
||||
expected_value=tensor_type([expected]),
|
||||
expected_gradient=tensor_type([0.]))
|
||||
|
||||
|
||||
class TestLazyLogitsInitialization(TestCase):
|
||||
def setUp(self):
|
||||
super(TestLazyLogitsInitialization, self).setUp()
|
||||
# ContinuousBernoulli is not tested because log_prob is not computed simply
|
||||
# from 'logits', but 'probs' is also needed
|
||||
self.examples = [e for e in EXAMPLES if e.Dist in
|
||||
(Categorical, OneHotCategorical, Bernoulli, Binomial, Multinomial)]
|
||||
|
||||
|
|
|
|||
|
|
@ -78,6 +78,7 @@ from .categorical import Categorical
|
|||
from .cauchy import Cauchy
|
||||
from .chi2 import Chi2
|
||||
from .constraint_registry import biject_to, transform_to
|
||||
from .continuous_bernoulli import ContinuousBernoulli
|
||||
from .dirichlet import Dirichlet
|
||||
from .distribution import Distribution
|
||||
from .exp_family import ExponentialFamily
|
||||
|
|
@ -118,6 +119,7 @@ __all__ = [
|
|||
'Categorical',
|
||||
'Cauchy',
|
||||
'Chi2',
|
||||
'ContinuousBernoulli',
|
||||
'Dirichlet',
|
||||
'Distribution',
|
||||
'Exponential',
|
||||
|
|
|
|||
198
torch/distributions/continuous_bernoulli.py
Normal file
198
torch/distributions/continuous_bernoulli.py
Normal file
|
|
@ -0,0 +1,198 @@
|
|||
from numbers import Number
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch.distributions import constraints
|
||||
from torch.distributions.exp_family import ExponentialFamily
|
||||
from torch.distributions.utils import broadcast_all, probs_to_logits, logits_to_probs, lazy_property, clamp_probs
|
||||
from torch.nn.functional import binary_cross_entropy_with_logits
|
||||
|
||||
|
||||
class ContinuousBernoulli(ExponentialFamily):
|
||||
r"""
|
||||
Creates a continuous Bernoulli distribution parameterized by :attr:`probs`
|
||||
or :attr:`logits` (but not both).
|
||||
|
||||
The distribution is supported in [0, 1] and parameterized by 'probs' (in
|
||||
(0,1)) or 'logits' (real-valued). Note that, unlike the Bernoulli, 'probs'
|
||||
does not correspond to a probability and 'logits' does not correspond to
|
||||
log-odds, but the same names are used due to the similarity with the
|
||||
Bernoulli. See [1] for more details.
|
||||
|
||||
Example::
|
||||
|
||||
>>> m = ContinuousBernoulli(torch.tensor([0.3]))
|
||||
>>> m.sample()
|
||||
tensor([ 0.2538])
|
||||
|
||||
Args:
|
||||
probs (Number, Tensor): (0,1) valued parameters
|
||||
logits (Number, Tensor): real valued parameters whose sigmoid matches 'probs'
|
||||
|
||||
[1] The continuous Bernoulli: fixing a pervasive error in variational
|
||||
autoencoders, Loaiza-Ganem G and Cunningham JP, NeurIPS 2019.
|
||||
https://arxiv.org/abs/1907.06845
|
||||
"""
|
||||
arg_constraints = {'probs': constraints.unit_interval,
|
||||
'logits': constraints.real}
|
||||
support = constraints.unit_interval
|
||||
_mean_carrier_measure = 0
|
||||
has_rsample = True
|
||||
|
||||
def __init__(self, probs=None, logits=None, lims=(0.499, 0.501), validate_args=None):
|
||||
if (probs is None) == (logits is None):
|
||||
raise ValueError("Either `probs` or `logits` must be specified, but not both.")
|
||||
if probs is not None:
|
||||
is_scalar = isinstance(probs, Number)
|
||||
self.probs, = broadcast_all(probs)
|
||||
# validate 'probs' here if necessary as it is later clamped for numerical stability
|
||||
# close to 0 and 1, later on; otherwise the clamped 'probs' would always pass
|
||||
if validate_args is not None:
|
||||
if not self.arg_constraints['probs'].check(getattr(self, 'probs')).all():
|
||||
raise ValueError("The parameter {} has invalid values".format('probs'))
|
||||
self.probs = clamp_probs(self.probs)
|
||||
else:
|
||||
is_scalar = isinstance(logits, Number)
|
||||
self.logits, = broadcast_all(logits)
|
||||
self._param = self.probs if probs is not None else self.logits
|
||||
if is_scalar:
|
||||
batch_shape = torch.Size()
|
||||
else:
|
||||
batch_shape = self._param.size()
|
||||
self._lims = lims
|
||||
super(ContinuousBernoulli, self).__init__(batch_shape, validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
new = self._get_checked_instance(ContinuousBernoulli, _instance)
|
||||
new._lims = self._lims
|
||||
batch_shape = torch.Size(batch_shape)
|
||||
if 'probs' in self.__dict__:
|
||||
new.probs = self.probs.expand(batch_shape)
|
||||
new._param = new.probs
|
||||
if 'logits' in self.__dict__:
|
||||
new.logits = self.logits.expand(batch_shape)
|
||||
new._param = new.logits
|
||||
super(ContinuousBernoulli, new).__init__(batch_shape, validate_args=False)
|
||||
new._validate_args = self._validate_args
|
||||
return new
|
||||
|
||||
def _new(self, *args, **kwargs):
|
||||
return self._param.new(*args, **kwargs)
|
||||
|
||||
def _outside_unstable_region(self):
|
||||
return torch.max(torch.le(self.probs, self._lims[0]),
|
||||
torch.gt(self.probs, self._lims[1]))
|
||||
|
||||
def _cut_probs(self):
|
||||
return torch.where(self._outside_unstable_region(),
|
||||
self.probs,
|
||||
self._lims[0] * torch.ones_like(self.probs))
|
||||
|
||||
def _cont_bern_log_norm(self):
|
||||
'''computes the log normalizing constant as a function of the 'probs' parameter'''
|
||||
cut_probs = self._cut_probs()
|
||||
cut_probs_below_half = torch.where(torch.le(cut_probs, 0.5),
|
||||
cut_probs,
|
||||
torch.zeros_like(cut_probs))
|
||||
cut_probs_above_half = torch.where(torch.ge(cut_probs, 0.5),
|
||||
cut_probs,
|
||||
torch.ones_like(cut_probs))
|
||||
log_norm = torch.log(torch.abs(torch.log1p(-cut_probs) - torch.log(cut_probs))) - torch.where(
|
||||
torch.le(cut_probs, 0.5),
|
||||
torch.log1p(-2.0 * cut_probs_below_half),
|
||||
torch.log(2.0 * cut_probs_above_half - 1.0))
|
||||
x = torch.pow(self.probs - 0.5, 2)
|
||||
taylor = math.log(2.0) + (4.0 / 3.0 + 104.0 / 45.0 * x) * x
|
||||
return torch.where(self._outside_unstable_region(), log_norm, taylor)
|
||||
|
||||
@property
|
||||
def mean(self):
|
||||
cut_probs = self._cut_probs()
|
||||
mus = cut_probs / (2.0 * cut_probs - 1.0) + 1.0 / (torch.log1p(-cut_probs) - torch.log(cut_probs))
|
||||
x = self.probs - 0.5
|
||||
taylor = 0.5 + (1.0 / 3.0 + 16.0 / 45.0 * torch.pow(x, 2)) * x
|
||||
return torch.where(self._outside_unstable_region(), mus, taylor)
|
||||
|
||||
@property
|
||||
def stddev(self):
|
||||
return torch.sqrt(self.variance)
|
||||
|
||||
@property
|
||||
def variance(self):
|
||||
cut_probs = self._cut_probs()
|
||||
vars = cut_probs * (cut_probs - 1.0) / torch.pow(1.0 - 2.0 * cut_probs, 2) + 1.0 / torch.pow(
|
||||
torch.log1p(-cut_probs) - torch.log(cut_probs), 2)
|
||||
x = torch.pow(self.probs - 0.5, 2)
|
||||
taylor = 1.0 / 12.0 - (1.0 / 15.0 - 128. / 945.0 * x) * x
|
||||
return torch.where(self._outside_unstable_region(), vars, taylor)
|
||||
|
||||
@lazy_property
|
||||
def logits(self):
|
||||
return probs_to_logits(self.probs, is_binary=True)
|
||||
|
||||
@lazy_property
|
||||
def probs(self):
|
||||
return clamp_probs(logits_to_probs(self.logits, is_binary=True))
|
||||
|
||||
@property
|
||||
def param_shape(self):
|
||||
return self._param.size()
|
||||
|
||||
def sample(self, sample_shape=torch.Size()):
|
||||
shape = self._extended_shape(sample_shape)
|
||||
u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device)
|
||||
with torch.no_grad():
|
||||
return self.icdf(u)
|
||||
|
||||
def rsample(self, sample_shape=torch.Size()):
|
||||
shape = self._extended_shape(sample_shape)
|
||||
u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device)
|
||||
return self.icdf(u)
|
||||
|
||||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
logits, value = broadcast_all(self.logits, value)
|
||||
return -binary_cross_entropy_with_logits(logits, value, reduction='none') + self._cont_bern_log_norm()
|
||||
|
||||
def cdf(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
cut_probs = self._cut_probs()
|
||||
cdfs = (torch.pow(cut_probs, value) * torch.pow(1.0 - cut_probs, 1.0 - value)
|
||||
+ cut_probs - 1.0) / (2.0 * cut_probs - 1.0)
|
||||
unbounded_cdfs = torch.where(self._outside_unstable_region(), cdfs, value)
|
||||
return torch.where(
|
||||
torch.le(value, 0.0),
|
||||
torch.zeros_like(value),
|
||||
torch.where(torch.ge(value, 1.0), torch.ones_like(value), unbounded_cdfs))
|
||||
|
||||
def icdf(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
cut_probs = self._cut_probs()
|
||||
return torch.where(
|
||||
self._outside_unstable_region(),
|
||||
(torch.log1p(-cut_probs + value * (2.0 * cut_probs - 1.0))
|
||||
- torch.log1p(-cut_probs)) / (torch.log(cut_probs) - torch.log1p(-cut_probs)),
|
||||
value)
|
||||
|
||||
def entropy(self):
|
||||
log_probs0 = torch.log1p(-self.probs)
|
||||
log_probs1 = torch.log(self.probs)
|
||||
return self.mean * (log_probs0 - log_probs1) - self._cont_bern_log_norm() - log_probs0
|
||||
|
||||
@property
|
||||
def _natural_params(self):
|
||||
return (self.logits, )
|
||||
|
||||
def _log_normalizer(self, x):
|
||||
"""computes the log normalizing constant as a function of the natural parameter"""
|
||||
out_unst_reg = torch.max(torch.le(x, self._lims[0] - 0.5),
|
||||
torch.gt(x, self._lims[1] - 0.5))
|
||||
cut_nat_params = torch.where(out_unst_reg,
|
||||
x,
|
||||
(self._lims[0] - 0.5) * torch.ones_like(x))
|
||||
log_norm = torch.log(torch.abs(torch.exp(cut_nat_params) - 1.0)) - torch.log(torch.abs(cut_nat_params))
|
||||
taylor = 0.5 * x + torch.pow(x, 2) / 24.0 - torch.pow(x, 4) / 2880.0
|
||||
return torch.where(out_unst_reg, log_norm, taylor)
|
||||
|
|
@ -9,6 +9,7 @@ from .bernoulli import Bernoulli
|
|||
from .beta import Beta
|
||||
from .binomial import Binomial
|
||||
from .categorical import Categorical
|
||||
from .continuous_bernoulli import ContinuousBernoulli
|
||||
from .dirichlet import Dirichlet
|
||||
from .distribution import Distribution
|
||||
from .exponential import Exponential
|
||||
|
|
@ -217,6 +218,14 @@ def _kl_categorical_categorical(p, q):
|
|||
return t.sum(-1)
|
||||
|
||||
|
||||
@register_kl(ContinuousBernoulli, ContinuousBernoulli)
|
||||
def _kl_continuous_bernoulli_continuous_bernoulli(p, q):
|
||||
t1 = p.mean * (p.logits - q.logits)
|
||||
t2 = p._cont_bern_log_norm() + torch.log1p(-p.probs)
|
||||
t3 = - q._cont_bern_log_norm() - torch.log1p(-q.probs)
|
||||
return t1 + t2 + t3
|
||||
|
||||
|
||||
@register_kl(Dirichlet, Dirichlet)
|
||||
def _kl_dirichlet_dirichlet(p, q):
|
||||
# From http://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/
|
||||
|
|
@ -445,6 +454,11 @@ def _kl_bernoulli_poisson(p, q):
|
|||
return -p.entropy() - (p.probs * q.rate.log() - q.rate)
|
||||
|
||||
|
||||
@register_kl(Beta, ContinuousBernoulli)
|
||||
def _kl_beta_continuous_bernoulli(p, q):
|
||||
return -p.entropy() - p.mean * q.logits - torch.log1p(-q.probs) - q._cont_bern_log_norm()
|
||||
|
||||
|
||||
@register_kl(Beta, Pareto)
|
||||
def _kl_beta_infinity(p, q):
|
||||
return _infinite_like(p.concentration1)
|
||||
|
|
@ -484,8 +498,40 @@ def _kl_beta_uniform(p, q):
|
|||
result[(q.low > p.support.lower_bound) | (q.high < p.support.upper_bound)] = inf
|
||||
return result
|
||||
|
||||
# Note that the KL between a ContinuousBernoulli and Beta has no closed form
|
||||
|
||||
|
||||
@register_kl(ContinuousBernoulli, Pareto)
|
||||
def _kl_continuous_bernoulli_infinity(p, q):
|
||||
return _infinite_like(p.probs)
|
||||
|
||||
|
||||
@register_kl(ContinuousBernoulli, Exponential)
|
||||
def _kl_continuous_bernoulli_exponential(p, q):
|
||||
return -p.entropy() - torch.log(q.rate) + q.rate * p.mean
|
||||
|
||||
# Note that the KL between a ContinuousBernoulli and Gamma has no closed form
|
||||
# TODO: Add ContinuousBernoulli-Laplace KL Divergence
|
||||
|
||||
|
||||
@register_kl(ContinuousBernoulli, Normal)
|
||||
def _kl_continuous_bernoulli_normal(p, q):
|
||||
t1 = -p.entropy()
|
||||
t2 = 0.5 * (math.log(2. * math.pi) + torch.square(q.loc / q.scale)) + torch.log(q.scale)
|
||||
t3 = (p.variance + torch.square(p.mean) - 2. * q.loc * p.mean) / (2.0 * torch.square(q.scale))
|
||||
return t1 + t2 + t3
|
||||
|
||||
|
||||
@register_kl(ContinuousBernoulli, Uniform)
|
||||
def _kl_continuous_bernoulli_uniform(p, q):
|
||||
result = -p.entropy() + (q.high - q.low).log()
|
||||
return torch.where(torch.max(torch.ge(q.low, p.support.lower_bound),
|
||||
torch.le(q.high, p.support.upper_bound)),
|
||||
torch.ones_like(result) * inf, result)
|
||||
|
||||
|
||||
@register_kl(Exponential, Beta)
|
||||
@register_kl(Exponential, ContinuousBernoulli)
|
||||
@register_kl(Exponential, Pareto)
|
||||
@register_kl(Exponential, Uniform)
|
||||
def _kl_exponential_infinity(p, q):
|
||||
|
|
@ -523,6 +569,7 @@ def _kl_exponential_normal(p, q):
|
|||
|
||||
|
||||
@register_kl(Gamma, Beta)
|
||||
@register_kl(Gamma, ContinuousBernoulli)
|
||||
@register_kl(Gamma, Pareto)
|
||||
@register_kl(Gamma, Uniform)
|
||||
def _kl_gamma_infinity(p, q):
|
||||
|
|
@ -558,6 +605,7 @@ def _kl_gamma_normal(p, q):
|
|||
|
||||
|
||||
@register_kl(Gumbel, Beta)
|
||||
@register_kl(Gumbel, ContinuousBernoulli)
|
||||
@register_kl(Gumbel, Exponential)
|
||||
@register_kl(Gumbel, Gamma)
|
||||
@register_kl(Gumbel, Pareto)
|
||||
|
|
@ -578,6 +626,7 @@ def _kl_gumbel_normal(p, q):
|
|||
|
||||
|
||||
@register_kl(Laplace, Beta)
|
||||
@register_kl(Laplace, ContinuousBernoulli)
|
||||
@register_kl(Laplace, Exponential)
|
||||
@register_kl(Laplace, Gamma)
|
||||
@register_kl(Laplace, Pareto)
|
||||
|
|
@ -598,6 +647,7 @@ def _kl_laplace_normal(p, q):
|
|||
|
||||
|
||||
@register_kl(Normal, Beta)
|
||||
@register_kl(Normal, ContinuousBernoulli)
|
||||
@register_kl(Normal, Exponential)
|
||||
@register_kl(Normal, Gamma)
|
||||
@register_kl(Normal, Pareto)
|
||||
|
|
@ -620,6 +670,7 @@ def _kl_normal_gumbel(p, q):
|
|||
|
||||
|
||||
@register_kl(Pareto, Beta)
|
||||
@register_kl(Pareto, ContinuousBernoulli)
|
||||
@register_kl(Pareto, Uniform)
|
||||
def _kl_pareto_infinity(p, q):
|
||||
return _infinite_like(p.scale)
|
||||
|
|
@ -681,6 +732,14 @@ def _kl_uniform_beta(p, q):
|
|||
return result
|
||||
|
||||
|
||||
@register_kl(Uniform, ContinuousBernoulli)
|
||||
def _kl_uniform_continuous_bernoulli(p, q):
|
||||
result = -p.entropy() - p.mean * q.logits - torch.log1p(-q.probs) - q._cont_bern_log_norm()
|
||||
return torch.where(torch.max(torch.ge(p.high, q.support.upper_bound),
|
||||
torch.le(p.low, q.support.lower_bound)),
|
||||
torch.ones_like(result) * inf, result)
|
||||
|
||||
|
||||
@register_kl(Uniform, Exponential)
|
||||
def _kl_uniform_exponetial(p, q):
|
||||
result = q.rate * (p.high + p.low) / 2 - ((p.high - p.low) * q.rate).log()
|
||||
|
|
|
|||
Loading…
Reference in a new issue