From e75f7994e1fb60c5055477360bbe41dd67de8aff Mon Sep 17 00:00:00 2001 From: Kale Kundert Date: Thu, 15 Jun 2023 16:16:47 +0000 Subject: [PATCH] Fix `Dirichlet.log_prob()` when x=0 and alpha=1 (#103605) `Dirichlet.log_prob()` incorrectly returns NaN in the case where $x_i=0$ and $\alpha_i=1$. The Dirichlet PDF is given by: $$\frac{1}{B(\alpha)} \prod_{i=1}^{K} x_i^{\alpha_i - 1}$$ So this corresponds to the case where one of the terms has the form $0^0=1$. The logarithm of such a term should be 0, but you get NaN if you try to calculate it as `0 * log(0)`. This PR implements the same algorithm that `scipy.stats.dirichlet` uses to avoid this behavior, namely `xlogy(alpha - 1, x)` instead of `(alpha - 1) * log(x)`. It also adds a test case comparing the pytorch and scipy implementations for this specific case. Pull Request resolved: https://github.com/pytorch/pytorch/pull/103605 Approved by: https://github.com/albanD --- test/distributions/test_distributions.py | 14 ++++++++++++++ torch/distributions/dirichlet.py | 2 +- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/test/distributions/test_distributions.py b/test/distributions/test_distributions.py index 7fbf4e88e0e..69591d31c5e 100644 --- a/test/distributions/test_distributions.py +++ b/test/distributions/test_distributions.py @@ -2773,6 +2773,20 @@ class TestDistributions(DistributionsTestCase): expected_log_prob = scipy.stats.dirichlet.logpdf(x[i].numpy(), alpha.numpy()) self.assertEqual(actual_log_prob[i], expected_log_prob, atol=1e-3, rtol=0) + @unittest.skipIf(not TEST_NUMPY, "NumPy not found") + def test_dirichlet_log_prob_zero(self): + # Specifically test the special case where x=0 and α=1. The PDF is + # proportional to x**(α-1), which in this case works out to 0**0=1. + # The log PDF of this term should therefore be 0. However, it's easy + # to accidentally introduce NaNs by calculating log(x) without regard + # for the value of α-1. + alpha = torch.tensor([1, 2]) + dist = Dirichlet(alpha) + x = torch.tensor([0, 1]) + actual_log_prob = dist.log_prob(x) + expected_log_prob = scipy.stats.dirichlet.logpdf(x.numpy(), alpha.numpy()) + self.assertEqual(actual_log_prob, expected_log_prob, atol=1e-3, rtol=0) + @unittest.skipIf(not TEST_NUMPY, "NumPy not found") def test_dirichlet_sample(self): set_rng_seed(0) # see Note [Randomized statistical tests] diff --git a/torch/distributions/dirichlet.py b/torch/distributions/dirichlet.py index 1612e37f42e..0a38ff50c26 100644 --- a/torch/distributions/dirichlet.py +++ b/torch/distributions/dirichlet.py @@ -69,7 +69,7 @@ class Dirichlet(ExponentialFamily): def log_prob(self, value): if self._validate_args: self._validate_sample(value) - return ((torch.log(value) * (self.concentration - 1.0)).sum(-1) + + return (torch.xlogy(self.concentration - 1.0, value).sum(-1) + torch.lgamma(self.concentration.sum(-1)) - torch.lgamma(self.concentration).sum(-1))