mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
add type annotation to distributions.kl_divergence (#78432)
Fixes #78431 Pull Request resolved: https://github.com/pytorch/pytorch/pull/78432 Approved by: https://github.com/fritzo, https://github.com/ejguan
This commit is contained in:
parent
77b6885a22
commit
dd620c4575
1 changed files with 1 additions and 1 deletions
|
|
@ -141,7 +141,7 @@ def _batch_trace_XXT(bmat):
|
|||
return flat_trace.reshape(bmat.shape[:-2])
|
||||
|
||||
|
||||
def kl_divergence(p, q):
|
||||
def kl_divergence(p: Distribution, q: Distribution) -> torch.Tensor:
|
||||
r"""
|
||||
Compute Kullback-Leibler divergence :math:`KL(p \| q)` between two distributions.
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue