diff --git a/torch/distributions/kl.py b/torch/distributions/kl.py index ea85658fc01..8a39dbe5f67 100644 --- a/torch/distributions/kl.py +++ b/torch/distributions/kl.py @@ -141,7 +141,7 @@ def _batch_trace_XXT(bmat): return flat_trace.reshape(bmat.shape[:-2]) -def kl_divergence(p, q): +def kl_divergence(p: Distribution, q: Distribution) -> torch.Tensor: r""" Compute Kullback-Leibler divergence :math:`KL(p \| q)` between two distributions.