mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145176 Approved by: https://github.com/bobrenjc93
576 lines
19 KiB
Python
576 lines
19 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
# Owner(s): ["oncall: distributed"]
|
|
|
|
import itertools
|
|
from typing import cast
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch import rand, randn, Tensor
|
|
from torch.distributed._tensor import (
|
|
DeviceMesh,
|
|
distribute_tensor,
|
|
init_device_mesh,
|
|
Replicate,
|
|
Shard,
|
|
)
|
|
from torch.distributed._tensor.placement_types import Placement
|
|
from torch.distributed.tensor._ops._view_ops import (
|
|
Broadcast,
|
|
dim_maps,
|
|
Flatten,
|
|
InputDim,
|
|
Repeat,
|
|
Singleton,
|
|
Split,
|
|
view_groups,
|
|
)
|
|
from torch.distributed.tensor.debug import CommDebugMode
|
|
from torch.testing._internal.common_utils import run_tests
|
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|
DTensorTestBase,
|
|
with_comms,
|
|
)
|
|
from torch.utils import _pytree as pytree
|
|
|
|
|
|
class TestViewOps(DTensorTestBase):
|
|
@property
|
|
def world_size(self) -> int:
|
|
return 6
|
|
|
|
def test_view_groups(self):
|
|
self.assertEqual(
|
|
view_groups([2, 3], [3, 2]),
|
|
(
|
|
Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 0),
|
|
Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 1),
|
|
),
|
|
)
|
|
self.assertEqual(
|
|
view_groups([3, 4, 5], [12, 5]),
|
|
(Flatten((InputDim(0), InputDim(1))), InputDim(2)),
|
|
)
|
|
self.assertEqual(
|
|
view_groups([2, 3, 4, 5, 7], [12, 70]),
|
|
(
|
|
Split(
|
|
Flatten(
|
|
(
|
|
InputDim(0),
|
|
InputDim(1),
|
|
InputDim(2),
|
|
InputDim(3),
|
|
InputDim(4),
|
|
)
|
|
),
|
|
(12, 70),
|
|
0,
|
|
),
|
|
Split(
|
|
Flatten(
|
|
(
|
|
InputDim(0),
|
|
InputDim(1),
|
|
InputDim(2),
|
|
InputDim(3),
|
|
InputDim(4),
|
|
)
|
|
),
|
|
(12, 70),
|
|
1,
|
|
),
|
|
),
|
|
)
|
|
self.assertEqual(
|
|
view_groups([2, 3, 4, 5, 7], [3, 8, 7, 5]),
|
|
(
|
|
Split(Flatten((InputDim(0), InputDim(1), InputDim(2))), (3, 8), 0),
|
|
Split(Flatten((InputDim(0), InputDim(1), InputDim(2))), (3, 8), 1),
|
|
Split(Flatten((InputDim(3), InputDim(4))), (7, 5), 0),
|
|
Split(Flatten((InputDim(3), InputDim(4))), (7, 5), 1),
|
|
),
|
|
)
|
|
self.assertEqual(
|
|
view_groups([3, 4, 8, 3], [12, 4, 2, 3]),
|
|
(
|
|
Flatten((InputDim(0), InputDim(1))),
|
|
Split(InputDim(2), (4, 2), 0),
|
|
Split(InputDim(2), (4, 2), 1),
|
|
InputDim(3),
|
|
),
|
|
)
|
|
self.assertEqual(
|
|
view_groups([3, 24], [1, 3, 2, 4, 1, 3, 1]),
|
|
(
|
|
Singleton(),
|
|
InputDim(0),
|
|
Split(InputDim(1), (2, 4, 3), 0),
|
|
Split(InputDim(1), (2, 4, 3), 1),
|
|
Singleton(),
|
|
Split(InputDim(1), (2, 4, 3), 2),
|
|
Singleton(),
|
|
),
|
|
)
|
|
self.assertEqual(
|
|
view_groups([1, 1, 3, 2, 1, 1], [6, 1, 1, 1]),
|
|
(
|
|
Flatten((InputDim(2), InputDim(3))),
|
|
InputDim(4),
|
|
InputDim(5),
|
|
Singleton(),
|
|
),
|
|
)
|
|
self.assertEqual(
|
|
view_groups([1, 1, 12, 1, 1, 1, 2, 5, 1], [3, 4, 1, 10]),
|
|
(
|
|
Split(InputDim(2), (3, 4), 0),
|
|
Split(InputDim(2), (3, 4), 1),
|
|
InputDim(3),
|
|
Flatten((InputDim(6), InputDim(7))),
|
|
),
|
|
)
|
|
self.assertEqual(
|
|
view_groups([2, 3, 4], [2, -1, 4]),
|
|
(InputDim(0), InputDim(1), InputDim(2)),
|
|
)
|
|
|
|
def call_dt_test(self, op, args, kwargs, device_mesh: DeviceMesh):
|
|
dim_map = dim_maps[op]
|
|
rules = dim_map(*args, **kwargs)
|
|
outputs = op(*args, **kwargs)
|
|
flat_args = pytree.arg_tree_leaves(*args)
|
|
in_shape = flat_args[0].shape
|
|
|
|
no_shard_dims = set()
|
|
for rule in rules:
|
|
if isinstance(rule, Repeat):
|
|
if isinstance(rule.input_dim, InputDim):
|
|
no_shard_dims.add(rule.input_dim.input_dim)
|
|
elif isinstance(rule, Flatten):
|
|
for dim in rule.input_dims[1:]:
|
|
if isinstance(dim, InputDim):
|
|
no_shard_dims.add(dim.input_dim)
|
|
elif isinstance(rule, Split):
|
|
if isinstance(rule.input_dim, Flatten):
|
|
for dim in rule.input_dim.input_dims[1:]:
|
|
if isinstance(dim, InputDim):
|
|
no_shard_dims.add(dim.input_dim)
|
|
|
|
if op == torch.unbind:
|
|
no_shard_dims.add(kwargs.get("dim", 0))
|
|
|
|
sharding_choices = cast(list[Placement], [Replicate()]) + [
|
|
Shard(i) for i, s in enumerate(in_shape) if s > 1 and i not in no_shard_dims
|
|
]
|
|
|
|
all_sharding_choices = itertools.product(
|
|
*(device_mesh.ndim * [sharding_choices])
|
|
)
|
|
|
|
for in_shard in all_sharding_choices:
|
|
in_dt = distribute_tensor(args[0], device_mesh, in_shard)
|
|
|
|
comm_mode = CommDebugMode()
|
|
with comm_mode:
|
|
out_dt = op(in_dt, *args[1:], **kwargs)
|
|
|
|
self.assertEqual(
|
|
comm_mode.get_total_counts(), 0, "Expected no redistribution."
|
|
)
|
|
|
|
full_out = out_dt.full_tensor()
|
|
|
|
if dist.get_rank() == 0:
|
|
self.assertEqual(outputs, full_out)
|
|
|
|
def dimmap_test(self, op, args, expected_rule_output):
|
|
rules = dim_maps[op](*args)
|
|
self.assertEqual(rules, expected_rule_output)
|
|
self.call_dt_test(op, args, {}, self.device_mesh)
|
|
|
|
@with_comms
|
|
def test_view_ops(self):
|
|
self.device_mesh = DeviceMesh(
|
|
self.device_type, torch.arange(dist.get_world_size()).view(-1, 2)
|
|
)
|
|
self.dimmap_test(torch.atleast_1d, (randn(()),), (Singleton(),))
|
|
self.dimmap_test(torch.atleast_1d, (randn(24),), (InputDim(0),))
|
|
self.dimmap_test(torch.atleast_1d, (randn(24, 36),), (InputDim(0), InputDim(1)))
|
|
|
|
self.dimmap_test(torch.atleast_2d, (randn(()),), (Singleton(), Singleton()))
|
|
self.dimmap_test(torch.atleast_2d, (randn(24),), (Singleton(), InputDim(0)))
|
|
self.dimmap_test(torch.atleast_2d, (randn(24, 36),), (InputDim(0), InputDim(1)))
|
|
self.dimmap_test(
|
|
torch.atleast_2d,
|
|
(randn(24, 36, 48),),
|
|
(InputDim(0), InputDim(1), InputDim(2)),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
torch.atleast_3d,
|
|
(randn(()),),
|
|
(Singleton(), Singleton(), Singleton()),
|
|
)
|
|
self.dimmap_test(
|
|
torch.atleast_3d,
|
|
(randn(24),),
|
|
(Singleton(), InputDim(0), Singleton()),
|
|
)
|
|
self.dimmap_test(
|
|
torch.atleast_3d,
|
|
(randn(24, 36),),
|
|
(InputDim(0), InputDim(1), Singleton()),
|
|
)
|
|
self.dimmap_test(
|
|
torch.atleast_3d,
|
|
(randn(24, 36, 42),),
|
|
(InputDim(0), InputDim(1), InputDim(2)),
|
|
)
|
|
self.dimmap_test(
|
|
torch.atleast_3d,
|
|
(randn(24, 36, 42, 24),),
|
|
(InputDim(0), InputDim(1), InputDim(2), InputDim(3)),
|
|
)
|
|
|
|
with self.assertRaises(AssertionError):
|
|
dim_maps[torch.broadcast_to](randn(24, 36), (1, 2, 4))
|
|
|
|
self.dimmap_test(
|
|
torch.broadcast_to,
|
|
(rand(24, 36), (1, 24, 36)),
|
|
(Singleton(), InputDim(0), InputDim(1)),
|
|
)
|
|
self.dimmap_test(
|
|
torch.broadcast_to,
|
|
(rand(24, 36), (42, 24, 36)),
|
|
(Broadcast(Singleton(), 42), InputDim(0), InputDim(1)),
|
|
)
|
|
self.dimmap_test(
|
|
torch.broadcast_to,
|
|
(rand(24, 1, 36), (12, 24, 24, 36)),
|
|
(
|
|
Broadcast(Singleton(), 12),
|
|
InputDim(0),
|
|
Broadcast(InputDim(1), 24),
|
|
InputDim(2),
|
|
),
|
|
)
|
|
self.dimmap_test(
|
|
torch.broadcast_to,
|
|
(rand(24, 36), (-1, 36)),
|
|
(InputDim(0), InputDim(1)),
|
|
)
|
|
self.dimmap_test(
|
|
torch.broadcast_to,
|
|
(rand(24, 1, 36), (-1, 1, 36)),
|
|
(InputDim(0), InputDim(1), InputDim(2)),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
torch.broadcast_to,
|
|
(randn(36, 1, 24), (12, 36, 42, 24)),
|
|
(
|
|
Broadcast(Singleton(), 12),
|
|
InputDim(0),
|
|
Broadcast(InputDim(1), 42),
|
|
InputDim(2),
|
|
),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
Tensor.expand,
|
|
(randn(24, 1, 36, 1), 36, 24, 42, -1, 24),
|
|
(
|
|
Broadcast(Singleton(), 36),
|
|
InputDim(0),
|
|
Broadcast(InputDim(1), 42),
|
|
InputDim(2),
|
|
Broadcast(InputDim(3), 24),
|
|
),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
Tensor.expand,
|
|
(randn(24, 1, 36, 1), (36, 24, 42, -1, 24)),
|
|
(
|
|
Broadcast(Singleton(), 36),
|
|
InputDim(0),
|
|
Broadcast(InputDim(1), 42),
|
|
InputDim(2),
|
|
Broadcast(InputDim(3), 24),
|
|
),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
torch.flatten,
|
|
(randn(24, 36),),
|
|
(Flatten((InputDim(0), InputDim(1))),),
|
|
)
|
|
self.dimmap_test(torch.flatten, (randn(42),), (InputDim(0),))
|
|
self.dimmap_test(torch.flatten, (randn(()),), (Singleton(),))
|
|
|
|
self.dimmap_test(
|
|
torch.movedim,
|
|
(randn(12, 24, 48, 96), 1, 2),
|
|
(InputDim(0), InputDim(2), InputDim(1), InputDim(3)),
|
|
)
|
|
self.dimmap_test(
|
|
torch.movedim,
|
|
(randn(6, 12, 24), 1, 0),
|
|
(InputDim(1), InputDim(0), InputDim(2)),
|
|
)
|
|
self.dimmap_test(
|
|
torch.movedim,
|
|
(randn(24, 12, 6), (1, 2), (0, 1)),
|
|
(InputDim(1), InputDim(2), InputDim(0)),
|
|
)
|
|
self.dimmap_test(
|
|
torch.movedim,
|
|
(randn(24, 6, 12), (0, 2, 1), (2, 1, 0)),
|
|
(InputDim(1), InputDim(2), InputDim(0)),
|
|
)
|
|
self.dimmap_test(
|
|
torch.movedim,
|
|
(randn(24, 12), (1, 0), (0, 1)),
|
|
(InputDim(1), InputDim(0)),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
torch.movedim,
|
|
(randn(36, 24, 12), (1, 2), (0, 1)),
|
|
(InputDim(1), InputDim(2), InputDim(0)),
|
|
)
|
|
self.dimmap_test(
|
|
torch.movedim,
|
|
(randn(36, 24, 12), (1, 2), (-3, -2)),
|
|
(InputDim(1), InputDim(2), InputDim(0)),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
torch.permute,
|
|
(randn(24, 36, 42), (2, 0, 1)),
|
|
(InputDim(2), InputDim(0), InputDim(1)),
|
|
)
|
|
self.dimmap_test(
|
|
torch.permute,
|
|
(randn(24, 36, 42), (-1, -3, -2)),
|
|
(InputDim(2), InputDim(0), InputDim(1)),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
torch.ravel,
|
|
(randn(24, 36),),
|
|
(Flatten((InputDim(0), InputDim(1))),),
|
|
)
|
|
self.dimmap_test(torch.ravel, (randn(42),), (InputDim(0),))
|
|
self.dimmap_test(torch.ravel, (randn(()),), (Singleton(),))
|
|
|
|
self.dimmap_test(
|
|
Tensor.repeat,
|
|
(randn(24, 36), 1, 2, 1, 1, 2),
|
|
(
|
|
Singleton(),
|
|
Broadcast(Singleton(), 2),
|
|
Singleton(),
|
|
InputDim(0),
|
|
Repeat(InputDim(1), 2),
|
|
),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
torch.reshape,
|
|
(randn(6, 12, 24), (72, 24)),
|
|
(Flatten((InputDim(0), InputDim(1))), InputDim(2)),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
torch.tile,
|
|
(randn(24, 36), (1, 2, 1, 1, 2)),
|
|
(
|
|
Singleton(),
|
|
Broadcast(Singleton(), 2),
|
|
Singleton(),
|
|
InputDim(0),
|
|
Repeat(InputDim(1), 2),
|
|
),
|
|
)
|
|
self.dimmap_test(
|
|
torch.tile,
|
|
(randn(42, 24, 36), (1, 3)),
|
|
(InputDim(0), InputDim(1), Repeat(InputDim(2), 3)),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
torch.transpose,
|
|
(randn(24, 60, 42, 60), 2, 0),
|
|
(InputDim(2), InputDim(1), InputDim(0), InputDim(3)),
|
|
)
|
|
self.dimmap_test(
|
|
torch.transpose,
|
|
(randn(24, 60, 42, 60), -1, 0),
|
|
(InputDim(3), InputDim(1), InputDim(2), InputDim(0)),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
torch.unsqueeze,
|
|
(randn(42, 24, 36), 1),
|
|
(InputDim(0), Singleton(), InputDim(1), InputDim(2)),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
Tensor.view,
|
|
(randn(6, 12, 24), 72, 24),
|
|
(Flatten((InputDim(0), InputDim(1))), InputDim(2)),
|
|
)
|
|
|
|
self.dimmap_test(Tensor.view, (randn(1, 1, 12), -1), (InputDim(2),))
|
|
|
|
self.dimmap_test(
|
|
Tensor.view,
|
|
(randn(1, 1, 42, 24), -1),
|
|
(Flatten((InputDim(2), InputDim(3))),),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
Tensor.view,
|
|
(randn(1, 1, 42, 1, 24, 1), -1),
|
|
(Flatten((InputDim(2), InputDim(input_dim=3), InputDim(4))),),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
Tensor.view,
|
|
(randn(48, 35, 26), (24, 4, 35, 13)),
|
|
(
|
|
Split(
|
|
Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))),
|
|
group_shape=(24, 4, 35, 13),
|
|
split_id=0,
|
|
),
|
|
Split(
|
|
Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))),
|
|
group_shape=(24, 4, 35, 13),
|
|
split_id=1,
|
|
),
|
|
Split(
|
|
Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))),
|
|
group_shape=(24, 4, 35, 13),
|
|
split_id=2,
|
|
),
|
|
Split(
|
|
Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))),
|
|
group_shape=(24, 4, 35, 13),
|
|
split_id=3,
|
|
),
|
|
),
|
|
)
|
|
|
|
# TODO: Currently functional collectives on complex numbers are not fully supported,
|
|
# so we are having a standalone test for view_as_complex and view_as_real combined.
|
|
# Once complex numbers are supported, we can add the following to the dim_map test.
|
|
#
|
|
# self.dimmap_test(
|
|
# torch.view_as_complex,
|
|
# (randn(24, 13, 2),),
|
|
# (
|
|
# InputDim(0),
|
|
# Flatten((InputDim(1), InputDim(2))),
|
|
# ),
|
|
# )
|
|
# self.dimmap_test(
|
|
# torch.view_as_real,
|
|
# (torch.randn(24, 13, dtype=torch.cfloat),),
|
|
# (
|
|
# InputDim(0),
|
|
# Split(InputDim(1), (13, 2), 0),
|
|
# Split(InputDim(1), (13, 2), 1),
|
|
# ),
|
|
# )
|
|
@with_comms
|
|
def test_complex_view_ops(self):
|
|
self.device_mesh = DeviceMesh(
|
|
self.device_type, torch.arange(dist.get_world_size()).view(-1, 2)
|
|
)
|
|
inp = randn(24, 13, 2)
|
|
intermediate = torch.view_as_complex(inp)
|
|
out = torch.view_as_real(intermediate)
|
|
|
|
# test dim_map correctness
|
|
expected_view_as_complex_rule = (
|
|
InputDim(0),
|
|
Flatten((InputDim(1), InputDim(2))),
|
|
)
|
|
view_as_complex_rule = dim_maps[torch.view_as_complex](inp)
|
|
self.assertEqual(view_as_complex_rule, expected_view_as_complex_rule)
|
|
expected_view_as_real_rule = (
|
|
InputDim(0),
|
|
Split(InputDim(1), (13, 2), 0),
|
|
Split(InputDim(1), (13, 2), 1),
|
|
)
|
|
view_as_real_rule = dim_maps[torch.view_as_real](intermediate)
|
|
self.assertEqual(view_as_real_rule, expected_view_as_real_rule)
|
|
|
|
# test sharded computation correctness
|
|
# NOTE: For the input to torch.view_as_complex, sharding
|
|
# on the last two dimensions is not supported.
|
|
sharding_choices: list[Placement] = [Replicate(), Shard(0)]
|
|
all_sharding_choices = itertools.product(
|
|
*(self.device_mesh.ndim * [sharding_choices])
|
|
)
|
|
|
|
for inp_shard in all_sharding_choices:
|
|
inp_dt = distribute_tensor(inp, self.device_mesh, inp_shard)
|
|
|
|
comm_mode = CommDebugMode()
|
|
with comm_mode:
|
|
intermediate_dt = torch.view_as_complex(inp_dt)
|
|
out_dt = torch.view_as_real(intermediate_dt)
|
|
|
|
self.assertEqual(
|
|
comm_mode.get_total_counts(), 0, "Expected no redistribution."
|
|
)
|
|
self.assertEqual(out, out_dt.full_tensor())
|
|
|
|
@with_comms
|
|
def test_dtensor_view_op_uneven(self):
|
|
"""
|
|
Test two uneven cases for view op:
|
|
1) the sharded tensor dim is 1 so that only the first rank has an non-empty shard.
|
|
2) the sharded tensor dim is uneven such that some ranks have full shards,
|
|
smaller non-empty shards, and empty shards.
|
|
"""
|
|
dim0_sizes = [1, self.world_size + 1]
|
|
for dim0_size in dim0_sizes:
|
|
p = torch.randn(dim0_size, 2, 2, 2)
|
|
mesh = init_device_mesh(self.device_type, (self.world_size,))
|
|
dtensor = distribute_tensor(p, mesh, [Shard(0)])
|
|
|
|
with CommDebugMode() as comm_mode:
|
|
view = dtensor.view(dim0_size, 2, 4)
|
|
self.assertEqual(len(comm_mode.get_comm_counts()), 0)
|
|
# when no communication happens, the data pointer should be the same.
|
|
self.assertEqual(
|
|
view.to_local().data_ptr(), dtensor.to_local().data_ptr()
|
|
)
|
|
|
|
view = dtensor.view(dim0_size, 4, 2)
|
|
self.assertEqual(
|
|
view.to_local().data_ptr(), dtensor.to_local().data_ptr()
|
|
)
|
|
self.assertEqual(len(comm_mode.get_comm_counts()), 0)
|
|
|
|
view = dtensor.view(dim0_size, 8)
|
|
self.assertEqual(
|
|
view.to_local().data_ptr(), dtensor.to_local().data_ptr()
|
|
)
|
|
self.assertEqual(len(comm_mode.get_comm_counts()), 0)
|
|
|
|
view = dtensor.view(dtensor.shape)
|
|
self.assertEqual(
|
|
view.to_local().data_ptr(), dtensor.to_local().data_ptr()
|
|
)
|
|
self.assertEqual(len(comm_mode.get_comm_counts()), 0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|