Continuous bernoulli distribution (take 2) (#34619)

Summary:
We recently had a NeurIPS paper (https://arxiv.org/abs/1907.06845 and https://papers.nips.cc/paper/9484-the-continuous-bernoulli-fixing-a-pervasive-error-in-variational-autoencoders) where we introduce a new [0,1]-supported distribution: the continuous Bernoulli. This pull request implements this distribution in pytorch.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34619

Differential Revision: D20403123

Pulled By: ngimel

fbshipit-source-id: d807c7d0d372c6daf6cb6ef09df178bc7491abb2
This commit is contained in:
gabloa 2020-03-12 11:43:58 -07:00 committed by Facebook GitHub Bot
parent 944ea4c334
commit a74fbea345
5 changed files with 454 additions and 3 deletions

View file

@ -77,6 +77,15 @@ Probability distributions - torch.distributions
:undoc-members:
:show-inheritance:
:hidden:`ContinuousBernoulli`
~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.continuous_bernoulli
.. autoclass:: ContinuousBernoulli
:members:
:undoc-members:
:show-inheritance:
:hidden:`Dirichlet`
~~~~~~~~~~~~~~~~~~~~~~~

View file

@ -40,8 +40,8 @@ from torch.testing._internal.common_utils import TestCase, run_tests, set_rng_se
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.autograd import grad, gradcheck
from torch.distributions import (Bernoulli, Beta, Binomial, Categorical,
Cauchy, Chi2, Dirichlet, Distribution,
Exponential, ExponentialFamily,
Cauchy, Chi2, ContinuousBernoulli, Dirichlet,
Distribution, Exponential, ExponentialFamily,
FisherSnedecor, Gamma, Geometric, Gumbel,
HalfCauchy, HalfNormal,
Independent, Laplace, LogisticNormal,
@ -452,7 +452,13 @@ EXAMPLES = [
{
'loc': torch.tensor([0.0, math.pi / 2], requires_grad=True),
'concentration': torch.tensor([1.0, 10.0], requires_grad=True)
}
},
]),
Example(ContinuousBernoulli, [
{'probs': torch.tensor([0.7, 0.2, 0.4], requires_grad=True)},
{'probs': torch.tensor([0.3], requires_grad=True)},
{'probs': 0.3},
{'logits': torch.tensor([0.], requires_grad=True)},
])
]
@ -673,6 +679,11 @@ BAD_EXAMPLES = [
'scale': torch.tensor([1.0], requires_grad=True),
'concentration': torch.tensor([-1.0], requires_grad=True)
}
]),
Example(ContinuousBernoulli, [
{'probs': torch.tensor([1.1, 0.2, 0.4], requires_grad=True)},
{'probs': torch.tensor([-0.5], requires_grad=True)},
{'probs': 1.00001},
])
]
@ -2402,6 +2413,44 @@ class TestDistributions(TestCase):
self.assertEqual(frac_zeros, 0.5, 0.12)
self.assertEqual(frac_ones, 0.5, 0.12)
def test_continuous_bernoulli(self):
p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True)
r = torch.tensor(0.3, requires_grad=True)
s = 0.3
self.assertEqual(ContinuousBernoulli(p).sample((8,)).size(), (8, 3))
self.assertFalse(ContinuousBernoulli(p).sample().requires_grad)
self.assertEqual(ContinuousBernoulli(r).sample((8,)).size(), (8,))
self.assertEqual(ContinuousBernoulli(r).sample().size(), ())
self.assertEqual(ContinuousBernoulli(r).sample((3, 2)).size(), (3, 2,))
self.assertEqual(ContinuousBernoulli(s).sample().size(), ())
self._gradcheck_log_prob(ContinuousBernoulli, (p,))
def ref_log_prob(idx, val, log_prob):
prob = p[idx]
if prob > 0.499 and prob < 0.501: # using default value of lim here
log_norm_const = math.log(2.) + 4. / 3. * math.pow(prob - 0.5, 2) + 104. / 45. * math.pow(prob - 0.5, 4)
else:
log_norm_const = math.log(2. * math.atanh(1. - 2. * prob) / (1. - 2.0 * prob))
res = val * math.log(prob) + (1. - val) * math.log1p(-prob) + log_norm_const
self.assertEqual(log_prob, res)
self._check_log_prob(ContinuousBernoulli(p), ref_log_prob)
self._check_log_prob(ContinuousBernoulli(logits=p.log() - (-p).log1p()), ref_log_prob)
# check entropy computation
self.assertEqual(ContinuousBernoulli(p).entropy(), torch.tensor([-0.02938, -0.07641, -0.00682]), prec=1e-4)
# entropy below corresponds to the clamped value of prob when using float 64
# the value for float32 should be -1.76898
self.assertEqual(ContinuousBernoulli(torch.tensor([0.0])).entropy(), torch.tensor([-2.58473]))
self.assertEqual(ContinuousBernoulli(s).entropy(), torch.tensor(-0.02938), prec=1e-4)
def test_continuous_bernoulli_3d(self):
p = torch.full((2, 3, 5), 0.5).requires_grad_()
self.assertEqual(ContinuousBernoulli(p).sample().size(), (2, 3, 5))
self.assertEqual(ContinuousBernoulli(p).sample(sample_shape=(2, 5)).size(),
(2, 5, 2, 3, 5))
self.assertEqual(ContinuousBernoulli(p).sample((2,)).size(), (2, 2, 3, 5))
def test_independent_shape(self):
for Dist, params in EXAMPLES:
for param in params:
@ -3271,6 +3320,26 @@ class TestDistributionShapes(TestCase):
self.assertRaises(ValueError, laplace.log_prob, self.tensor_sample_2)
self.assertEqual(laplace.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
def test_continuous_bernoulli_shape_scalar_params(self):
continuous_bernoulli = ContinuousBernoulli(0.3)
self.assertEqual(continuous_bernoulli._batch_shape, torch.Size())
self.assertEqual(continuous_bernoulli._event_shape, torch.Size())
self.assertEqual(continuous_bernoulli.sample().size(), torch.Size())
self.assertEqual(continuous_bernoulli.sample((3, 2)).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, continuous_bernoulli.log_prob, self.scalar_sample)
self.assertEqual(continuous_bernoulli.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(continuous_bernoulli.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
def test_continuous_bernoulli_shape_tensor_params(self):
continuous_bernoulli = ContinuousBernoulli(torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
self.assertEqual(continuous_bernoulli._batch_shape, torch.Size((3, 2)))
self.assertEqual(continuous_bernoulli._event_shape, torch.Size(()))
self.assertEqual(continuous_bernoulli.sample().size(), torch.Size((3, 2)))
self.assertEqual(continuous_bernoulli.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
self.assertEqual(continuous_bernoulli.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, continuous_bernoulli.log_prob, self.tensor_sample_2)
self.assertEqual(continuous_bernoulli.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2)))
class TestKL(TestCase):
@ -3316,6 +3385,7 @@ class TestKL(TestCase):
uniform_positive = pairwise(Uniform, [1, 1.5, 2, 4], [1.2, 2.0, 3, 7])
uniform_real = pairwise(Uniform, [-2., -1, 0, 2], [-1., 1, 1, 4])
uniform_pareto = pairwise(Uniform, [6.5, 8.5, 6.5, 8.5], [7.5, 7.5, 9.5, 9.5])
continuous_bernoulli = pairwise(ContinuousBernoulli, [0.1, 0.2, 0.5, 0.9])
# These tests should pass with precision = 0.01, but that makes tests very expensive.
# Instead, we test with precision = 0.1 and only test with higher precision locally
@ -3374,6 +3444,10 @@ class TestKL(TestCase):
(uniform_real, gumbel),
(uniform_real, normal),
(uniform_pareto, pareto),
(continuous_bernoulli, continuous_bernoulli),
(continuous_bernoulli, exponential),
(continuous_bernoulli, normal),
(beta, continuous_bernoulli)
]
self.infinite_examples = [
@ -3429,6 +3503,18 @@ class TestKL(TestCase):
(Uniform(-1, 2), Exponential(3)),
(Uniform(-1, 2), Gamma(3, 4)),
(Uniform(-1, 2), Pareto(3, 4)),
(ContinuousBernoulli(0.25), Uniform(0.25, 1)),
(ContinuousBernoulli(0.25), Uniform(0, 0.75)),
(ContinuousBernoulli(0.25), Uniform(0.25, 0.75)),
(ContinuousBernoulli(0.25), Pareto(1, 2)),
(Exponential(1), ContinuousBernoulli(0.75)),
(Gamma(1, 2), ContinuousBernoulli(0.75)),
(Gumbel(-1, 2), ContinuousBernoulli(0.75)),
(Laplace(-1, 2), ContinuousBernoulli(0.75)),
(Normal(-1, 2), ContinuousBernoulli(0.75)),
(Uniform(-1, 1), ContinuousBernoulli(0.75)),
(Uniform(0, 2), ContinuousBernoulli(0.75)),
(Uniform(-1, 2), ContinuousBernoulli(0.75))
]
def test_kl_monte_carlo(self):
@ -3787,10 +3873,107 @@ class TestNumericalStability(TestCase):
log_pdf_prob_0 = multinomial.log_prob(torch.tensor([10, 0], dtype=dtype))
self.assertEqual(log_pdf_prob_0.item(), -inf, allow_inf=True)
def test_continuous_bernoulli_gradient(self):
def expec_val(x, probs=None, logits=None):
assert not (probs is None and logits is None)
if logits is not None:
probs = 1. / (1. + math.exp(-logits))
bern_log_lik = x * math.log(probs) + (1. - x) * math.log1p(-probs)
if probs < 0.499 or probs > 0.501: # using default values of lims here
log_norm_const = math.log(
math.fabs(math.atanh(1. - 2. * probs))) - math.log(math.fabs(1. - 2. * probs)) + math.log(2.)
else:
aux = math.pow(probs - 0.5, 2)
log_norm_const = math.log(2.0) + (4.0 / 3.0 + 104.0 / 45.0 * aux) * aux
log_lik = bern_log_lik + log_norm_const
return log_lik
def expec_grad(x, probs=None, logits=None):
assert not (probs is None and logits is None)
if logits is not None:
probs = 1. / (1. + math.exp(-logits))
grad_bern_log_lik = x / probs - (1. - x) / (1. - probs)
if probs < 0.499 or probs > 0.501: # using default values of lims here
grad_log_c = 2. * probs - 4. * (probs - 1.) * probs * math.atanh(1. - 2. * probs) - 1.
grad_log_c /= 2. * (probs - 1.) * probs * (2. * probs - 1.) * math.atanh(1. - 2. * probs)
else:
grad_log_c = 8. / 3. * (probs - 0.5) + 416. / 45. * math.pow(probs - 0.5, 3)
grad = grad_bern_log_lik + grad_log_c
if logits is not None:
grad *= 1. / (1. + math.exp(logits)) - 1. / math.pow(1. + math.exp(logits), 2)
return grad
for tensor_type in [torch.FloatTensor, torch.DoubleTensor]:
self._test_pdf_score(dist_class=ContinuousBernoulli,
probs=tensor_type([0.1]),
x=tensor_type([0.1]),
expected_value=tensor_type([expec_val(0.1, probs=0.1)]),
expected_gradient=tensor_type([expec_grad(0.1, probs=0.1)]))
self._test_pdf_score(dist_class=ContinuousBernoulli,
probs=tensor_type([0.1]),
x=tensor_type([1.]),
expected_value=tensor_type([expec_val(1., probs=0.1)]),
expected_gradient=tensor_type([expec_grad(1., probs=0.1)]))
self._test_pdf_score(dist_class=ContinuousBernoulli,
probs=tensor_type([0.4999]),
x=tensor_type([0.9]),
expected_value=tensor_type([expec_val(0.9, probs=0.4999)]),
expected_gradient=tensor_type([expec_grad(0.9, probs=0.4999)]))
self._test_pdf_score(dist_class=ContinuousBernoulli,
probs=tensor_type([1e-4]),
x=tensor_type([1]),
expected_value=tensor_type([expec_val(1, probs=1e-4)]),
expected_gradient=tensor_type(tensor_type([expec_grad(1, probs=1e-4)])),
prec=1e-3)
self._test_pdf_score(dist_class=ContinuousBernoulli,
probs=tensor_type([1 - 1e-4]),
x=tensor_type([0.1]),
expected_value=tensor_type([expec_val(0.1, probs=1 - 1e-4)]),
expected_gradient=tensor_type([expec_grad(0.1, probs=1 - 1e-4)]),
prec=2)
self._test_pdf_score(dist_class=ContinuousBernoulli,
logits=tensor_type([math.log(9999)]),
x=tensor_type([0]),
expected_value=tensor_type([expec_val(0, logits=math.log(9999))]),
expected_gradient=tensor_type([expec_grad(0, logits=math.log(9999))]),
prec=1e-3)
self._test_pdf_score(dist_class=ContinuousBernoulli,
logits=tensor_type([0.001]),
x=tensor_type([0.5]),
expected_value=tensor_type([expec_val(0.5, logits=0.001)]),
expected_gradient=tensor_type([expec_grad(0.5, logits=0.001)]))
def test_continuous_bernoulli_with_logits_underflow(self):
for tensor_type, lim, expected in ([(torch.FloatTensor, -1e38, 2.76898),
(torch.DoubleTensor, -1e308, 3.58473)]):
self._test_pdf_score(dist_class=ContinuousBernoulli,
logits=tensor_type([lim]),
x=tensor_type([0]),
expected_value=tensor_type([expected]),
expected_gradient=tensor_type([0.]))
def test_continuous_bernoulli_with_logits_overflow(self):
for tensor_type, lim, expected in ([(torch.FloatTensor, 1e38, 2.76898),
(torch.DoubleTensor, 1e308, 3.58473)]):
self._test_pdf_score(dist_class=ContinuousBernoulli,
logits=tensor_type([lim]),
x=tensor_type([1]),
expected_value=tensor_type([expected]),
expected_gradient=tensor_type([0.]))
class TestLazyLogitsInitialization(TestCase):
def setUp(self):
super(TestLazyLogitsInitialization, self).setUp()
# ContinuousBernoulli is not tested because log_prob is not computed simply
# from 'logits', but 'probs' is also needed
self.examples = [e for e in EXAMPLES if e.Dist in
(Categorical, OneHotCategorical, Bernoulli, Binomial, Multinomial)]

View file

@ -78,6 +78,7 @@ from .categorical import Categorical
from .cauchy import Cauchy
from .chi2 import Chi2
from .constraint_registry import biject_to, transform_to
from .continuous_bernoulli import ContinuousBernoulli
from .dirichlet import Dirichlet
from .distribution import Distribution
from .exp_family import ExponentialFamily
@ -118,6 +119,7 @@ __all__ = [
'Categorical',
'Cauchy',
'Chi2',
'ContinuousBernoulli',
'Dirichlet',
'Distribution',
'Exponential',

View file

@ -0,0 +1,198 @@
from numbers import Number
import math
import torch
from torch.distributions import constraints
from torch.distributions.exp_family import ExponentialFamily
from torch.distributions.utils import broadcast_all, probs_to_logits, logits_to_probs, lazy_property, clamp_probs
from torch.nn.functional import binary_cross_entropy_with_logits
class ContinuousBernoulli(ExponentialFamily):
r"""
Creates a continuous Bernoulli distribution parameterized by :attr:`probs`
or :attr:`logits` (but not both).
The distribution is supported in [0, 1] and parameterized by 'probs' (in
(0,1)) or 'logits' (real-valued). Note that, unlike the Bernoulli, 'probs'
does not correspond to a probability and 'logits' does not correspond to
log-odds, but the same names are used due to the similarity with the
Bernoulli. See [1] for more details.
Example::
>>> m = ContinuousBernoulli(torch.tensor([0.3]))
>>> m.sample()
tensor([ 0.2538])
Args:
probs (Number, Tensor): (0,1) valued parameters
logits (Number, Tensor): real valued parameters whose sigmoid matches 'probs'
[1] The continuous Bernoulli: fixing a pervasive error in variational
autoencoders, Loaiza-Ganem G and Cunningham JP, NeurIPS 2019.
https://arxiv.org/abs/1907.06845
"""
arg_constraints = {'probs': constraints.unit_interval,
'logits': constraints.real}
support = constraints.unit_interval
_mean_carrier_measure = 0
has_rsample = True
def __init__(self, probs=None, logits=None, lims=(0.499, 0.501), validate_args=None):
if (probs is None) == (logits is None):
raise ValueError("Either `probs` or `logits` must be specified, but not both.")
if probs is not None:
is_scalar = isinstance(probs, Number)
self.probs, = broadcast_all(probs)
# validate 'probs' here if necessary as it is later clamped for numerical stability
# close to 0 and 1, later on; otherwise the clamped 'probs' would always pass
if validate_args is not None:
if not self.arg_constraints['probs'].check(getattr(self, 'probs')).all():
raise ValueError("The parameter {} has invalid values".format('probs'))
self.probs = clamp_probs(self.probs)
else:
is_scalar = isinstance(logits, Number)
self.logits, = broadcast_all(logits)
self._param = self.probs if probs is not None else self.logits
if is_scalar:
batch_shape = torch.Size()
else:
batch_shape = self._param.size()
self._lims = lims
super(ContinuousBernoulli, self).__init__(batch_shape, validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(ContinuousBernoulli, _instance)
new._lims = self._lims
batch_shape = torch.Size(batch_shape)
if 'probs' in self.__dict__:
new.probs = self.probs.expand(batch_shape)
new._param = new.probs
if 'logits' in self.__dict__:
new.logits = self.logits.expand(batch_shape)
new._param = new.logits
super(ContinuousBernoulli, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new
def _new(self, *args, **kwargs):
return self._param.new(*args, **kwargs)
def _outside_unstable_region(self):
return torch.max(torch.le(self.probs, self._lims[0]),
torch.gt(self.probs, self._lims[1]))
def _cut_probs(self):
return torch.where(self._outside_unstable_region(),
self.probs,
self._lims[0] * torch.ones_like(self.probs))
def _cont_bern_log_norm(self):
'''computes the log normalizing constant as a function of the 'probs' parameter'''
cut_probs = self._cut_probs()
cut_probs_below_half = torch.where(torch.le(cut_probs, 0.5),
cut_probs,
torch.zeros_like(cut_probs))
cut_probs_above_half = torch.where(torch.ge(cut_probs, 0.5),
cut_probs,
torch.ones_like(cut_probs))
log_norm = torch.log(torch.abs(torch.log1p(-cut_probs) - torch.log(cut_probs))) - torch.where(
torch.le(cut_probs, 0.5),
torch.log1p(-2.0 * cut_probs_below_half),
torch.log(2.0 * cut_probs_above_half - 1.0))
x = torch.pow(self.probs - 0.5, 2)
taylor = math.log(2.0) + (4.0 / 3.0 + 104.0 / 45.0 * x) * x
return torch.where(self._outside_unstable_region(), log_norm, taylor)
@property
def mean(self):
cut_probs = self._cut_probs()
mus = cut_probs / (2.0 * cut_probs - 1.0) + 1.0 / (torch.log1p(-cut_probs) - torch.log(cut_probs))
x = self.probs - 0.5
taylor = 0.5 + (1.0 / 3.0 + 16.0 / 45.0 * torch.pow(x, 2)) * x
return torch.where(self._outside_unstable_region(), mus, taylor)
@property
def stddev(self):
return torch.sqrt(self.variance)
@property
def variance(self):
cut_probs = self._cut_probs()
vars = cut_probs * (cut_probs - 1.0) / torch.pow(1.0 - 2.0 * cut_probs, 2) + 1.0 / torch.pow(
torch.log1p(-cut_probs) - torch.log(cut_probs), 2)
x = torch.pow(self.probs - 0.5, 2)
taylor = 1.0 / 12.0 - (1.0 / 15.0 - 128. / 945.0 * x) * x
return torch.where(self._outside_unstable_region(), vars, taylor)
@lazy_property
def logits(self):
return probs_to_logits(self.probs, is_binary=True)
@lazy_property
def probs(self):
return clamp_probs(logits_to_probs(self.logits, is_binary=True))
@property
def param_shape(self):
return self._param.size()
def sample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device)
with torch.no_grad():
return self.icdf(u)
def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device)
return self.icdf(u)
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
logits, value = broadcast_all(self.logits, value)
return -binary_cross_entropy_with_logits(logits, value, reduction='none') + self._cont_bern_log_norm()
def cdf(self, value):
if self._validate_args:
self._validate_sample(value)
cut_probs = self._cut_probs()
cdfs = (torch.pow(cut_probs, value) * torch.pow(1.0 - cut_probs, 1.0 - value)
+ cut_probs - 1.0) / (2.0 * cut_probs - 1.0)
unbounded_cdfs = torch.where(self._outside_unstable_region(), cdfs, value)
return torch.where(
torch.le(value, 0.0),
torch.zeros_like(value),
torch.where(torch.ge(value, 1.0), torch.ones_like(value), unbounded_cdfs))
def icdf(self, value):
if self._validate_args:
self._validate_sample(value)
cut_probs = self._cut_probs()
return torch.where(
self._outside_unstable_region(),
(torch.log1p(-cut_probs + value * (2.0 * cut_probs - 1.0))
- torch.log1p(-cut_probs)) / (torch.log(cut_probs) - torch.log1p(-cut_probs)),
value)
def entropy(self):
log_probs0 = torch.log1p(-self.probs)
log_probs1 = torch.log(self.probs)
return self.mean * (log_probs0 - log_probs1) - self._cont_bern_log_norm() - log_probs0
@property
def _natural_params(self):
return (self.logits, )
def _log_normalizer(self, x):
"""computes the log normalizing constant as a function of the natural parameter"""
out_unst_reg = torch.max(torch.le(x, self._lims[0] - 0.5),
torch.gt(x, self._lims[1] - 0.5))
cut_nat_params = torch.where(out_unst_reg,
x,
(self._lims[0] - 0.5) * torch.ones_like(x))
log_norm = torch.log(torch.abs(torch.exp(cut_nat_params) - 1.0)) - torch.log(torch.abs(cut_nat_params))
taylor = 0.5 * x + torch.pow(x, 2) / 24.0 - torch.pow(x, 4) / 2880.0
return torch.where(out_unst_reg, log_norm, taylor)

View file

@ -9,6 +9,7 @@ from .bernoulli import Bernoulli
from .beta import Beta
from .binomial import Binomial
from .categorical import Categorical
from .continuous_bernoulli import ContinuousBernoulli
from .dirichlet import Dirichlet
from .distribution import Distribution
from .exponential import Exponential
@ -217,6 +218,14 @@ def _kl_categorical_categorical(p, q):
return t.sum(-1)
@register_kl(ContinuousBernoulli, ContinuousBernoulli)
def _kl_continuous_bernoulli_continuous_bernoulli(p, q):
t1 = p.mean * (p.logits - q.logits)
t2 = p._cont_bern_log_norm() + torch.log1p(-p.probs)
t3 = - q._cont_bern_log_norm() - torch.log1p(-q.probs)
return t1 + t2 + t3
@register_kl(Dirichlet, Dirichlet)
def _kl_dirichlet_dirichlet(p, q):
# From http://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/
@ -445,6 +454,11 @@ def _kl_bernoulli_poisson(p, q):
return -p.entropy() - (p.probs * q.rate.log() - q.rate)
@register_kl(Beta, ContinuousBernoulli)
def _kl_beta_continuous_bernoulli(p, q):
return -p.entropy() - p.mean * q.logits - torch.log1p(-q.probs) - q._cont_bern_log_norm()
@register_kl(Beta, Pareto)
def _kl_beta_infinity(p, q):
return _infinite_like(p.concentration1)
@ -484,8 +498,40 @@ def _kl_beta_uniform(p, q):
result[(q.low > p.support.lower_bound) | (q.high < p.support.upper_bound)] = inf
return result
# Note that the KL between a ContinuousBernoulli and Beta has no closed form
@register_kl(ContinuousBernoulli, Pareto)
def _kl_continuous_bernoulli_infinity(p, q):
return _infinite_like(p.probs)
@register_kl(ContinuousBernoulli, Exponential)
def _kl_continuous_bernoulli_exponential(p, q):
return -p.entropy() - torch.log(q.rate) + q.rate * p.mean
# Note that the KL between a ContinuousBernoulli and Gamma has no closed form
# TODO: Add ContinuousBernoulli-Laplace KL Divergence
@register_kl(ContinuousBernoulli, Normal)
def _kl_continuous_bernoulli_normal(p, q):
t1 = -p.entropy()
t2 = 0.5 * (math.log(2. * math.pi) + torch.square(q.loc / q.scale)) + torch.log(q.scale)
t3 = (p.variance + torch.square(p.mean) - 2. * q.loc * p.mean) / (2.0 * torch.square(q.scale))
return t1 + t2 + t3
@register_kl(ContinuousBernoulli, Uniform)
def _kl_continuous_bernoulli_uniform(p, q):
result = -p.entropy() + (q.high - q.low).log()
return torch.where(torch.max(torch.ge(q.low, p.support.lower_bound),
torch.le(q.high, p.support.upper_bound)),
torch.ones_like(result) * inf, result)
@register_kl(Exponential, Beta)
@register_kl(Exponential, ContinuousBernoulli)
@register_kl(Exponential, Pareto)
@register_kl(Exponential, Uniform)
def _kl_exponential_infinity(p, q):
@ -523,6 +569,7 @@ def _kl_exponential_normal(p, q):
@register_kl(Gamma, Beta)
@register_kl(Gamma, ContinuousBernoulli)
@register_kl(Gamma, Pareto)
@register_kl(Gamma, Uniform)
def _kl_gamma_infinity(p, q):
@ -558,6 +605,7 @@ def _kl_gamma_normal(p, q):
@register_kl(Gumbel, Beta)
@register_kl(Gumbel, ContinuousBernoulli)
@register_kl(Gumbel, Exponential)
@register_kl(Gumbel, Gamma)
@register_kl(Gumbel, Pareto)
@ -578,6 +626,7 @@ def _kl_gumbel_normal(p, q):
@register_kl(Laplace, Beta)
@register_kl(Laplace, ContinuousBernoulli)
@register_kl(Laplace, Exponential)
@register_kl(Laplace, Gamma)
@register_kl(Laplace, Pareto)
@ -598,6 +647,7 @@ def _kl_laplace_normal(p, q):
@register_kl(Normal, Beta)
@register_kl(Normal, ContinuousBernoulli)
@register_kl(Normal, Exponential)
@register_kl(Normal, Gamma)
@register_kl(Normal, Pareto)
@ -620,6 +670,7 @@ def _kl_normal_gumbel(p, q):
@register_kl(Pareto, Beta)
@register_kl(Pareto, ContinuousBernoulli)
@register_kl(Pareto, Uniform)
def _kl_pareto_infinity(p, q):
return _infinite_like(p.scale)
@ -681,6 +732,14 @@ def _kl_uniform_beta(p, q):
return result
@register_kl(Uniform, ContinuousBernoulli)
def _kl_uniform_continuous_bernoulli(p, q):
result = -p.entropy() - p.mean * q.logits - torch.log1p(-q.probs) - q._cont_bern_log_norm()
return torch.where(torch.max(torch.ge(p.high, q.support.upper_bound),
torch.le(p.low, q.support.lower_bound)),
torch.ones_like(result) * inf, result)
@register_kl(Uniform, Exponential)
def _kl_uniform_exponetial(p, q):
result = q.rate * (p.high + p.low) / 2 - ((p.high - p.low) * q.rate).log()