diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 4399e63c3b0..b16f14de9bf 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -447,6 +447,7 @@ Other Operations bincount block_diag broadcast_tensors + broadcast_shapes bucketize cartesian_prod cdist diff --git a/test/test_view_ops.py b/test/test_view_ops.py index 6722a55588e..d4e59a3dbf2 100644 --- a/test/test_view_ops.py +++ b/test/test_view_ops.py @@ -1021,6 +1021,23 @@ class TestOldViewOps(TestCase): self.assertTrue(y1.size() == expected_size) self.assertTrue(y2.size() == expected_size) + + @onlyCPU + def test_broadcast_shapes(self, device): + examples = [(), (1,), (2,), (1, 1), (3, 1), (3, 2), (4, 1, 1), (4, 3, 2)] + for s0 in examples: + x0 = torch.randn(s0) + expected = torch.broadcast_tensors(x0)[0].shape + actual = torch.broadcast_shapes(s0) + self.assertEqual(expected, actual) + + for s1 in examples: + x1 = torch.randn(s1) + expected = torch.broadcast_tensors(x0, x1)[0].shape + actual = torch.broadcast_shapes(s0, s1) + self.assertEqual(expected, actual) + + def test_view(self, device): tensor = torch.rand(15, device=device) template = torch.rand(3, 5, device=device) diff --git a/torch/distributions/multivariate_normal.py b/torch/distributions/multivariate_normal.py index de997f49a94..4845d4742df 100644 --- a/torch/distributions/multivariate_normal.py +++ b/torch/distributions/multivariate_normal.py @@ -122,25 +122,27 @@ class MultivariateNormal(Distribution): if (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) != 1: raise ValueError("Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified.") - loc_ = loc.unsqueeze(-1) # temporarily add dim on right if scale_tril is not None: if scale_tril.dim() < 2: raise ValueError("scale_tril matrix must be at least two-dimensional, " "with optional leading batch dimensions") - self.scale_tril, loc_ = torch.broadcast_tensors(scale_tril, loc_) + batch_shape = torch.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1]) + self.scale_tril = scale_tril.expand(batch_shape + (-1, -1)) elif covariance_matrix is not None: if covariance_matrix.dim() < 2: raise ValueError("covariance_matrix must be at least two-dimensional, " "with optional leading batch dimensions") - self.covariance_matrix, loc_ = torch.broadcast_tensors(covariance_matrix, loc_) + batch_shape = torch.broadcast_shapes(covariance_matrix.shape[:-2], loc.shape[:-1]) + self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1)) else: if precision_matrix.dim() < 2: raise ValueError("precision_matrix must be at least two-dimensional, " "with optional leading batch dimensions") - self.precision_matrix, loc_ = torch.broadcast_tensors(precision_matrix, loc_) - self.loc = loc_[..., 0] # drop rightmost dim + batch_shape = torch.broadcast_shapes(precision_matrix.shape[:-2], loc.shape[:-1]) + self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1)) + self.loc = loc.expand(batch_shape + (-1,)) - batch_shape, event_shape = self.loc.shape[:-1], self.loc.shape[-1:] + event_shape = self.loc.shape[-1:] super(MultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=validate_args) if scale_tril is not None: diff --git a/torch/functional.py b/torch/functional.py index 29af0b662cc..62076a9dc29 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -19,6 +19,7 @@ __all__ = [ 'atleast_2d', 'atleast_3d', 'align_tensors', + 'broadcast_shapes', 'broadcast_tensors', 'cartesian_prod', 'block_diag', @@ -72,6 +73,39 @@ def broadcast_tensors(*tensors): return _VF.broadcast_tensors(tensors) # type: ignore +def broadcast_shapes(*shapes): + r"""broadcast_shapes(*shapes) -> Size + + Similar to :func:`broadcast_tensors` but for shapes. + + This is equivalent to + ``torch.broadcast_tensors(*map(torch.empty, shapes))[0].shape`` + but avoids the need create to intermediate tensors. This is useful for + broadcasting tensors of common batch shape but different rightmost shape, + e.g. to broadcast mean vectors with covariance matrices. + + Example:: + + >>> torch.broadcast_shapes((2,), (3, 1), (1, 1, 1)) + torch.Size([1, 3, 2]) + + Args: + \*shapes (torch.Size): Shapes of tensors. + + Returns: + shape (torch.Size): A shape compatible with all input shapes. + + Raises: + RuntimeError: If shapes are incompatible. + """ + # TODO Movie this to C++ once the jit has better support for torch.Size. + with torch.no_grad(): + scalar = torch.zeros((), device="cpu") + tensors = [scalar.expand(shape) for shape in shapes] + tensors = broadcast_tensors(*tensors) + return tensors[0].shape + + def split(tensor, split_size_or_sections, dim=0): r"""Splits the tensor into chunks. Each chunk is a view of the original tensor. diff --git a/torch/overrides.py b/torch/overrides.py index 36ae037ed55..f6a49376ab5 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -119,6 +119,7 @@ def get_ignored_functions() -> Set[Callable]: torch.as_strided, torch.bartlett_window, torch.blackman_window, + torch.broadcast_shapes, torch.can_cast, torch.cudnn_affine_grid_generator, torch.cudnn_batch_norm,