diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index 3ace6bd4e7b..beaec7de400 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -77,6 +77,15 @@ Probability distributions - torch.distributions :undoc-members: :show-inheritance: +:hidden:`ContinuousBernoulli` +~~~~~~~~~~~~~~~~~~~~~~~ + +.. currentmodule:: torch.distributions.continuous_bernoulli +.. autoclass:: ContinuousBernoulli + :members: + :undoc-members: + :show-inheritance: + :hidden:`Dirichlet` ~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/test/test_distributions.py b/test/test_distributions.py index 3a866971040..83ab926fbb4 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -40,8 +40,8 @@ from torch.testing._internal.common_utils import TestCase, run_tests, set_rng_se from torch.testing._internal.common_cuda import TEST_CUDA from torch.autograd import grad, gradcheck from torch.distributions import (Bernoulli, Beta, Binomial, Categorical, - Cauchy, Chi2, Dirichlet, Distribution, - Exponential, ExponentialFamily, + Cauchy, Chi2, ContinuousBernoulli, Dirichlet, + Distribution, Exponential, ExponentialFamily, FisherSnedecor, Gamma, Geometric, Gumbel, HalfCauchy, HalfNormal, Independent, Laplace, LogisticNormal, @@ -452,7 +452,13 @@ EXAMPLES = [ { 'loc': torch.tensor([0.0, math.pi / 2], requires_grad=True), 'concentration': torch.tensor([1.0, 10.0], requires_grad=True) - } + }, + ]), + Example(ContinuousBernoulli, [ + {'probs': torch.tensor([0.7, 0.2, 0.4], requires_grad=True)}, + {'probs': torch.tensor([0.3], requires_grad=True)}, + {'probs': 0.3}, + {'logits': torch.tensor([0.], requires_grad=True)}, ]) ] @@ -673,6 +679,11 @@ BAD_EXAMPLES = [ 'scale': torch.tensor([1.0], requires_grad=True), 'concentration': torch.tensor([-1.0], requires_grad=True) } + ]), + Example(ContinuousBernoulli, [ + {'probs': torch.tensor([1.1, 0.2, 0.4], requires_grad=True)}, + {'probs': torch.tensor([-0.5], requires_grad=True)}, + {'probs': 1.00001}, ]) ] @@ -2402,6 +2413,44 @@ class TestDistributions(TestCase): self.assertEqual(frac_zeros, 0.5, 0.12) self.assertEqual(frac_ones, 0.5, 0.12) + def test_continuous_bernoulli(self): + p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True) + r = torch.tensor(0.3, requires_grad=True) + s = 0.3 + self.assertEqual(ContinuousBernoulli(p).sample((8,)).size(), (8, 3)) + self.assertFalse(ContinuousBernoulli(p).sample().requires_grad) + self.assertEqual(ContinuousBernoulli(r).sample((8,)).size(), (8,)) + self.assertEqual(ContinuousBernoulli(r).sample().size(), ()) + self.assertEqual(ContinuousBernoulli(r).sample((3, 2)).size(), (3, 2,)) + self.assertEqual(ContinuousBernoulli(s).sample().size(), ()) + self._gradcheck_log_prob(ContinuousBernoulli, (p,)) + + def ref_log_prob(idx, val, log_prob): + prob = p[idx] + if prob > 0.499 and prob < 0.501: # using default value of lim here + log_norm_const = math.log(2.) + 4. / 3. * math.pow(prob - 0.5, 2) + 104. / 45. * math.pow(prob - 0.5, 4) + else: + log_norm_const = math.log(2. * math.atanh(1. - 2. * prob) / (1. - 2.0 * prob)) + res = val * math.log(prob) + (1. - val) * math.log1p(-prob) + log_norm_const + self.assertEqual(log_prob, res) + + self._check_log_prob(ContinuousBernoulli(p), ref_log_prob) + self._check_log_prob(ContinuousBernoulli(logits=p.log() - (-p).log1p()), ref_log_prob) + + # check entropy computation + self.assertEqual(ContinuousBernoulli(p).entropy(), torch.tensor([-0.02938, -0.07641, -0.00682]), prec=1e-4) + # entropy below corresponds to the clamped value of prob when using float 64 + # the value for float32 should be -1.76898 + self.assertEqual(ContinuousBernoulli(torch.tensor([0.0])).entropy(), torch.tensor([-2.58473])) + self.assertEqual(ContinuousBernoulli(s).entropy(), torch.tensor(-0.02938), prec=1e-4) + + def test_continuous_bernoulli_3d(self): + p = torch.full((2, 3, 5), 0.5).requires_grad_() + self.assertEqual(ContinuousBernoulli(p).sample().size(), (2, 3, 5)) + self.assertEqual(ContinuousBernoulli(p).sample(sample_shape=(2, 5)).size(), + (2, 5, 2, 3, 5)) + self.assertEqual(ContinuousBernoulli(p).sample((2,)).size(), (2, 2, 3, 5)) + def test_independent_shape(self): for Dist, params in EXAMPLES: for param in params: @@ -3271,6 +3320,26 @@ class TestDistributionShapes(TestCase): self.assertRaises(ValueError, laplace.log_prob, self.tensor_sample_2) self.assertEqual(laplace.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2))) + def test_continuous_bernoulli_shape_scalar_params(self): + continuous_bernoulli = ContinuousBernoulli(0.3) + self.assertEqual(continuous_bernoulli._batch_shape, torch.Size()) + self.assertEqual(continuous_bernoulli._event_shape, torch.Size()) + self.assertEqual(continuous_bernoulli.sample().size(), torch.Size()) + self.assertEqual(continuous_bernoulli.sample((3, 2)).size(), torch.Size((3, 2))) + self.assertRaises(ValueError, continuous_bernoulli.log_prob, self.scalar_sample) + self.assertEqual(continuous_bernoulli.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) + self.assertEqual(continuous_bernoulli.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3))) + + def test_continuous_bernoulli_shape_tensor_params(self): + continuous_bernoulli = ContinuousBernoulli(torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]])) + self.assertEqual(continuous_bernoulli._batch_shape, torch.Size((3, 2))) + self.assertEqual(continuous_bernoulli._event_shape, torch.Size(())) + self.assertEqual(continuous_bernoulli.sample().size(), torch.Size((3, 2))) + self.assertEqual(continuous_bernoulli.sample((3, 2)).size(), torch.Size((3, 2, 3, 2))) + self.assertEqual(continuous_bernoulli.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) + self.assertRaises(ValueError, continuous_bernoulli.log_prob, self.tensor_sample_2) + self.assertEqual(continuous_bernoulli.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2))) + class TestKL(TestCase): @@ -3316,6 +3385,7 @@ class TestKL(TestCase): uniform_positive = pairwise(Uniform, [1, 1.5, 2, 4], [1.2, 2.0, 3, 7]) uniform_real = pairwise(Uniform, [-2., -1, 0, 2], [-1., 1, 1, 4]) uniform_pareto = pairwise(Uniform, [6.5, 8.5, 6.5, 8.5], [7.5, 7.5, 9.5, 9.5]) + continuous_bernoulli = pairwise(ContinuousBernoulli, [0.1, 0.2, 0.5, 0.9]) # These tests should pass with precision = 0.01, but that makes tests very expensive. # Instead, we test with precision = 0.1 and only test with higher precision locally @@ -3374,6 +3444,10 @@ class TestKL(TestCase): (uniform_real, gumbel), (uniform_real, normal), (uniform_pareto, pareto), + (continuous_bernoulli, continuous_bernoulli), + (continuous_bernoulli, exponential), + (continuous_bernoulli, normal), + (beta, continuous_bernoulli) ] self.infinite_examples = [ @@ -3429,6 +3503,18 @@ class TestKL(TestCase): (Uniform(-1, 2), Exponential(3)), (Uniform(-1, 2), Gamma(3, 4)), (Uniform(-1, 2), Pareto(3, 4)), + (ContinuousBernoulli(0.25), Uniform(0.25, 1)), + (ContinuousBernoulli(0.25), Uniform(0, 0.75)), + (ContinuousBernoulli(0.25), Uniform(0.25, 0.75)), + (ContinuousBernoulli(0.25), Pareto(1, 2)), + (Exponential(1), ContinuousBernoulli(0.75)), + (Gamma(1, 2), ContinuousBernoulli(0.75)), + (Gumbel(-1, 2), ContinuousBernoulli(0.75)), + (Laplace(-1, 2), ContinuousBernoulli(0.75)), + (Normal(-1, 2), ContinuousBernoulli(0.75)), + (Uniform(-1, 1), ContinuousBernoulli(0.75)), + (Uniform(0, 2), ContinuousBernoulli(0.75)), + (Uniform(-1, 2), ContinuousBernoulli(0.75)) ] def test_kl_monte_carlo(self): @@ -3787,10 +3873,107 @@ class TestNumericalStability(TestCase): log_pdf_prob_0 = multinomial.log_prob(torch.tensor([10, 0], dtype=dtype)) self.assertEqual(log_pdf_prob_0.item(), -inf, allow_inf=True) + def test_continuous_bernoulli_gradient(self): + + def expec_val(x, probs=None, logits=None): + assert not (probs is None and logits is None) + if logits is not None: + probs = 1. / (1. + math.exp(-logits)) + bern_log_lik = x * math.log(probs) + (1. - x) * math.log1p(-probs) + if probs < 0.499 or probs > 0.501: # using default values of lims here + log_norm_const = math.log( + math.fabs(math.atanh(1. - 2. * probs))) - math.log(math.fabs(1. - 2. * probs)) + math.log(2.) + else: + aux = math.pow(probs - 0.5, 2) + log_norm_const = math.log(2.0) + (4.0 / 3.0 + 104.0 / 45.0 * aux) * aux + log_lik = bern_log_lik + log_norm_const + return log_lik + + def expec_grad(x, probs=None, logits=None): + assert not (probs is None and logits is None) + if logits is not None: + probs = 1. / (1. + math.exp(-logits)) + grad_bern_log_lik = x / probs - (1. - x) / (1. - probs) + if probs < 0.499 or probs > 0.501: # using default values of lims here + grad_log_c = 2. * probs - 4. * (probs - 1.) * probs * math.atanh(1. - 2. * probs) - 1. + grad_log_c /= 2. * (probs - 1.) * probs * (2. * probs - 1.) * math.atanh(1. - 2. * probs) + else: + grad_log_c = 8. / 3. * (probs - 0.5) + 416. / 45. * math.pow(probs - 0.5, 3) + grad = grad_bern_log_lik + grad_log_c + if logits is not None: + grad *= 1. / (1. + math.exp(logits)) - 1. / math.pow(1. + math.exp(logits), 2) + return grad + + for tensor_type in [torch.FloatTensor, torch.DoubleTensor]: + self._test_pdf_score(dist_class=ContinuousBernoulli, + probs=tensor_type([0.1]), + x=tensor_type([0.1]), + expected_value=tensor_type([expec_val(0.1, probs=0.1)]), + expected_gradient=tensor_type([expec_grad(0.1, probs=0.1)])) + + self._test_pdf_score(dist_class=ContinuousBernoulli, + probs=tensor_type([0.1]), + x=tensor_type([1.]), + expected_value=tensor_type([expec_val(1., probs=0.1)]), + expected_gradient=tensor_type([expec_grad(1., probs=0.1)])) + + self._test_pdf_score(dist_class=ContinuousBernoulli, + probs=tensor_type([0.4999]), + x=tensor_type([0.9]), + expected_value=tensor_type([expec_val(0.9, probs=0.4999)]), + expected_gradient=tensor_type([expec_grad(0.9, probs=0.4999)])) + + self._test_pdf_score(dist_class=ContinuousBernoulli, + probs=tensor_type([1e-4]), + x=tensor_type([1]), + expected_value=tensor_type([expec_val(1, probs=1e-4)]), + expected_gradient=tensor_type(tensor_type([expec_grad(1, probs=1e-4)])), + prec=1e-3) + + self._test_pdf_score(dist_class=ContinuousBernoulli, + probs=tensor_type([1 - 1e-4]), + x=tensor_type([0.1]), + expected_value=tensor_type([expec_val(0.1, probs=1 - 1e-4)]), + expected_gradient=tensor_type([expec_grad(0.1, probs=1 - 1e-4)]), + prec=2) + + self._test_pdf_score(dist_class=ContinuousBernoulli, + logits=tensor_type([math.log(9999)]), + x=tensor_type([0]), + expected_value=tensor_type([expec_val(0, logits=math.log(9999))]), + expected_gradient=tensor_type([expec_grad(0, logits=math.log(9999))]), + prec=1e-3) + + self._test_pdf_score(dist_class=ContinuousBernoulli, + logits=tensor_type([0.001]), + x=tensor_type([0.5]), + expected_value=tensor_type([expec_val(0.5, logits=0.001)]), + expected_gradient=tensor_type([expec_grad(0.5, logits=0.001)])) + + def test_continuous_bernoulli_with_logits_underflow(self): + for tensor_type, lim, expected in ([(torch.FloatTensor, -1e38, 2.76898), + (torch.DoubleTensor, -1e308, 3.58473)]): + self._test_pdf_score(dist_class=ContinuousBernoulli, + logits=tensor_type([lim]), + x=tensor_type([0]), + expected_value=tensor_type([expected]), + expected_gradient=tensor_type([0.])) + + def test_continuous_bernoulli_with_logits_overflow(self): + for tensor_type, lim, expected in ([(torch.FloatTensor, 1e38, 2.76898), + (torch.DoubleTensor, 1e308, 3.58473)]): + self._test_pdf_score(dist_class=ContinuousBernoulli, + logits=tensor_type([lim]), + x=tensor_type([1]), + expected_value=tensor_type([expected]), + expected_gradient=tensor_type([0.])) + class TestLazyLogitsInitialization(TestCase): def setUp(self): super(TestLazyLogitsInitialization, self).setUp() + # ContinuousBernoulli is not tested because log_prob is not computed simply + # from 'logits', but 'probs' is also needed self.examples = [e for e in EXAMPLES if e.Dist in (Categorical, OneHotCategorical, Bernoulli, Binomial, Multinomial)] diff --git a/torch/distributions/__init__.py b/torch/distributions/__init__.py index 8a16ec10bd2..4d7a4bff96a 100644 --- a/torch/distributions/__init__.py +++ b/torch/distributions/__init__.py @@ -78,6 +78,7 @@ from .categorical import Categorical from .cauchy import Cauchy from .chi2 import Chi2 from .constraint_registry import biject_to, transform_to +from .continuous_bernoulli import ContinuousBernoulli from .dirichlet import Dirichlet from .distribution import Distribution from .exp_family import ExponentialFamily @@ -118,6 +119,7 @@ __all__ = [ 'Categorical', 'Cauchy', 'Chi2', + 'ContinuousBernoulli', 'Dirichlet', 'Distribution', 'Exponential', diff --git a/torch/distributions/continuous_bernoulli.py b/torch/distributions/continuous_bernoulli.py new file mode 100644 index 00000000000..180fbd8187e --- /dev/null +++ b/torch/distributions/continuous_bernoulli.py @@ -0,0 +1,198 @@ +from numbers import Number +import math + +import torch +from torch.distributions import constraints +from torch.distributions.exp_family import ExponentialFamily +from torch.distributions.utils import broadcast_all, probs_to_logits, logits_to_probs, lazy_property, clamp_probs +from torch.nn.functional import binary_cross_entropy_with_logits + + +class ContinuousBernoulli(ExponentialFamily): + r""" + Creates a continuous Bernoulli distribution parameterized by :attr:`probs` + or :attr:`logits` (but not both). + + The distribution is supported in [0, 1] and parameterized by 'probs' (in + (0,1)) or 'logits' (real-valued). Note that, unlike the Bernoulli, 'probs' + does not correspond to a probability and 'logits' does not correspond to + log-odds, but the same names are used due to the similarity with the + Bernoulli. See [1] for more details. + + Example:: + + >>> m = ContinuousBernoulli(torch.tensor([0.3])) + >>> m.sample() + tensor([ 0.2538]) + + Args: + probs (Number, Tensor): (0,1) valued parameters + logits (Number, Tensor): real valued parameters whose sigmoid matches 'probs' + + [1] The continuous Bernoulli: fixing a pervasive error in variational + autoencoders, Loaiza-Ganem G and Cunningham JP, NeurIPS 2019. + https://arxiv.org/abs/1907.06845 + """ + arg_constraints = {'probs': constraints.unit_interval, + 'logits': constraints.real} + support = constraints.unit_interval + _mean_carrier_measure = 0 + has_rsample = True + + def __init__(self, probs=None, logits=None, lims=(0.499, 0.501), validate_args=None): + if (probs is None) == (logits is None): + raise ValueError("Either `probs` or `logits` must be specified, but not both.") + if probs is not None: + is_scalar = isinstance(probs, Number) + self.probs, = broadcast_all(probs) + # validate 'probs' here if necessary as it is later clamped for numerical stability + # close to 0 and 1, later on; otherwise the clamped 'probs' would always pass + if validate_args is not None: + if not self.arg_constraints['probs'].check(getattr(self, 'probs')).all(): + raise ValueError("The parameter {} has invalid values".format('probs')) + self.probs = clamp_probs(self.probs) + else: + is_scalar = isinstance(logits, Number) + self.logits, = broadcast_all(logits) + self._param = self.probs if probs is not None else self.logits + if is_scalar: + batch_shape = torch.Size() + else: + batch_shape = self._param.size() + self._lims = lims + super(ContinuousBernoulli, self).__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(ContinuousBernoulli, _instance) + new._lims = self._lims + batch_shape = torch.Size(batch_shape) + if 'probs' in self.__dict__: + new.probs = self.probs.expand(batch_shape) + new._param = new.probs + if 'logits' in self.__dict__: + new.logits = self.logits.expand(batch_shape) + new._param = new.logits + super(ContinuousBernoulli, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def _new(self, *args, **kwargs): + return self._param.new(*args, **kwargs) + + def _outside_unstable_region(self): + return torch.max(torch.le(self.probs, self._lims[0]), + torch.gt(self.probs, self._lims[1])) + + def _cut_probs(self): + return torch.where(self._outside_unstable_region(), + self.probs, + self._lims[0] * torch.ones_like(self.probs)) + + def _cont_bern_log_norm(self): + '''computes the log normalizing constant as a function of the 'probs' parameter''' + cut_probs = self._cut_probs() + cut_probs_below_half = torch.where(torch.le(cut_probs, 0.5), + cut_probs, + torch.zeros_like(cut_probs)) + cut_probs_above_half = torch.where(torch.ge(cut_probs, 0.5), + cut_probs, + torch.ones_like(cut_probs)) + log_norm = torch.log(torch.abs(torch.log1p(-cut_probs) - torch.log(cut_probs))) - torch.where( + torch.le(cut_probs, 0.5), + torch.log1p(-2.0 * cut_probs_below_half), + torch.log(2.0 * cut_probs_above_half - 1.0)) + x = torch.pow(self.probs - 0.5, 2) + taylor = math.log(2.0) + (4.0 / 3.0 + 104.0 / 45.0 * x) * x + return torch.where(self._outside_unstable_region(), log_norm, taylor) + + @property + def mean(self): + cut_probs = self._cut_probs() + mus = cut_probs / (2.0 * cut_probs - 1.0) + 1.0 / (torch.log1p(-cut_probs) - torch.log(cut_probs)) + x = self.probs - 0.5 + taylor = 0.5 + (1.0 / 3.0 + 16.0 / 45.0 * torch.pow(x, 2)) * x + return torch.where(self._outside_unstable_region(), mus, taylor) + + @property + def stddev(self): + return torch.sqrt(self.variance) + + @property + def variance(self): + cut_probs = self._cut_probs() + vars = cut_probs * (cut_probs - 1.0) / torch.pow(1.0 - 2.0 * cut_probs, 2) + 1.0 / torch.pow( + torch.log1p(-cut_probs) - torch.log(cut_probs), 2) + x = torch.pow(self.probs - 0.5, 2) + taylor = 1.0 / 12.0 - (1.0 / 15.0 - 128. / 945.0 * x) * x + return torch.where(self._outside_unstable_region(), vars, taylor) + + @lazy_property + def logits(self): + return probs_to_logits(self.probs, is_binary=True) + + @lazy_property + def probs(self): + return clamp_probs(logits_to_probs(self.logits, is_binary=True)) + + @property + def param_shape(self): + return self._param.size() + + def sample(self, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device) + with torch.no_grad(): + return self.icdf(u) + + def rsample(self, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device) + return self.icdf(u) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + logits, value = broadcast_all(self.logits, value) + return -binary_cross_entropy_with_logits(logits, value, reduction='none') + self._cont_bern_log_norm() + + def cdf(self, value): + if self._validate_args: + self._validate_sample(value) + cut_probs = self._cut_probs() + cdfs = (torch.pow(cut_probs, value) * torch.pow(1.0 - cut_probs, 1.0 - value) + + cut_probs - 1.0) / (2.0 * cut_probs - 1.0) + unbounded_cdfs = torch.where(self._outside_unstable_region(), cdfs, value) + return torch.where( + torch.le(value, 0.0), + torch.zeros_like(value), + torch.where(torch.ge(value, 1.0), torch.ones_like(value), unbounded_cdfs)) + + def icdf(self, value): + if self._validate_args: + self._validate_sample(value) + cut_probs = self._cut_probs() + return torch.where( + self._outside_unstable_region(), + (torch.log1p(-cut_probs + value * (2.0 * cut_probs - 1.0)) + - torch.log1p(-cut_probs)) / (torch.log(cut_probs) - torch.log1p(-cut_probs)), + value) + + def entropy(self): + log_probs0 = torch.log1p(-self.probs) + log_probs1 = torch.log(self.probs) + return self.mean * (log_probs0 - log_probs1) - self._cont_bern_log_norm() - log_probs0 + + @property + def _natural_params(self): + return (self.logits, ) + + def _log_normalizer(self, x): + """computes the log normalizing constant as a function of the natural parameter""" + out_unst_reg = torch.max(torch.le(x, self._lims[0] - 0.5), + torch.gt(x, self._lims[1] - 0.5)) + cut_nat_params = torch.where(out_unst_reg, + x, + (self._lims[0] - 0.5) * torch.ones_like(x)) + log_norm = torch.log(torch.abs(torch.exp(cut_nat_params) - 1.0)) - torch.log(torch.abs(cut_nat_params)) + taylor = 0.5 * x + torch.pow(x, 2) / 24.0 - torch.pow(x, 4) / 2880.0 + return torch.where(out_unst_reg, log_norm, taylor) diff --git a/torch/distributions/kl.py b/torch/distributions/kl.py index 365715d9310..627896742d2 100644 --- a/torch/distributions/kl.py +++ b/torch/distributions/kl.py @@ -9,6 +9,7 @@ from .bernoulli import Bernoulli from .beta import Beta from .binomial import Binomial from .categorical import Categorical +from .continuous_bernoulli import ContinuousBernoulli from .dirichlet import Dirichlet from .distribution import Distribution from .exponential import Exponential @@ -217,6 +218,14 @@ def _kl_categorical_categorical(p, q): return t.sum(-1) +@register_kl(ContinuousBernoulli, ContinuousBernoulli) +def _kl_continuous_bernoulli_continuous_bernoulli(p, q): + t1 = p.mean * (p.logits - q.logits) + t2 = p._cont_bern_log_norm() + torch.log1p(-p.probs) + t3 = - q._cont_bern_log_norm() - torch.log1p(-q.probs) + return t1 + t2 + t3 + + @register_kl(Dirichlet, Dirichlet) def _kl_dirichlet_dirichlet(p, q): # From http://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/ @@ -445,6 +454,11 @@ def _kl_bernoulli_poisson(p, q): return -p.entropy() - (p.probs * q.rate.log() - q.rate) +@register_kl(Beta, ContinuousBernoulli) +def _kl_beta_continuous_bernoulli(p, q): + return -p.entropy() - p.mean * q.logits - torch.log1p(-q.probs) - q._cont_bern_log_norm() + + @register_kl(Beta, Pareto) def _kl_beta_infinity(p, q): return _infinite_like(p.concentration1) @@ -484,8 +498,40 @@ def _kl_beta_uniform(p, q): result[(q.low > p.support.lower_bound) | (q.high < p.support.upper_bound)] = inf return result +# Note that the KL between a ContinuousBernoulli and Beta has no closed form + + +@register_kl(ContinuousBernoulli, Pareto) +def _kl_continuous_bernoulli_infinity(p, q): + return _infinite_like(p.probs) + + +@register_kl(ContinuousBernoulli, Exponential) +def _kl_continuous_bernoulli_exponential(p, q): + return -p.entropy() - torch.log(q.rate) + q.rate * p.mean + +# Note that the KL between a ContinuousBernoulli and Gamma has no closed form +# TODO: Add ContinuousBernoulli-Laplace KL Divergence + + +@register_kl(ContinuousBernoulli, Normal) +def _kl_continuous_bernoulli_normal(p, q): + t1 = -p.entropy() + t2 = 0.5 * (math.log(2. * math.pi) + torch.square(q.loc / q.scale)) + torch.log(q.scale) + t3 = (p.variance + torch.square(p.mean) - 2. * q.loc * p.mean) / (2.0 * torch.square(q.scale)) + return t1 + t2 + t3 + + +@register_kl(ContinuousBernoulli, Uniform) +def _kl_continuous_bernoulli_uniform(p, q): + result = -p.entropy() + (q.high - q.low).log() + return torch.where(torch.max(torch.ge(q.low, p.support.lower_bound), + torch.le(q.high, p.support.upper_bound)), + torch.ones_like(result) * inf, result) + @register_kl(Exponential, Beta) +@register_kl(Exponential, ContinuousBernoulli) @register_kl(Exponential, Pareto) @register_kl(Exponential, Uniform) def _kl_exponential_infinity(p, q): @@ -523,6 +569,7 @@ def _kl_exponential_normal(p, q): @register_kl(Gamma, Beta) +@register_kl(Gamma, ContinuousBernoulli) @register_kl(Gamma, Pareto) @register_kl(Gamma, Uniform) def _kl_gamma_infinity(p, q): @@ -558,6 +605,7 @@ def _kl_gamma_normal(p, q): @register_kl(Gumbel, Beta) +@register_kl(Gumbel, ContinuousBernoulli) @register_kl(Gumbel, Exponential) @register_kl(Gumbel, Gamma) @register_kl(Gumbel, Pareto) @@ -578,6 +626,7 @@ def _kl_gumbel_normal(p, q): @register_kl(Laplace, Beta) +@register_kl(Laplace, ContinuousBernoulli) @register_kl(Laplace, Exponential) @register_kl(Laplace, Gamma) @register_kl(Laplace, Pareto) @@ -598,6 +647,7 @@ def _kl_laplace_normal(p, q): @register_kl(Normal, Beta) +@register_kl(Normal, ContinuousBernoulli) @register_kl(Normal, Exponential) @register_kl(Normal, Gamma) @register_kl(Normal, Pareto) @@ -620,6 +670,7 @@ def _kl_normal_gumbel(p, q): @register_kl(Pareto, Beta) +@register_kl(Pareto, ContinuousBernoulli) @register_kl(Pareto, Uniform) def _kl_pareto_infinity(p, q): return _infinite_like(p.scale) @@ -681,6 +732,14 @@ def _kl_uniform_beta(p, q): return result +@register_kl(Uniform, ContinuousBernoulli) +def _kl_uniform_continuous_bernoulli(p, q): + result = -p.entropy() - p.mean * q.logits - torch.log1p(-q.probs) - q._cont_bern_log_norm() + return torch.where(torch.max(torch.ge(p.high, q.support.upper_bound), + torch.le(p.low, q.support.lower_bound)), + torch.ones_like(result) * inf, result) + + @register_kl(Uniform, Exponential) def _kl_uniform_exponetial(p, q): result = q.rate * (p.high + p.low) / 2 - ((p.high - p.low) * q.rate).log()