mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Applies some more harmless pyupgrades. This one gets rid of deprecated aliases in unit_tests and more upgrades yield for loops into yield from generators which are more performance and propagates more information / exceptions from original generator. This is the modern recommended way of forwarding generators. Pull Request resolved: https://github.com/pytorch/pytorch/pull/94309 Approved by: https://github.com/albanD
464 lines
14 KiB
Python
464 lines
14 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
# Owner(s): ["oncall: distributed"]
|
|
|
|
import itertools
|
|
from typing import cast, List
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch import rand, randn, Tensor
|
|
from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard
|
|
from torch.distributed._tensor.ops.view_ops import (
|
|
Broadcast,
|
|
Flatten,
|
|
InputDim,
|
|
ops,
|
|
Repeat,
|
|
Singleton,
|
|
Split,
|
|
view_groups,
|
|
)
|
|
from torch.distributed._tensor.placement_types import Placement
|
|
from torch.testing._internal.common_utils import run_tests
|
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|
DTensorTestBase,
|
|
redistribute_profiler,
|
|
with_comms,
|
|
)
|
|
from torch.utils._pytree import tree_flatten
|
|
|
|
|
|
class TestViewOps(DTensorTestBase):
|
|
def test_view_groups(self):
|
|
self.assertEqual(
|
|
view_groups([2, 3], [3, 2]),
|
|
(
|
|
Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 0),
|
|
Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 1),
|
|
),
|
|
)
|
|
self.assertEqual(
|
|
view_groups([3, 4, 5], [12, 5]),
|
|
(Flatten((InputDim(0), InputDim(1))), InputDim(2)),
|
|
)
|
|
self.assertEqual(
|
|
view_groups([2, 3, 4, 5, 7], [12, 70]),
|
|
(
|
|
Split(
|
|
Flatten(
|
|
(
|
|
InputDim(0),
|
|
InputDim(1),
|
|
InputDim(2),
|
|
InputDim(3),
|
|
InputDim(4),
|
|
)
|
|
),
|
|
(12, 70),
|
|
0,
|
|
),
|
|
Split(
|
|
Flatten(
|
|
(
|
|
InputDim(0),
|
|
InputDim(1),
|
|
InputDim(2),
|
|
InputDim(3),
|
|
InputDim(4),
|
|
)
|
|
),
|
|
(12, 70),
|
|
1,
|
|
),
|
|
),
|
|
)
|
|
self.assertEqual(
|
|
view_groups([2, 3, 4, 5, 7], [3, 8, 7, 5]),
|
|
(
|
|
Split(Flatten((InputDim(0), InputDim(1), InputDim(2))), (3, 8), 0),
|
|
Split(Flatten((InputDim(0), InputDim(1), InputDim(2))), (3, 8), 1),
|
|
Split(Flatten((InputDim(3), InputDim(4))), (7, 5), 0),
|
|
Split(Flatten((InputDim(3), InputDim(4))), (7, 5), 1),
|
|
),
|
|
)
|
|
self.assertEqual(
|
|
view_groups([3, 4, 8, 3], [12, 4, 2, 3]),
|
|
(
|
|
Flatten((InputDim(0), InputDim(1))),
|
|
Split(InputDim(2), (4, 2), 0),
|
|
Split(InputDim(2), (4, 2), 1),
|
|
InputDim(3),
|
|
),
|
|
)
|
|
self.assertEqual(
|
|
view_groups([3, 24], [1, 3, 2, 4, 1, 3, 1]),
|
|
(
|
|
Singleton(),
|
|
InputDim(0),
|
|
Split(InputDim(1), (2, 4, 3), 0),
|
|
Split(InputDim(1), (2, 4, 3), 1),
|
|
Singleton(),
|
|
Split(InputDim(1), (2, 4, 3), 2),
|
|
Singleton(),
|
|
),
|
|
)
|
|
self.assertEqual(
|
|
view_groups([1, 1, 3, 2, 1, 1], [6, 1, 1, 1]),
|
|
(
|
|
Flatten((InputDim(2), InputDim(3))),
|
|
Singleton(),
|
|
Singleton(),
|
|
Singleton(),
|
|
),
|
|
)
|
|
self.assertEqual(
|
|
view_groups([1, 1, 12, 1, 1, 1, 2, 5, 1], [3, 4, 1, 10]),
|
|
(
|
|
Split(InputDim(2), (3, 4), 0),
|
|
Split(InputDim(2), (3, 4), 1),
|
|
Singleton(),
|
|
Flatten((InputDim(6), InputDim(7))),
|
|
),
|
|
)
|
|
self.assertEqual(
|
|
view_groups([2, 3, 4], [2, -1, 4]),
|
|
(InputDim(0), InputDim(1), InputDim(2)),
|
|
)
|
|
|
|
@property
|
|
def world_size(self) -> int:
|
|
return 6
|
|
|
|
def call_dt_test(self, op, args, kwargs, device_mesh: DeviceMesh):
|
|
spec = ops[op]
|
|
rules = spec.dim_map(*args, **kwargs)
|
|
outputs = op(*args, **kwargs)
|
|
flat_args, _ = tree_flatten(args)
|
|
in_shape = flat_args[0].shape
|
|
|
|
no_shard_dims = set()
|
|
for rule in rules:
|
|
if isinstance(rule, Repeat):
|
|
if isinstance(rule.input_dim, InputDim):
|
|
no_shard_dims.add(rule.input_dim.input_dim)
|
|
elif isinstance(rule, Flatten):
|
|
for dim in rule.input_dims[1:]:
|
|
if isinstance(dim, InputDim):
|
|
no_shard_dims.add(dim.input_dim)
|
|
elif isinstance(rule, Split):
|
|
if isinstance(rule.input_dim, Flatten):
|
|
for dim in rule.input_dim.input_dims[1:]:
|
|
if isinstance(dim, InputDim):
|
|
no_shard_dims.add(dim.input_dim)
|
|
|
|
if op == torch.unbind:
|
|
no_shard_dims.add(kwargs.get("dim", 0))
|
|
|
|
sharding_choices = cast(List[Placement], [Replicate()]) + [
|
|
Shard(i) for i, s in enumerate(in_shape) if s > 1 and i not in no_shard_dims
|
|
]
|
|
|
|
all_sharding_choices = itertools.product(
|
|
*(device_mesh.ndim * [sharding_choices])
|
|
)
|
|
|
|
for in_shard in all_sharding_choices:
|
|
# print(f' |--- {in_shard}')
|
|
in_dt = distribute_tensor(args[0], device_mesh, in_shard)
|
|
|
|
with redistribute_profiler() as profiler:
|
|
out_dt = op(in_dt, *args[1:], **kwargs)
|
|
|
|
self.assertEqual(profiler.num_calls, 0, "Expected no redistribution.")
|
|
|
|
full_out = out_dt.redistribute(
|
|
device_mesh, device_mesh.ndim * [Replicate()]
|
|
).to_local()
|
|
|
|
if dist.get_rank() == 0:
|
|
self.assertEqual(outputs, full_out)
|
|
|
|
def dimmap_test(self, op, args, expected_rule_output):
|
|
rules = ops[op].dim_map(*args)
|
|
self.assertEqual(rules, expected_rule_output)
|
|
self.call_dt_test(op, args, {}, self.device_mesh)
|
|
|
|
@with_comms
|
|
def test_view_ops(self):
|
|
self.device_mesh = DeviceMesh(
|
|
self.device_type, torch.arange(dist.get_world_size()).view(-1, 2)
|
|
)
|
|
self.dimmap_test(torch.atleast_1d, (randn(()),), (Singleton(),))
|
|
self.dimmap_test(torch.atleast_1d, (randn(24),), (InputDim(0),))
|
|
self.dimmap_test(torch.atleast_1d, (randn(24, 36),), (InputDim(0), InputDim(1)))
|
|
|
|
self.dimmap_test(torch.atleast_2d, (randn(()),), (Singleton(), Singleton()))
|
|
self.dimmap_test(torch.atleast_2d, (randn(24),), (Singleton(), InputDim(0)))
|
|
self.dimmap_test(torch.atleast_2d, (randn(24, 36),), (InputDim(0), InputDim(1)))
|
|
self.dimmap_test(
|
|
torch.atleast_2d,
|
|
(randn(24, 36, 48),),
|
|
(InputDim(0), InputDim(1), InputDim(2)),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
torch.atleast_3d,
|
|
(randn(()),),
|
|
(Singleton(), Singleton(), Singleton()),
|
|
)
|
|
self.dimmap_test(
|
|
torch.atleast_3d,
|
|
(randn(24),),
|
|
(Singleton(), InputDim(0), Singleton()),
|
|
)
|
|
self.dimmap_test(
|
|
torch.atleast_3d,
|
|
(randn(24, 36),),
|
|
(InputDim(0), InputDim(1), Singleton()),
|
|
)
|
|
self.dimmap_test(
|
|
torch.atleast_3d,
|
|
(randn(24, 36, 42),),
|
|
(InputDim(0), InputDim(1), InputDim(2)),
|
|
)
|
|
self.dimmap_test(
|
|
torch.atleast_3d,
|
|
(randn(24, 36, 42, 24),),
|
|
(InputDim(0), InputDim(1), InputDim(2), InputDim(3)),
|
|
)
|
|
|
|
with self.assertRaises(AssertionError):
|
|
ops[torch.broadcast_to].dim_map(randn(24, 36), (1, 2, 4))
|
|
|
|
self.dimmap_test(
|
|
torch.broadcast_to,
|
|
(rand(24, 36), (1, 24, 36)),
|
|
(Singleton(), InputDim(0), InputDim(1)),
|
|
)
|
|
self.dimmap_test(
|
|
torch.broadcast_to,
|
|
(rand(24, 36), (42, 24, 36)),
|
|
(Broadcast(Singleton(), 42), InputDim(0), InputDim(1)),
|
|
)
|
|
self.dimmap_test(
|
|
torch.broadcast_to,
|
|
(rand(24, 1, 36), (12, 24, 24, 36)),
|
|
(
|
|
Broadcast(Singleton(), 12),
|
|
InputDim(0),
|
|
Broadcast(InputDim(1), 24),
|
|
InputDim(2),
|
|
),
|
|
)
|
|
self.dimmap_test(
|
|
torch.broadcast_to,
|
|
(rand(24, 36), (-1, 36)),
|
|
(InputDim(0), InputDim(1)),
|
|
)
|
|
self.dimmap_test(
|
|
torch.broadcast_to,
|
|
(rand(24, 1, 36), (-1, 1, 36)),
|
|
(InputDim(0), InputDim(1), InputDim(2)),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
torch.broadcast_to,
|
|
(randn(36, 1, 24), (12, 36, 42, 24)),
|
|
(
|
|
Broadcast(Singleton(), 12),
|
|
InputDim(0),
|
|
Broadcast(InputDim(1), 42),
|
|
InputDim(2),
|
|
),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
Tensor.expand,
|
|
(randn(24, 1, 36, 1), 36, 24, 42, -1, 24),
|
|
(
|
|
Broadcast(Singleton(), 36),
|
|
InputDim(0),
|
|
Broadcast(InputDim(1), 42),
|
|
InputDim(2),
|
|
Broadcast(InputDim(3), 24),
|
|
),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
Tensor.expand,
|
|
(randn(24, 1, 36, 1), (36, 24, 42, -1, 24)),
|
|
(
|
|
Broadcast(Singleton(), 36),
|
|
InputDim(0),
|
|
Broadcast(InputDim(1), 42),
|
|
InputDim(2),
|
|
Broadcast(InputDim(3), 24),
|
|
),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
torch.flatten,
|
|
(randn(24, 36),),
|
|
(Flatten((InputDim(0), InputDim(1))),),
|
|
)
|
|
self.dimmap_test(torch.flatten, (randn(42),), (InputDim(0),))
|
|
self.dimmap_test(torch.flatten, (randn(()),), (Singleton(),))
|
|
|
|
self.dimmap_test(
|
|
torch.movedim,
|
|
(randn(12, 24, 48, 96), 1, 2),
|
|
(InputDim(0), InputDim(2), InputDim(1), InputDim(3)),
|
|
)
|
|
self.dimmap_test(
|
|
torch.movedim,
|
|
(randn(6, 12, 24), 1, 0),
|
|
(InputDim(1), InputDim(0), InputDim(2)),
|
|
)
|
|
self.dimmap_test(
|
|
torch.movedim,
|
|
(randn(24, 12, 6), (1, 2), (0, 1)),
|
|
(InputDim(1), InputDim(2), InputDim(0)),
|
|
)
|
|
self.dimmap_test(
|
|
torch.movedim,
|
|
(randn(24, 6, 12), (0, 2, 1), (2, 1, 0)),
|
|
(InputDim(1), InputDim(2), InputDim(0)),
|
|
)
|
|
self.dimmap_test(
|
|
torch.movedim,
|
|
(randn(24, 12), (1, 0), (0, 1)),
|
|
(InputDim(1), InputDim(0)),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
torch.movedim,
|
|
(randn(36, 24, 12), (1, 2), (0, 1)),
|
|
(InputDim(1), InputDim(2), InputDim(0)),
|
|
)
|
|
self.dimmap_test(
|
|
torch.movedim,
|
|
(randn(36, 24, 12), (1, 2), (-3, -2)),
|
|
(InputDim(1), InputDim(2), InputDim(0)),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
torch.permute,
|
|
(randn(24, 36, 42), (2, 0, 1)),
|
|
(InputDim(2), InputDim(0), InputDim(1)),
|
|
)
|
|
self.dimmap_test(
|
|
torch.permute,
|
|
(randn(24, 36, 42), (-1, -3, -2)),
|
|
(InputDim(2), InputDim(0), InputDim(1)),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
torch.ravel,
|
|
(randn(24, 36),),
|
|
(Flatten((InputDim(0), InputDim(1))),),
|
|
)
|
|
self.dimmap_test(torch.ravel, (randn(42),), (InputDim(0),))
|
|
self.dimmap_test(torch.ravel, (randn(()),), (Singleton(),))
|
|
|
|
self.dimmap_test(
|
|
Tensor.repeat,
|
|
(randn(24, 36), 1, 2, 1, 1, 2),
|
|
(
|
|
Singleton(),
|
|
Broadcast(Singleton(), 2),
|
|
Singleton(),
|
|
InputDim(0),
|
|
Repeat(InputDim(1), 2),
|
|
),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
torch.reshape,
|
|
(randn(6, 12, 24), (72, 24)),
|
|
(Flatten((InputDim(0), InputDim(1))), InputDim(2)),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
torch.tile,
|
|
(randn(24, 36), (1, 2, 1, 1, 2)),
|
|
(
|
|
Singleton(),
|
|
Broadcast(Singleton(), 2),
|
|
Singleton(),
|
|
InputDim(0),
|
|
Repeat(InputDim(1), 2),
|
|
),
|
|
)
|
|
self.dimmap_test(
|
|
torch.tile,
|
|
(randn(42, 24, 36), (1, 3)),
|
|
(InputDim(0), InputDim(1), Repeat(InputDim(2), 3)),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
torch.transpose,
|
|
(randn(24, 60, 42, 60), 2, 0),
|
|
(InputDim(2), InputDim(1), InputDim(0), InputDim(3)),
|
|
)
|
|
self.dimmap_test(
|
|
torch.transpose,
|
|
(randn(24, 60, 42, 60), -1, 0),
|
|
(InputDim(3), InputDim(1), InputDim(2), InputDim(0)),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
torch.unsqueeze,
|
|
(randn(42, 24, 36), 1),
|
|
(InputDim(0), Singleton(), InputDim(1), InputDim(2)),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
Tensor.view,
|
|
(randn(6, 12, 24), 72, 24),
|
|
(Flatten((InputDim(0), InputDim(1))), InputDim(2)),
|
|
)
|
|
|
|
self.dimmap_test(Tensor.view, (randn(1, 1, 12), -1), (InputDim(2),))
|
|
|
|
self.dimmap_test(
|
|
Tensor.view,
|
|
(randn(1, 1, 42, 24), -1),
|
|
(Flatten((InputDim(2), InputDim(3))),),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
Tensor.view,
|
|
(randn(1, 1, 42, 1, 24, 1), -1),
|
|
(Flatten((InputDim(2), InputDim(4))),),
|
|
)
|
|
|
|
self.dimmap_test(
|
|
Tensor.view,
|
|
(randn(48, 35, 26), (24, 4, 35, 13)),
|
|
(
|
|
Split(
|
|
Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))),
|
|
group_shape=(24, 4, 35, 13),
|
|
split_id=0,
|
|
),
|
|
Split(
|
|
Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))),
|
|
group_shape=(24, 4, 35, 13),
|
|
split_id=1,
|
|
),
|
|
Split(
|
|
Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))),
|
|
group_shape=(24, 4, 35, 13),
|
|
split_id=2,
|
|
),
|
|
Split(
|
|
Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))),
|
|
group_shape=(24, 4, 35, 13),
|
|
split_id=3,
|
|
),
|
|
),
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|