Fix Dirichlet.log_prob() when x=0 and alpha=1 (#103605)

`Dirichlet.log_prob()` incorrectly returns NaN in the case where $x_i=0$ and $\alpha_i=1$.  The Dirichlet PDF is given by:
$$\frac{1}{B(\alpha)} \prod_{i=1}^{K} x_i^{\alpha_i - 1}$$
So this corresponds to the case where one of the terms has the form $0^0=1$. The logarithm of such a term should be 0, but you get NaN if you try to calculate it as `0 * log(0)`.

This PR implements the same algorithm that `scipy.stats.dirichlet` uses to avoid this behavior, namely `xlogy(alpha - 1, x)` instead of `(alpha - 1) * log(x)`.  It also adds a test case comparing the pytorch and scipy implementations for this specific case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103605
Approved by: https://github.com/albanD
This commit is contained in:
Kale Kundert 2023-06-15 16:16:47 +00:00 committed by PyTorch MergeBot
parent 2f893d04c8
commit e75f7994e1
2 changed files with 15 additions and 1 deletions

View file

@ -2773,6 +2773,20 @@ class TestDistributions(DistributionsTestCase):
expected_log_prob = scipy.stats.dirichlet.logpdf(x[i].numpy(), alpha.numpy())
self.assertEqual(actual_log_prob[i], expected_log_prob, atol=1e-3, rtol=0)
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_dirichlet_log_prob_zero(self):
# Specifically test the special case where x=0 and α=1. The PDF is
# proportional to x**(α-1), which in this case works out to 0**0=1.
# The log PDF of this term should therefore be 0. However, it's easy
# to accidentally introduce NaNs by calculating log(x) without regard
# for the value of α-1.
alpha = torch.tensor([1, 2])
dist = Dirichlet(alpha)
x = torch.tensor([0, 1])
actual_log_prob = dist.log_prob(x)
expected_log_prob = scipy.stats.dirichlet.logpdf(x.numpy(), alpha.numpy())
self.assertEqual(actual_log_prob, expected_log_prob, atol=1e-3, rtol=0)
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_dirichlet_sample(self):
set_rng_seed(0) # see Note [Randomized statistical tests]

View file

@ -69,7 +69,7 @@ class Dirichlet(ExponentialFamily):
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
return ((torch.log(value) * (self.concentration - 1.0)).sum(-1) +
return (torch.xlogy(self.concentration - 1.0, value).sum(-1) +
torch.lgamma(self.concentration.sum(-1)) -
torch.lgamma(self.concentration).sum(-1))