mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
as titled, mainly moving files Pull Request resolved: https://github.com/pytorch/pytorch/pull/144166 Approved by: https://github.com/Skylion007
221 lines
7.4 KiB
Python
221 lines
7.4 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.distributed._functional_collectives as funcol
|
|
import torch.nn as nn
|
|
from torch.distributed._tensor import DeviceMesh, DTensor
|
|
from torch.distributed._tensor.placement_types import Shard
|
|
from torch.distributed.tensor.debug import CommDebugMode
|
|
from torch.testing._internal.common_distributed import requires_nccl
|
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule
|
|
from torch.testing._internal.distributed.fake_pg import FakeStore
|
|
|
|
|
|
c10d_functional = torch.ops.c10d_functional
|
|
c10d_ops = torch.ops.c10d
|
|
|
|
|
|
class TestCommMode(TestCase):
|
|
def tearDown(self):
|
|
super().tearDown()
|
|
dist.destroy_process_group()
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.world_size = 2
|
|
store = FakeStore()
|
|
dist.init_process_group(
|
|
backend="fake", rank=1, world_size=self.world_size, store=store
|
|
)
|
|
self.device_type = "cuda" if torch.cuda.is_available() else "cpu"
|
|
self.world_pg = dist.distributed_c10d._get_default_group()
|
|
|
|
def checksAssert(self, comm_mode, key, expected_value, expected_total_value):
|
|
comm_counts = comm_mode.get_comm_counts()
|
|
self.assertEqual(comm_mode.get_total_counts(), expected_total_value)
|
|
self.assertEqual(comm_counts[key], expected_value)
|
|
|
|
return
|
|
|
|
def test_comm_mode(self):
|
|
world_pg = self.world_pg
|
|
|
|
class WrapperModel(nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.model = MLPModule(device=device)
|
|
|
|
def forward(self, x):
|
|
x = funcol.all_gather_tensor(x, 0, world_pg)
|
|
x = funcol.reduce_scatter_tensor(x, "sum", 0, world_pg)
|
|
out = self.model(x)
|
|
return funcol.all_reduce(out, "sum", world_pg)
|
|
|
|
model = WrapperModel(self.device_type)
|
|
|
|
comm_mode = CommDebugMode()
|
|
with comm_mode:
|
|
model(torch.randn(20, 10, device=self.device_type))
|
|
|
|
comm_counts = comm_mode.get_comm_counts()
|
|
self.assertEqual(comm_mode.get_total_counts(), 3)
|
|
self.assertEqual(comm_counts[c10d_functional.all_reduce], 1)
|
|
self.assertEqual(comm_counts[c10d_functional.all_gather_into_tensor], 1)
|
|
self.assertEqual(comm_counts[c10d_functional.reduce_scatter_tensor], 1)
|
|
|
|
def test_comm_mode_coalesced(self):
|
|
world_pg = self.world_pg
|
|
|
|
class WrapperModelCoalesced(nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.model = MLPModule(device=device)
|
|
|
|
def forward(self, x):
|
|
x = funcol.all_gather_tensor(x, 0, world_pg)
|
|
x = funcol.reduce_scatter_tensor(x, "sum", 0, world_pg)
|
|
out = self.model(x)
|
|
return funcol.all_reduce_coalesced([out], "sum", world_pg)
|
|
|
|
model = WrapperModelCoalesced(self.device_type)
|
|
|
|
comm_mode = CommDebugMode()
|
|
with comm_mode:
|
|
model(torch.randn(20, 10, device=self.device_type))
|
|
|
|
comm_counts = comm_mode.get_comm_counts()
|
|
self.assertEqual(comm_mode.get_total_counts(), 3)
|
|
self.assertEqual(comm_counts[c10d_functional.all_reduce_coalesced], 1)
|
|
self.assertEqual(comm_counts[c10d_functional.all_gather_into_tensor], 1)
|
|
self.assertEqual(comm_counts[c10d_functional.reduce_scatter_tensor], 1)
|
|
|
|
def test_comm_mode_with_dtensor(self):
|
|
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
|
|
|
def f(x, y):
|
|
return torch.mm(x, y)
|
|
|
|
comm_mode = CommDebugMode()
|
|
x = torch.randn(4, 8, requires_grad=True)
|
|
y = torch.randn(4, 32, requires_grad=True)
|
|
x_dtensor = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
|
|
y_dtensor = DTensor.from_local(y, mesh, [Shard(0)], run_check=False)
|
|
|
|
with comm_mode:
|
|
f(x_dtensor, y_dtensor)
|
|
|
|
comm_counts = comm_mode.get_comm_counts()
|
|
self.assertEqual(comm_mode.get_total_counts(), 1)
|
|
self.assertEqual(comm_counts[c10d_functional.all_reduce], 0)
|
|
self.assertEqual(comm_counts[c10d_functional.all_gather_into_tensor], 1)
|
|
self.assertEqual(comm_counts[c10d_functional.reduce_scatter_tensor], 0)
|
|
|
|
@requires_nccl()
|
|
def test_comm_mode_with_c10d(self):
|
|
if not torch.cuda.is_available():
|
|
return
|
|
|
|
inp = torch.rand(2, 8, 16).cuda()
|
|
all_gather_out = inp.new_empty(self.world_size * 2, 8, 16)
|
|
|
|
comm_mode = CommDebugMode()
|
|
|
|
# tests c10d all_reduce tracing
|
|
with comm_mode:
|
|
dist.all_reduce(inp)
|
|
|
|
self.checksAssert(comm_mode, c10d_ops.allreduce_, 1, 1)
|
|
|
|
# tests c10d all_gather_into_tensor tracing
|
|
with comm_mode:
|
|
dist.all_gather_into_tensor(all_gather_out, inp)
|
|
|
|
self.checksAssert(comm_mode, c10d_ops._allgather_base_, 1, 1)
|
|
|
|
# tests c10d reduce_scatter tracing
|
|
with comm_mode:
|
|
dist.reduce_scatter_tensor(inp, all_gather_out)
|
|
|
|
self.checksAssert(comm_mode, c10d_ops._reduce_scatter_base_, 1, 1)
|
|
|
|
# tests c10d broadcast tracing
|
|
with comm_mode:
|
|
dist.broadcast(inp, 0)
|
|
|
|
self.checksAssert(comm_mode, c10d_ops.broadcast_, 1, 1)
|
|
|
|
# tests c10d gather tracing
|
|
with comm_mode:
|
|
dist.gather(inp, None, 0)
|
|
|
|
self.checksAssert(comm_mode, c10d_ops.gather_, 1, 1)
|
|
|
|
# tests c10d reduce tracing
|
|
with comm_mode:
|
|
dist.reduce(inp, 0)
|
|
|
|
self.checksAssert(comm_mode, c10d_ops.reduce_, 1, 1)
|
|
|
|
# tests c10d scatter tracing
|
|
with comm_mode:
|
|
dist.scatter(inp, None, 0)
|
|
|
|
self.checksAssert(comm_mode, c10d_ops.scatter_, 1, 1)
|
|
|
|
# tests c10d all_gather tracing
|
|
output_list = []
|
|
|
|
with comm_mode:
|
|
dist.all_gather(output_list, inp, None)
|
|
|
|
self.checksAssert(comm_mode, c10d_ops.allgather_, 1, 1)
|
|
|
|
# tests c10d allgather_coalesced_ tracing
|
|
output_list = []
|
|
|
|
with comm_mode:
|
|
dist.all_gather_coalesced(output_list, [inp], None)
|
|
|
|
self.checksAssert(comm_mode, c10d_ops.allgather_coalesced_, 1, 1)
|
|
|
|
# tests c10d allgather_into_tensor_coalesced_ tracing
|
|
with comm_mode, dist._coalescing_manager():
|
|
dist.all_gather_into_tensor(all_gather_out, inp)
|
|
|
|
self.checksAssert(comm_mode, c10d_ops.allgather_into_tensor_coalesced_, 1, 1)
|
|
|
|
# tests c10d allreduce_coalesced
|
|
with comm_mode:
|
|
dist.all_reduce_coalesced(inp)
|
|
|
|
self.checksAssert(comm_mode, c10d_ops.allreduce_coalesced_, 1, 1)
|
|
|
|
# tests c10d reduce_scatter_
|
|
with comm_mode:
|
|
dist.reduce_scatter(all_gather_out, [inp])
|
|
|
|
self.checksAssert(comm_mode, c10d_ops.reduce_scatter_, 1, 1)
|
|
|
|
# tests c10d reduce_scatter_tensor_coalesced
|
|
with comm_mode, dist._coalescing_manager():
|
|
dist.reduce_scatter_tensor(all_gather_out, inp)
|
|
|
|
self.checksAssert(comm_mode, c10d_ops.reduce_scatter_tensor_coalesced_, 1, 1)
|
|
|
|
# tests c10d alltoall_
|
|
with comm_mode:
|
|
dist.all_to_all([inp], [inp])
|
|
|
|
self.checksAssert(comm_mode, c10d_ops.alltoall_, 1, 1)
|
|
|
|
# tests c10d alltoall_base_
|
|
with comm_mode:
|
|
dist.all_to_all_single(inp, inp)
|
|
|
|
self.checksAssert(comm_mode, c10d_ops.alltoall_base_, 1, 1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|