mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add broadcast_shapes() function and use it in MultivariateNormal (#43935)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/43837
This adds a `torch.broadcast_shapes()` function similar to Pyro's [broadcast_shape()](7c2c22c10d/pyro/distributions/util.py (L151)) and JAX's [lax.broadcast_shapes()](https://jax.readthedocs.io/en/test-docs/_modules/jax/lax/lax.html). This helper is useful e.g. in multivariate distributions that are parameterized by multiple tensors and we want to `torch.broadcast_tensors()` but the parameter tensors have different "event shape" (e.g. mean vectors and covariance matrices). This helper is already heavily used in Pyro's distribution codebase, and we would like to start using it in `torch.distributions`.
- [x] refactor `MultivariateNormal`'s expansion logic to use `torch.broadcast_shapes()`
- [x] add unit tests for `torch.broadcast_shapes()`
- [x] add docs
cc neerajprad
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43935
Reviewed By: bdhirsh
Differential Revision: D25275213
Pulled By: neerajprad
fbshipit-source-id: 1011fdd597d0a7a4ef744ebc359bbb3c3be2aadc
This commit is contained in:
parent
c7746adbc6
commit
313e77fc06
5 changed files with 61 additions and 6 deletions
|
|
@ -447,6 +447,7 @@ Other Operations
|
|||
bincount
|
||||
block_diag
|
||||
broadcast_tensors
|
||||
broadcast_shapes
|
||||
bucketize
|
||||
cartesian_prod
|
||||
cdist
|
||||
|
|
|
|||
|
|
@ -1021,6 +1021,23 @@ class TestOldViewOps(TestCase):
|
|||
self.assertTrue(y1.size() == expected_size)
|
||||
self.assertTrue(y2.size() == expected_size)
|
||||
|
||||
|
||||
@onlyCPU
|
||||
def test_broadcast_shapes(self, device):
|
||||
examples = [(), (1,), (2,), (1, 1), (3, 1), (3, 2), (4, 1, 1), (4, 3, 2)]
|
||||
for s0 in examples:
|
||||
x0 = torch.randn(s0)
|
||||
expected = torch.broadcast_tensors(x0)[0].shape
|
||||
actual = torch.broadcast_shapes(s0)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
for s1 in examples:
|
||||
x1 = torch.randn(s1)
|
||||
expected = torch.broadcast_tensors(x0, x1)[0].shape
|
||||
actual = torch.broadcast_shapes(s0, s1)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
|
||||
def test_view(self, device):
|
||||
tensor = torch.rand(15, device=device)
|
||||
template = torch.rand(3, 5, device=device)
|
||||
|
|
|
|||
|
|
@ -122,25 +122,27 @@ class MultivariateNormal(Distribution):
|
|||
if (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) != 1:
|
||||
raise ValueError("Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified.")
|
||||
|
||||
loc_ = loc.unsqueeze(-1) # temporarily add dim on right
|
||||
if scale_tril is not None:
|
||||
if scale_tril.dim() < 2:
|
||||
raise ValueError("scale_tril matrix must be at least two-dimensional, "
|
||||
"with optional leading batch dimensions")
|
||||
self.scale_tril, loc_ = torch.broadcast_tensors(scale_tril, loc_)
|
||||
batch_shape = torch.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1])
|
||||
self.scale_tril = scale_tril.expand(batch_shape + (-1, -1))
|
||||
elif covariance_matrix is not None:
|
||||
if covariance_matrix.dim() < 2:
|
||||
raise ValueError("covariance_matrix must be at least two-dimensional, "
|
||||
"with optional leading batch dimensions")
|
||||
self.covariance_matrix, loc_ = torch.broadcast_tensors(covariance_matrix, loc_)
|
||||
batch_shape = torch.broadcast_shapes(covariance_matrix.shape[:-2], loc.shape[:-1])
|
||||
self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1))
|
||||
else:
|
||||
if precision_matrix.dim() < 2:
|
||||
raise ValueError("precision_matrix must be at least two-dimensional, "
|
||||
"with optional leading batch dimensions")
|
||||
self.precision_matrix, loc_ = torch.broadcast_tensors(precision_matrix, loc_)
|
||||
self.loc = loc_[..., 0] # drop rightmost dim
|
||||
batch_shape = torch.broadcast_shapes(precision_matrix.shape[:-2], loc.shape[:-1])
|
||||
self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1))
|
||||
self.loc = loc.expand(batch_shape + (-1,))
|
||||
|
||||
batch_shape, event_shape = self.loc.shape[:-1], self.loc.shape[-1:]
|
||||
event_shape = self.loc.shape[-1:]
|
||||
super(MultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=validate_args)
|
||||
|
||||
if scale_tril is not None:
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ __all__ = [
|
|||
'atleast_2d',
|
||||
'atleast_3d',
|
||||
'align_tensors',
|
||||
'broadcast_shapes',
|
||||
'broadcast_tensors',
|
||||
'cartesian_prod',
|
||||
'block_diag',
|
||||
|
|
@ -72,6 +73,39 @@ def broadcast_tensors(*tensors):
|
|||
return _VF.broadcast_tensors(tensors) # type: ignore
|
||||
|
||||
|
||||
def broadcast_shapes(*shapes):
|
||||
r"""broadcast_shapes(*shapes) -> Size
|
||||
|
||||
Similar to :func:`broadcast_tensors` but for shapes.
|
||||
|
||||
This is equivalent to
|
||||
``torch.broadcast_tensors(*map(torch.empty, shapes))[0].shape``
|
||||
but avoids the need create to intermediate tensors. This is useful for
|
||||
broadcasting tensors of common batch shape but different rightmost shape,
|
||||
e.g. to broadcast mean vectors with covariance matrices.
|
||||
|
||||
Example::
|
||||
|
||||
>>> torch.broadcast_shapes((2,), (3, 1), (1, 1, 1))
|
||||
torch.Size([1, 3, 2])
|
||||
|
||||
Args:
|
||||
\*shapes (torch.Size): Shapes of tensors.
|
||||
|
||||
Returns:
|
||||
shape (torch.Size): A shape compatible with all input shapes.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If shapes are incompatible.
|
||||
"""
|
||||
# TODO Movie this to C++ once the jit has better support for torch.Size.
|
||||
with torch.no_grad():
|
||||
scalar = torch.zeros((), device="cpu")
|
||||
tensors = [scalar.expand(shape) for shape in shapes]
|
||||
tensors = broadcast_tensors(*tensors)
|
||||
return tensors[0].shape
|
||||
|
||||
|
||||
def split(tensor, split_size_or_sections, dim=0):
|
||||
r"""Splits the tensor into chunks. Each chunk is a view of the original tensor.
|
||||
|
||||
|
|
|
|||
|
|
@ -119,6 +119,7 @@ def get_ignored_functions() -> Set[Callable]:
|
|||
torch.as_strided,
|
||||
torch.bartlett_window,
|
||||
torch.blackman_window,
|
||||
torch.broadcast_shapes,
|
||||
torch.can_cast,
|
||||
torch.cudnn_affine_grid_generator,
|
||||
torch.cudnn_batch_norm,
|
||||
|
|
|
|||
Loading…
Reference in a new issue