mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Fix Dirichlet.log_prob() when x=0 and alpha=1 (#103605)
`Dirichlet.log_prob()` incorrectly returns NaN in the case where $x_i=0$ and $\alpha_i=1$. The Dirichlet PDF is given by:
$$\frac{1}{B(\alpha)} \prod_{i=1}^{K} x_i^{\alpha_i - 1}$$
So this corresponds to the case where one of the terms has the form $0^0=1$. The logarithm of such a term should be 0, but you get NaN if you try to calculate it as `0 * log(0)`.
This PR implements the same algorithm that `scipy.stats.dirichlet` uses to avoid this behavior, namely `xlogy(alpha - 1, x)` instead of `(alpha - 1) * log(x)`. It also adds a test case comparing the pytorch and scipy implementations for this specific case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103605
Approved by: https://github.com/albanD
This commit is contained in:
parent
2f893d04c8
commit
e75f7994e1
2 changed files with 15 additions and 1 deletions
|
|
@ -2773,6 +2773,20 @@ class TestDistributions(DistributionsTestCase):
|
|||
expected_log_prob = scipy.stats.dirichlet.logpdf(x[i].numpy(), alpha.numpy())
|
||||
self.assertEqual(actual_log_prob[i], expected_log_prob, atol=1e-3, rtol=0)
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
|
||||
def test_dirichlet_log_prob_zero(self):
|
||||
# Specifically test the special case where x=0 and α=1. The PDF is
|
||||
# proportional to x**(α-1), which in this case works out to 0**0=1.
|
||||
# The log PDF of this term should therefore be 0. However, it's easy
|
||||
# to accidentally introduce NaNs by calculating log(x) without regard
|
||||
# for the value of α-1.
|
||||
alpha = torch.tensor([1, 2])
|
||||
dist = Dirichlet(alpha)
|
||||
x = torch.tensor([0, 1])
|
||||
actual_log_prob = dist.log_prob(x)
|
||||
expected_log_prob = scipy.stats.dirichlet.logpdf(x.numpy(), alpha.numpy())
|
||||
self.assertEqual(actual_log_prob, expected_log_prob, atol=1e-3, rtol=0)
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
|
||||
def test_dirichlet_sample(self):
|
||||
set_rng_seed(0) # see Note [Randomized statistical tests]
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ class Dirichlet(ExponentialFamily):
|
|||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
return ((torch.log(value) * (self.concentration - 1.0)).sum(-1) +
|
||||
return (torch.xlogy(self.concentration - 1.0, value).sum(-1) +
|
||||
torch.lgamma(self.concentration.sum(-1)) -
|
||||
torch.lgamma(self.concentration).sum(-1))
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue