pytorch/test/distributed/tensor/test_view_ops.py
2025-01-22 04:48:28 +00:00

576 lines
19 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import itertools
from typing import cast
import torch
import torch.distributed as dist
from torch import rand, randn, Tensor
from torch.distributed._tensor import (
DeviceMesh,
distribute_tensor,
init_device_mesh,
Replicate,
Shard,
)
from torch.distributed._tensor.placement_types import Placement
from torch.distributed.tensor._ops._view_ops import (
Broadcast,
dim_maps,
Flatten,
InputDim,
Repeat,
Singleton,
Split,
view_groups,
)
from torch.distributed.tensor.debug import CommDebugMode
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)
from torch.utils import _pytree as pytree
class TestViewOps(DTensorTestBase):
@property
def world_size(self) -> int:
return 6
def test_view_groups(self):
self.assertEqual(
view_groups([2, 3], [3, 2]),
(
Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 0),
Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 1),
),
)
self.assertEqual(
view_groups([3, 4, 5], [12, 5]),
(Flatten((InputDim(0), InputDim(1))), InputDim(2)),
)
self.assertEqual(
view_groups([2, 3, 4, 5, 7], [12, 70]),
(
Split(
Flatten(
(
InputDim(0),
InputDim(1),
InputDim(2),
InputDim(3),
InputDim(4),
)
),
(12, 70),
0,
),
Split(
Flatten(
(
InputDim(0),
InputDim(1),
InputDim(2),
InputDim(3),
InputDim(4),
)
),
(12, 70),
1,
),
),
)
self.assertEqual(
view_groups([2, 3, 4, 5, 7], [3, 8, 7, 5]),
(
Split(Flatten((InputDim(0), InputDim(1), InputDim(2))), (3, 8), 0),
Split(Flatten((InputDim(0), InputDim(1), InputDim(2))), (3, 8), 1),
Split(Flatten((InputDim(3), InputDim(4))), (7, 5), 0),
Split(Flatten((InputDim(3), InputDim(4))), (7, 5), 1),
),
)
self.assertEqual(
view_groups([3, 4, 8, 3], [12, 4, 2, 3]),
(
Flatten((InputDim(0), InputDim(1))),
Split(InputDim(2), (4, 2), 0),
Split(InputDim(2), (4, 2), 1),
InputDim(3),
),
)
self.assertEqual(
view_groups([3, 24], [1, 3, 2, 4, 1, 3, 1]),
(
Singleton(),
InputDim(0),
Split(InputDim(1), (2, 4, 3), 0),
Split(InputDim(1), (2, 4, 3), 1),
Singleton(),
Split(InputDim(1), (2, 4, 3), 2),
Singleton(),
),
)
self.assertEqual(
view_groups([1, 1, 3, 2, 1, 1], [6, 1, 1, 1]),
(
Flatten((InputDim(2), InputDim(3))),
InputDim(4),
InputDim(5),
Singleton(),
),
)
self.assertEqual(
view_groups([1, 1, 12, 1, 1, 1, 2, 5, 1], [3, 4, 1, 10]),
(
Split(InputDim(2), (3, 4), 0),
Split(InputDim(2), (3, 4), 1),
InputDim(3),
Flatten((InputDim(6), InputDim(7))),
),
)
self.assertEqual(
view_groups([2, 3, 4], [2, -1, 4]),
(InputDim(0), InputDim(1), InputDim(2)),
)
def call_dt_test(self, op, args, kwargs, device_mesh: DeviceMesh):
dim_map = dim_maps[op]
rules = dim_map(*args, **kwargs)
outputs = op(*args, **kwargs)
flat_args = pytree.arg_tree_leaves(*args)
in_shape = flat_args[0].shape
no_shard_dims = set()
for rule in rules:
if isinstance(rule, Repeat):
if isinstance(rule.input_dim, InputDim):
no_shard_dims.add(rule.input_dim.input_dim)
elif isinstance(rule, Flatten):
for dim in rule.input_dims[1:]:
if isinstance(dim, InputDim):
no_shard_dims.add(dim.input_dim)
elif isinstance(rule, Split):
if isinstance(rule.input_dim, Flatten):
for dim in rule.input_dim.input_dims[1:]:
if isinstance(dim, InputDim):
no_shard_dims.add(dim.input_dim)
if op == torch.unbind:
no_shard_dims.add(kwargs.get("dim", 0))
sharding_choices = cast(list[Placement], [Replicate()]) + [
Shard(i) for i, s in enumerate(in_shape) if s > 1 and i not in no_shard_dims
]
all_sharding_choices = itertools.product(
*(device_mesh.ndim * [sharding_choices])
)
for in_shard in all_sharding_choices:
in_dt = distribute_tensor(args[0], device_mesh, in_shard)
comm_mode = CommDebugMode()
with comm_mode:
out_dt = op(in_dt, *args[1:], **kwargs)
self.assertEqual(
comm_mode.get_total_counts(), 0, "Expected no redistribution."
)
full_out = out_dt.full_tensor()
if dist.get_rank() == 0:
self.assertEqual(outputs, full_out)
def dimmap_test(self, op, args, expected_rule_output):
rules = dim_maps[op](*args)
self.assertEqual(rules, expected_rule_output)
self.call_dt_test(op, args, {}, self.device_mesh)
@with_comms
def test_view_ops(self):
self.device_mesh = DeviceMesh(
self.device_type, torch.arange(dist.get_world_size()).view(-1, 2)
)
self.dimmap_test(torch.atleast_1d, (randn(()),), (Singleton(),))
self.dimmap_test(torch.atleast_1d, (randn(24),), (InputDim(0),))
self.dimmap_test(torch.atleast_1d, (randn(24, 36),), (InputDim(0), InputDim(1)))
self.dimmap_test(torch.atleast_2d, (randn(()),), (Singleton(), Singleton()))
self.dimmap_test(torch.atleast_2d, (randn(24),), (Singleton(), InputDim(0)))
self.dimmap_test(torch.atleast_2d, (randn(24, 36),), (InputDim(0), InputDim(1)))
self.dimmap_test(
torch.atleast_2d,
(randn(24, 36, 48),),
(InputDim(0), InputDim(1), InputDim(2)),
)
self.dimmap_test(
torch.atleast_3d,
(randn(()),),
(Singleton(), Singleton(), Singleton()),
)
self.dimmap_test(
torch.atleast_3d,
(randn(24),),
(Singleton(), InputDim(0), Singleton()),
)
self.dimmap_test(
torch.atleast_3d,
(randn(24, 36),),
(InputDim(0), InputDim(1), Singleton()),
)
self.dimmap_test(
torch.atleast_3d,
(randn(24, 36, 42),),
(InputDim(0), InputDim(1), InputDim(2)),
)
self.dimmap_test(
torch.atleast_3d,
(randn(24, 36, 42, 24),),
(InputDim(0), InputDim(1), InputDim(2), InputDim(3)),
)
with self.assertRaises(AssertionError):
dim_maps[torch.broadcast_to](randn(24, 36), (1, 2, 4))
self.dimmap_test(
torch.broadcast_to,
(rand(24, 36), (1, 24, 36)),
(Singleton(), InputDim(0), InputDim(1)),
)
self.dimmap_test(
torch.broadcast_to,
(rand(24, 36), (42, 24, 36)),
(Broadcast(Singleton(), 42), InputDim(0), InputDim(1)),
)
self.dimmap_test(
torch.broadcast_to,
(rand(24, 1, 36), (12, 24, 24, 36)),
(
Broadcast(Singleton(), 12),
InputDim(0),
Broadcast(InputDim(1), 24),
InputDim(2),
),
)
self.dimmap_test(
torch.broadcast_to,
(rand(24, 36), (-1, 36)),
(InputDim(0), InputDim(1)),
)
self.dimmap_test(
torch.broadcast_to,
(rand(24, 1, 36), (-1, 1, 36)),
(InputDim(0), InputDim(1), InputDim(2)),
)
self.dimmap_test(
torch.broadcast_to,
(randn(36, 1, 24), (12, 36, 42, 24)),
(
Broadcast(Singleton(), 12),
InputDim(0),
Broadcast(InputDim(1), 42),
InputDim(2),
),
)
self.dimmap_test(
Tensor.expand,
(randn(24, 1, 36, 1), 36, 24, 42, -1, 24),
(
Broadcast(Singleton(), 36),
InputDim(0),
Broadcast(InputDim(1), 42),
InputDim(2),
Broadcast(InputDim(3), 24),
),
)
self.dimmap_test(
Tensor.expand,
(randn(24, 1, 36, 1), (36, 24, 42, -1, 24)),
(
Broadcast(Singleton(), 36),
InputDim(0),
Broadcast(InputDim(1), 42),
InputDim(2),
Broadcast(InputDim(3), 24),
),
)
self.dimmap_test(
torch.flatten,
(randn(24, 36),),
(Flatten((InputDim(0), InputDim(1))),),
)
self.dimmap_test(torch.flatten, (randn(42),), (InputDim(0),))
self.dimmap_test(torch.flatten, (randn(()),), (Singleton(),))
self.dimmap_test(
torch.movedim,
(randn(12, 24, 48, 96), 1, 2),
(InputDim(0), InputDim(2), InputDim(1), InputDim(3)),
)
self.dimmap_test(
torch.movedim,
(randn(6, 12, 24), 1, 0),
(InputDim(1), InputDim(0), InputDim(2)),
)
self.dimmap_test(
torch.movedim,
(randn(24, 12, 6), (1, 2), (0, 1)),
(InputDim(1), InputDim(2), InputDim(0)),
)
self.dimmap_test(
torch.movedim,
(randn(24, 6, 12), (0, 2, 1), (2, 1, 0)),
(InputDim(1), InputDim(2), InputDim(0)),
)
self.dimmap_test(
torch.movedim,
(randn(24, 12), (1, 0), (0, 1)),
(InputDim(1), InputDim(0)),
)
self.dimmap_test(
torch.movedim,
(randn(36, 24, 12), (1, 2), (0, 1)),
(InputDim(1), InputDim(2), InputDim(0)),
)
self.dimmap_test(
torch.movedim,
(randn(36, 24, 12), (1, 2), (-3, -2)),
(InputDim(1), InputDim(2), InputDim(0)),
)
self.dimmap_test(
torch.permute,
(randn(24, 36, 42), (2, 0, 1)),
(InputDim(2), InputDim(0), InputDim(1)),
)
self.dimmap_test(
torch.permute,
(randn(24, 36, 42), (-1, -3, -2)),
(InputDim(2), InputDim(0), InputDim(1)),
)
self.dimmap_test(
torch.ravel,
(randn(24, 36),),
(Flatten((InputDim(0), InputDim(1))),),
)
self.dimmap_test(torch.ravel, (randn(42),), (InputDim(0),))
self.dimmap_test(torch.ravel, (randn(()),), (Singleton(),))
self.dimmap_test(
Tensor.repeat,
(randn(24, 36), 1, 2, 1, 1, 2),
(
Singleton(),
Broadcast(Singleton(), 2),
Singleton(),
InputDim(0),
Repeat(InputDim(1), 2),
),
)
self.dimmap_test(
torch.reshape,
(randn(6, 12, 24), (72, 24)),
(Flatten((InputDim(0), InputDim(1))), InputDim(2)),
)
self.dimmap_test(
torch.tile,
(randn(24, 36), (1, 2, 1, 1, 2)),
(
Singleton(),
Broadcast(Singleton(), 2),
Singleton(),
InputDim(0),
Repeat(InputDim(1), 2),
),
)
self.dimmap_test(
torch.tile,
(randn(42, 24, 36), (1, 3)),
(InputDim(0), InputDim(1), Repeat(InputDim(2), 3)),
)
self.dimmap_test(
torch.transpose,
(randn(24, 60, 42, 60), 2, 0),
(InputDim(2), InputDim(1), InputDim(0), InputDim(3)),
)
self.dimmap_test(
torch.transpose,
(randn(24, 60, 42, 60), -1, 0),
(InputDim(3), InputDim(1), InputDim(2), InputDim(0)),
)
self.dimmap_test(
torch.unsqueeze,
(randn(42, 24, 36), 1),
(InputDim(0), Singleton(), InputDim(1), InputDim(2)),
)
self.dimmap_test(
Tensor.view,
(randn(6, 12, 24), 72, 24),
(Flatten((InputDim(0), InputDim(1))), InputDim(2)),
)
self.dimmap_test(Tensor.view, (randn(1, 1, 12), -1), (InputDim(2),))
self.dimmap_test(
Tensor.view,
(randn(1, 1, 42, 24), -1),
(Flatten((InputDim(2), InputDim(3))),),
)
self.dimmap_test(
Tensor.view,
(randn(1, 1, 42, 1, 24, 1), -1),
(Flatten((InputDim(2), InputDim(input_dim=3), InputDim(4))),),
)
self.dimmap_test(
Tensor.view,
(randn(48, 35, 26), (24, 4, 35, 13)),
(
Split(
Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))),
group_shape=(24, 4, 35, 13),
split_id=0,
),
Split(
Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))),
group_shape=(24, 4, 35, 13),
split_id=1,
),
Split(
Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))),
group_shape=(24, 4, 35, 13),
split_id=2,
),
Split(
Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))),
group_shape=(24, 4, 35, 13),
split_id=3,
),
),
)
# TODO: Currently functional collectives on complex numbers are not fully supported,
# so we are having a standalone test for view_as_complex and view_as_real combined.
# Once complex numbers are supported, we can add the following to the dim_map test.
#
# self.dimmap_test(
# torch.view_as_complex,
# (randn(24, 13, 2),),
# (
# InputDim(0),
# Flatten((InputDim(1), InputDim(2))),
# ),
# )
# self.dimmap_test(
# torch.view_as_real,
# (torch.randn(24, 13, dtype=torch.cfloat),),
# (
# InputDim(0),
# Split(InputDim(1), (13, 2), 0),
# Split(InputDim(1), (13, 2), 1),
# ),
# )
@with_comms
def test_complex_view_ops(self):
self.device_mesh = DeviceMesh(
self.device_type, torch.arange(dist.get_world_size()).view(-1, 2)
)
inp = randn(24, 13, 2)
intermediate = torch.view_as_complex(inp)
out = torch.view_as_real(intermediate)
# test dim_map correctness
expected_view_as_complex_rule = (
InputDim(0),
Flatten((InputDim(1), InputDim(2))),
)
view_as_complex_rule = dim_maps[torch.view_as_complex](inp)
self.assertEqual(view_as_complex_rule, expected_view_as_complex_rule)
expected_view_as_real_rule = (
InputDim(0),
Split(InputDim(1), (13, 2), 0),
Split(InputDim(1), (13, 2), 1),
)
view_as_real_rule = dim_maps[torch.view_as_real](intermediate)
self.assertEqual(view_as_real_rule, expected_view_as_real_rule)
# test sharded computation correctness
# NOTE: For the input to torch.view_as_complex, sharding
# on the last two dimensions is not supported.
sharding_choices: list[Placement] = [Replicate(), Shard(0)]
all_sharding_choices = itertools.product(
*(self.device_mesh.ndim * [sharding_choices])
)
for inp_shard in all_sharding_choices:
inp_dt = distribute_tensor(inp, self.device_mesh, inp_shard)
comm_mode = CommDebugMode()
with comm_mode:
intermediate_dt = torch.view_as_complex(inp_dt)
out_dt = torch.view_as_real(intermediate_dt)
self.assertEqual(
comm_mode.get_total_counts(), 0, "Expected no redistribution."
)
self.assertEqual(out, out_dt.full_tensor())
@with_comms
def test_dtensor_view_op_uneven(self):
"""
Test two uneven cases for view op:
1) the sharded tensor dim is 1 so that only the first rank has an non-empty shard.
2) the sharded tensor dim is uneven such that some ranks have full shards,
smaller non-empty shards, and empty shards.
"""
dim0_sizes = [1, self.world_size + 1]
for dim0_size in dim0_sizes:
p = torch.randn(dim0_size, 2, 2, 2)
mesh = init_device_mesh(self.device_type, (self.world_size,))
dtensor = distribute_tensor(p, mesh, [Shard(0)])
with CommDebugMode() as comm_mode:
view = dtensor.view(dim0_size, 2, 4)
self.assertEqual(len(comm_mode.get_comm_counts()), 0)
# when no communication happens, the data pointer should be the same.
self.assertEqual(
view.to_local().data_ptr(), dtensor.to_local().data_ptr()
)
view = dtensor.view(dim0_size, 4, 2)
self.assertEqual(
view.to_local().data_ptr(), dtensor.to_local().data_ptr()
)
self.assertEqual(len(comm_mode.get_comm_counts()), 0)
view = dtensor.view(dim0_size, 8)
self.assertEqual(
view.to_local().data_ptr(), dtensor.to_local().data_ptr()
)
self.assertEqual(len(comm_mode.get_comm_counts()), 0)
view = dtensor.view(dtensor.shape)
self.assertEqual(
view.to_local().data_ptr(), dtensor.to_local().data_ptr()
)
self.assertEqual(len(comm_mode.get_comm_counts()), 0)
if __name__ == "__main__":
run_tests()