mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68786 To enable the auto grad for the sharded linear, we find we need to make some changes to the current nn function api (c10d api with auto grad enabled). So we made the following several changes: 1. Add a new api `reduce_scatter` since we need it in the rowwise sharding. 2. Modify the `all_to_all` api to make sure it consistent with the ones in distributed_c10d.py. 3. Found the cpp input params of `reduce_scatter` is missing input param, added more unit test to cover these cases. 4. Sync the NN test from gloo to nccl. ghstack-source-id: 144860208 Test Plan: CI + Unit Test Reviewed By: pritamdamania87 Differential Revision: D32569674 fbshipit-source-id: 9bd613f91bbf7a39eede0af32a5a5db0f2ade43b
159 lines
6.3 KiB
Python
159 lines
6.3 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
import sys
|
|
import test_c10d_spawn
|
|
import torch
|
|
import torch.distributed as c10d
|
|
from test_c10d_spawn import _torch_dist_nn_available, TestDistributedNNFunctions
|
|
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
|
from torch.testing._internal.common_distributed import (
|
|
requires_nccl,
|
|
skip_if_lt_x_gpu,
|
|
)
|
|
from torch.testing._internal.common_utils import (
|
|
TestCase,
|
|
run_tests,
|
|
sandcastle_skip_if,
|
|
TEST_WITH_DEV_DBG_ASAN,
|
|
)
|
|
|
|
NO_NCCL = not hasattr(c10d, "ProcessGroupNCCL")
|
|
|
|
# Fails on Python-3.9, see https://github.com/pytorch/pytorch/issues/51619
|
|
if sys.version_info < (3, 9):
|
|
|
|
class ProcessGroupShareTensorTest(
|
|
test_c10d_spawn.AbstractProcessGroupShareTensorTest, TestCase
|
|
):
|
|
@classmethod
|
|
def _init_pg_nccl(cls, rank, filename, world_size):
|
|
store = c10d.FileStore(filename, world_size)
|
|
return c10d.ProcessGroupNCCL(store, rank, world_size)
|
|
|
|
@sandcastle_skip_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
|
|
@sandcastle_skip_if(NO_NCCL, "NCCL needed")
|
|
def test_shared_broadcast_nccl(self):
|
|
self._test_multiprocess(
|
|
ProcessGroupShareTensorTest._test_broadcast_process,
|
|
[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
|
|
ProcessGroupShareTensorTest._init_pg_nccl,
|
|
1,
|
|
)
|
|
|
|
@sandcastle_skip_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
|
|
@sandcastle_skip_if(NO_NCCL, "NCCL needed")
|
|
def test_shared_allreduce_nccl(self):
|
|
self._test_multiprocess(
|
|
ProcessGroupShareTensorTest._test_allreduce_process,
|
|
[torch.ones(2, 2).to(i) for i in range(self.world_size)],
|
|
ProcessGroupShareTensorTest._init_pg_nccl,
|
|
1,
|
|
)
|
|
|
|
@classmethod
|
|
def _test_reduce_process(
|
|
cls, rank, filename, shared_tensors, world_size, init_pg, c2p, p2c
|
|
):
|
|
pg = init_pg(rank, filename, world_size)
|
|
x = shared_tensors[rank]
|
|
pg.reduce(x, root=0, op=c10d.ReduceOp.SUM).wait()
|
|
if rank == 0:
|
|
c2p.put((rank, torch.ones(2, 2) * 2, x.to("cpu")))
|
|
else:
|
|
c2p.put((rank, torch.ones(2, 2), x.to("cpu")))
|
|
p2c.get()
|
|
|
|
@sandcastle_skip_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
|
|
@sandcastle_skip_if(NO_NCCL, "NCCL needed")
|
|
def test_shared_reduce_nccl(self):
|
|
self._test_multiprocess(
|
|
ProcessGroupShareTensorTest._test_reduce_process,
|
|
[torch.ones(2, 2).to(i) for i in range(self.world_size)],
|
|
ProcessGroupShareTensorTest._init_pg_nccl,
|
|
1,
|
|
)
|
|
|
|
@sandcastle_skip_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
|
|
@sandcastle_skip_if(NO_NCCL, "NCCL needed")
|
|
def test_shared_allgather_nccl(self):
|
|
self._test_multiprocess(
|
|
ProcessGroupShareTensorTest._test_allgather_process,
|
|
[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
|
|
ProcessGroupShareTensorTest._init_pg_nccl,
|
|
self.world_size,
|
|
)
|
|
|
|
|
|
# Skip dev-asan as torch + multiprocessing spawn have known issues
|
|
if not TEST_WITH_DEV_DBG_ASAN:
|
|
|
|
class TestDistributedNNFunctionsNccl(TestDistributedNNFunctions):
|
|
# Test Common Ops First.
|
|
@requires_nccl()
|
|
@skip_if_lt_x_gpu(2)
|
|
@sandcastle_skip_if(
|
|
not _torch_dist_nn_available, "torch.distributed.nn is not available"
|
|
)
|
|
def test_broadcast(self):
|
|
self._test_broadcast("nccl")
|
|
|
|
@requires_nccl()
|
|
@skip_if_lt_x_gpu(2)
|
|
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
|
|
def test_reduce(self):
|
|
self._test_reduce("nccl")
|
|
|
|
@requires_nccl()
|
|
@skip_if_lt_x_gpu(2)
|
|
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
|
|
def test_allreduce(self):
|
|
self._test_allreduce("nccl")
|
|
|
|
@requires_nccl()
|
|
@skip_if_lt_x_gpu(2)
|
|
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
|
|
def test_all_gather(self):
|
|
self._test_all_gather("nccl")
|
|
|
|
@requires_nccl()
|
|
@skip_if_lt_x_gpu(2)
|
|
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
|
|
def test_all_to_all(self):
|
|
self._test_all_to_all("nccl")
|
|
|
|
@requires_nccl()
|
|
@skip_if_lt_x_gpu(2)
|
|
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
|
|
def test_all_to_all_single(self):
|
|
self._test_all_to_all_single("nccl")
|
|
|
|
# Test Ops only supported in NCCL.
|
|
@requires_nccl()
|
|
@skip_if_lt_x_gpu(2)
|
|
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
|
|
def test_reduce_scatter(self):
|
|
store = c10d.FileStore(self.file_name, self.world_size)
|
|
# This is required because these functions calls directly to the .dist and needs
|
|
# the world to be initialized
|
|
c10d.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='nccl')
|
|
device = torch.device(f"cuda:{self.rank}")
|
|
x0 = torch.ones(5, 5, device=device) + self.rank
|
|
x1 = torch.ones(5, 5, device=device) + self.rank + 1
|
|
x0.requires_grad = True
|
|
x1.requires_grad = True
|
|
y = torch.empty_like(x0)
|
|
expected = (1 + self.world_size) * self.world_size / 2 + self.world_size * self.rank
|
|
y = torch.distributed.nn.reduce_scatter(y, [x0, x1])
|
|
self.assertEqual(y, torch.ones(5, 5, device=device) * expected)
|
|
z = y.sin().sum()
|
|
z.backward()
|
|
expected_0 = (1 + self.world_size) * self.world_size / 2
|
|
expected_1 = expected_0 + self.world_size
|
|
x_s_0 = (expected_0 * torch.ones(5, 5, device=device)).cos()
|
|
x_s_1 = (expected_1 * torch.ones(5, 5, device=device)).cos()
|
|
self.assertEqual(x0.grad, x_s_0)
|
|
self.assertEqual(x1.grad, x_s_1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|