add type annotation to distributions.kl_divergence (#78432)

Fixes #78431

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78432
Approved by: https://github.com/fritzo, https://github.com/ejguan
This commit is contained in:
Tongzhou Wang 2022-06-10 13:39:20 +00:00 committed by PyTorch MergeBot
parent 77b6885a22
commit dd620c4575

View file

@ -141,7 +141,7 @@ def _batch_trace_XXT(bmat):
return flat_trace.reshape(bmat.shape[:-2])
def kl_divergence(p, q):
def kl_divergence(p: Distribution, q: Distribution) -> torch.Tensor:
r"""
Compute Kullback-Leibler divergence :math:`KL(p \| q)` between two distributions.