Add broadcast_shapes() function and use it in MultivariateNormal (#43935)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/43837

This adds a `torch.broadcast_shapes()` function similar to Pyro's [broadcast_shape()](7c2c22c10d/pyro/distributions/util.py (L151)) and JAX's [lax.broadcast_shapes()](https://jax.readthedocs.io/en/test-docs/_modules/jax/lax/lax.html). This helper is useful e.g. in multivariate distributions that are parameterized by multiple tensors and we want to `torch.broadcast_tensors()` but the parameter tensors have different "event shape" (e.g. mean vectors and covariance matrices). This helper is already heavily used in Pyro's distribution codebase, and we would like to start using it in `torch.distributions`.

- [x] refactor `MultivariateNormal`'s expansion logic to use `torch.broadcast_shapes()`
- [x] add unit tests for `torch.broadcast_shapes()`
- [x] add docs

cc neerajprad

Pull Request resolved: https://github.com/pytorch/pytorch/pull/43935

Reviewed By: bdhirsh

Differential Revision: D25275213

Pulled By: neerajprad

fbshipit-source-id: 1011fdd597d0a7a4ef744ebc359bbb3c3be2aadc
This commit is contained in:
Fritz Obermeyer 2020-12-03 02:40:23 -08:00 committed by Facebook GitHub Bot
parent c7746adbc6
commit 313e77fc06
5 changed files with 61 additions and 6 deletions

View file

@ -447,6 +447,7 @@ Other Operations
bincount
block_diag
broadcast_tensors
broadcast_shapes
bucketize
cartesian_prod
cdist

View file

@ -1021,6 +1021,23 @@ class TestOldViewOps(TestCase):
self.assertTrue(y1.size() == expected_size)
self.assertTrue(y2.size() == expected_size)
@onlyCPU
def test_broadcast_shapes(self, device):
examples = [(), (1,), (2,), (1, 1), (3, 1), (3, 2), (4, 1, 1), (4, 3, 2)]
for s0 in examples:
x0 = torch.randn(s0)
expected = torch.broadcast_tensors(x0)[0].shape
actual = torch.broadcast_shapes(s0)
self.assertEqual(expected, actual)
for s1 in examples:
x1 = torch.randn(s1)
expected = torch.broadcast_tensors(x0, x1)[0].shape
actual = torch.broadcast_shapes(s0, s1)
self.assertEqual(expected, actual)
def test_view(self, device):
tensor = torch.rand(15, device=device)
template = torch.rand(3, 5, device=device)

View file

@ -122,25 +122,27 @@ class MultivariateNormal(Distribution):
if (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) != 1:
raise ValueError("Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified.")
loc_ = loc.unsqueeze(-1) # temporarily add dim on right
if scale_tril is not None:
if scale_tril.dim() < 2:
raise ValueError("scale_tril matrix must be at least two-dimensional, "
"with optional leading batch dimensions")
self.scale_tril, loc_ = torch.broadcast_tensors(scale_tril, loc_)
batch_shape = torch.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1])
self.scale_tril = scale_tril.expand(batch_shape + (-1, -1))
elif covariance_matrix is not None:
if covariance_matrix.dim() < 2:
raise ValueError("covariance_matrix must be at least two-dimensional, "
"with optional leading batch dimensions")
self.covariance_matrix, loc_ = torch.broadcast_tensors(covariance_matrix, loc_)
batch_shape = torch.broadcast_shapes(covariance_matrix.shape[:-2], loc.shape[:-1])
self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1))
else:
if precision_matrix.dim() < 2:
raise ValueError("precision_matrix must be at least two-dimensional, "
"with optional leading batch dimensions")
self.precision_matrix, loc_ = torch.broadcast_tensors(precision_matrix, loc_)
self.loc = loc_[..., 0] # drop rightmost dim
batch_shape = torch.broadcast_shapes(precision_matrix.shape[:-2], loc.shape[:-1])
self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1))
self.loc = loc.expand(batch_shape + (-1,))
batch_shape, event_shape = self.loc.shape[:-1], self.loc.shape[-1:]
event_shape = self.loc.shape[-1:]
super(MultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=validate_args)
if scale_tril is not None:

View file

@ -19,6 +19,7 @@ __all__ = [
'atleast_2d',
'atleast_3d',
'align_tensors',
'broadcast_shapes',
'broadcast_tensors',
'cartesian_prod',
'block_diag',
@ -72,6 +73,39 @@ def broadcast_tensors(*tensors):
return _VF.broadcast_tensors(tensors) # type: ignore
def broadcast_shapes(*shapes):
r"""broadcast_shapes(*shapes) -> Size
Similar to :func:`broadcast_tensors` but for shapes.
This is equivalent to
``torch.broadcast_tensors(*map(torch.empty, shapes))[0].shape``
but avoids the need create to intermediate tensors. This is useful for
broadcasting tensors of common batch shape but different rightmost shape,
e.g. to broadcast mean vectors with covariance matrices.
Example::
>>> torch.broadcast_shapes((2,), (3, 1), (1, 1, 1))
torch.Size([1, 3, 2])
Args:
\*shapes (torch.Size): Shapes of tensors.
Returns:
shape (torch.Size): A shape compatible with all input shapes.
Raises:
RuntimeError: If shapes are incompatible.
"""
# TODO Movie this to C++ once the jit has better support for torch.Size.
with torch.no_grad():
scalar = torch.zeros((), device="cpu")
tensors = [scalar.expand(shape) for shape in shapes]
tensors = broadcast_tensors(*tensors)
return tensors[0].shape
def split(tensor, split_size_or_sections, dim=0):
r"""Splits the tensor into chunks. Each chunk is a view of the original tensor.

View file

@ -119,6 +119,7 @@ def get_ignored_functions() -> Set[Callable]:
torch.as_strided,
torch.bartlett_window,
torch.blackman_window,
torch.broadcast_shapes,
torch.can_cast,
torch.cudnn_affine_grid_generator,
torch.cudnn_batch_norm,