From 313e77fc06aded10ea07f9807f8736ec033cb574 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 3 Dec 2020 02:40:23 -0800 Subject: [PATCH] Add broadcast_shapes() function and use it in MultivariateNormal (#43935) Summary: Fixes https://github.com/pytorch/pytorch/issues/43837 This adds a `torch.broadcast_shapes()` function similar to Pyro's [broadcast_shape()](https://github.com/pyro-ppl/pyro/blob/7c2c22c10dffda8a33ffbd593cc8d58819959e40/pyro/distributions/util.py#L151) and JAX's [lax.broadcast_shapes()](https://jax.readthedocs.io/en/test-docs/_modules/jax/lax/lax.html). This helper is useful e.g. in multivariate distributions that are parameterized by multiple tensors and we want to `torch.broadcast_tensors()` but the parameter tensors have different "event shape" (e.g. mean vectors and covariance matrices). This helper is already heavily used in Pyro's distribution codebase, and we would like to start using it in `torch.distributions`. - [x] refactor `MultivariateNormal`'s expansion logic to use `torch.broadcast_shapes()` - [x] add unit tests for `torch.broadcast_shapes()` - [x] add docs cc neerajprad Pull Request resolved: https://github.com/pytorch/pytorch/pull/43935 Reviewed By: bdhirsh Differential Revision: D25275213 Pulled By: neerajprad fbshipit-source-id: 1011fdd597d0a7a4ef744ebc359bbb3c3be2aadc --- docs/source/torch.rst | 1 + test/test_view_ops.py | 17 +++++++++++ torch/distributions/multivariate_normal.py | 14 +++++---- torch/functional.py | 34 ++++++++++++++++++++++ torch/overrides.py | 1 + 5 files changed, 61 insertions(+), 6 deletions(-) diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 4399e63c3b0..b16f14de9bf 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -447,6 +447,7 @@ Other Operations bincount block_diag broadcast_tensors + broadcast_shapes bucketize cartesian_prod cdist diff --git a/test/test_view_ops.py b/test/test_view_ops.py index 6722a55588e..d4e59a3dbf2 100644 --- a/test/test_view_ops.py +++ b/test/test_view_ops.py @@ -1021,6 +1021,23 @@ class TestOldViewOps(TestCase): self.assertTrue(y1.size() == expected_size) self.assertTrue(y2.size() == expected_size) + + @onlyCPU + def test_broadcast_shapes(self, device): + examples = [(), (1,), (2,), (1, 1), (3, 1), (3, 2), (4, 1, 1), (4, 3, 2)] + for s0 in examples: + x0 = torch.randn(s0) + expected = torch.broadcast_tensors(x0)[0].shape + actual = torch.broadcast_shapes(s0) + self.assertEqual(expected, actual) + + for s1 in examples: + x1 = torch.randn(s1) + expected = torch.broadcast_tensors(x0, x1)[0].shape + actual = torch.broadcast_shapes(s0, s1) + self.assertEqual(expected, actual) + + def test_view(self, device): tensor = torch.rand(15, device=device) template = torch.rand(3, 5, device=device) diff --git a/torch/distributions/multivariate_normal.py b/torch/distributions/multivariate_normal.py index de997f49a94..4845d4742df 100644 --- a/torch/distributions/multivariate_normal.py +++ b/torch/distributions/multivariate_normal.py @@ -122,25 +122,27 @@ class MultivariateNormal(Distribution): if (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) != 1: raise ValueError("Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified.") - loc_ = loc.unsqueeze(-1) # temporarily add dim on right if scale_tril is not None: if scale_tril.dim() < 2: raise ValueError("scale_tril matrix must be at least two-dimensional, " "with optional leading batch dimensions") - self.scale_tril, loc_ = torch.broadcast_tensors(scale_tril, loc_) + batch_shape = torch.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1]) + self.scale_tril = scale_tril.expand(batch_shape + (-1, -1)) elif covariance_matrix is not None: if covariance_matrix.dim() < 2: raise ValueError("covariance_matrix must be at least two-dimensional, " "with optional leading batch dimensions") - self.covariance_matrix, loc_ = torch.broadcast_tensors(covariance_matrix, loc_) + batch_shape = torch.broadcast_shapes(covariance_matrix.shape[:-2], loc.shape[:-1]) + self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1)) else: if precision_matrix.dim() < 2: raise ValueError("precision_matrix must be at least two-dimensional, " "with optional leading batch dimensions") - self.precision_matrix, loc_ = torch.broadcast_tensors(precision_matrix, loc_) - self.loc = loc_[..., 0] # drop rightmost dim + batch_shape = torch.broadcast_shapes(precision_matrix.shape[:-2], loc.shape[:-1]) + self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1)) + self.loc = loc.expand(batch_shape + (-1,)) - batch_shape, event_shape = self.loc.shape[:-1], self.loc.shape[-1:] + event_shape = self.loc.shape[-1:] super(MultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=validate_args) if scale_tril is not None: diff --git a/torch/functional.py b/torch/functional.py index 29af0b662cc..62076a9dc29 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -19,6 +19,7 @@ __all__ = [ 'atleast_2d', 'atleast_3d', 'align_tensors', + 'broadcast_shapes', 'broadcast_tensors', 'cartesian_prod', 'block_diag', @@ -72,6 +73,39 @@ def broadcast_tensors(*tensors): return _VF.broadcast_tensors(tensors) # type: ignore +def broadcast_shapes(*shapes): + r"""broadcast_shapes(*shapes) -> Size + + Similar to :func:`broadcast_tensors` but for shapes. + + This is equivalent to + ``torch.broadcast_tensors(*map(torch.empty, shapes))[0].shape`` + but avoids the need create to intermediate tensors. This is useful for + broadcasting tensors of common batch shape but different rightmost shape, + e.g. to broadcast mean vectors with covariance matrices. + + Example:: + + >>> torch.broadcast_shapes((2,), (3, 1), (1, 1, 1)) + torch.Size([1, 3, 2]) + + Args: + \*shapes (torch.Size): Shapes of tensors. + + Returns: + shape (torch.Size): A shape compatible with all input shapes. + + Raises: + RuntimeError: If shapes are incompatible. + """ + # TODO Movie this to C++ once the jit has better support for torch.Size. + with torch.no_grad(): + scalar = torch.zeros((), device="cpu") + tensors = [scalar.expand(shape) for shape in shapes] + tensors = broadcast_tensors(*tensors) + return tensors[0].shape + + def split(tensor, split_size_or_sections, dim=0): r"""Splits the tensor into chunks. Each chunk is a view of the original tensor. diff --git a/torch/overrides.py b/torch/overrides.py index 36ae037ed55..f6a49376ab5 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -119,6 +119,7 @@ def get_ignored_functions() -> Set[Callable]: torch.as_strided, torch.bartlett_window, torch.blackman_window, + torch.broadcast_shapes, torch.can_cast, torch.cudnn_affine_grid_generator, torch.cudnn_batch_norm,