From 7c2489bdae5a96dc122c3bb7b42c18528bcfdc86 Mon Sep 17 00:00:00 2001 From: Junjie Wang Date: Mon, 6 Dec 2021 13:37:10 -0800 Subject: [PATCH] [PyTorch][Distributed] Enable Reduce Scatter and modify all_to_all for sharded linear with more test cases. (#68786) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68786 To enable the auto grad for the sharded linear, we find we need to make some changes to the current nn function api (c10d api with auto grad enabled). So we made the following several changes: 1. Add a new api `reduce_scatter` since we need it in the rowwise sharding. 2. Modify the `all_to_all` api to make sure it consistent with the ones in distributed_c10d.py. 3. Found the cpp input params of `reduce_scatter` is missing input param, added more unit test to cover these cases. 4. Sync the NN test from gloo to nccl. ghstack-source-id: 144860208 Test Plan: CI + Unit Test Reviewed By: pritamdamania87 Differential Revision: D32569674 fbshipit-source-id: 9bd613f91bbf7a39eede0af32a5a5db0f2ade43b --- test/distributed/test_c10d_nccl.py | 29 ++++ test/distributed/test_c10d_spawn.py | 158 +++++++++++++++++++++- test/distributed/test_c10d_spawn_gloo.py | 160 +++++------------------ test/distributed/test_c10d_spawn_nccl.py | 104 +++++++++++++-- torch/csrc/distributed/c10d/init.cpp | 9 +- torch/distributed/nn/functional.py | 130 ++++++++++++++++-- 6 files changed, 440 insertions(+), 150 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 8199ea788f0..e62e6f41e6f 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -611,6 +611,35 @@ class ProcessGroupNCCLTest(MultiProcessTestCase): # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 self.assertEqualIgnoreType(expected, output[i]) + # Test the input params overridden scenarios, aka, when the input is + # a list and output is just one tensor. + # Sum + output_tensor = torch.empty_like(input_per_gpu[0][0]).cuda(self.rank) + input_list = [tensor[0].cuda(self.rank) for tensor in input_per_gpu] + pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.SUM).wait() + expected = torch.tensor( + float((1 + self.world_size) * self.world_size / 2) + self.world_size * self.rank + ) + self.assertEqualIgnoreType(expected, output_tensor) + + # Min + pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.MIN).wait() + expected = torch.tensor(self.rank + 1) + self.assertEqualIgnoreType(expected, output_tensor) + + # Max + pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.MAX).wait() + expected = torch.tensor(self.rank + self.world_size) + self.assertEqualIgnoreType(expected, output_tensor) + + # Product + pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.PRODUCT).wait() + prod_val = self.rank + 1 + for k in range(1, self.world_size): + prod_val = prod_val * (self.rank + 1 + k) + expected = torch.tensor(prod_val) + self.assertEqualIgnoreType(expected, output_tensor) + @requires_nccl() @sandcastle_skip_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs") def test_reduce_scatter_base_ops(self): diff --git a/test/distributed/test_c10d_spawn.py b/test/distributed/test_c10d_spawn.py index 7cc11d3b2de..0e87bdc1729 100644 --- a/test/distributed/test_c10d_spawn.py +++ b/test/distributed/test_c10d_spawn.py @@ -1,13 +1,16 @@ # Owner(s): ["oncall: distributed"] +import os import sys import tempfile import torch import torch.distributed as c10d import torch.multiprocessing as mp -from torch.testing._internal.common_utils import NO_MULTIPROCESSING_SPAWN -from torch.testing._internal.common_utils import load_tests +from torch.testing._internal.common_distributed import \ + MultiProcessTestCase +from torch.testing._internal.common_utils import load_tests,\ + NO_MULTIPROCESSING_SPAWN # Torch distributed.nn is not available in windows # check #42095, it errors on import. @@ -96,3 +99,154 @@ class AbstractProcessGroupShareTensorTest(object): c2p.put((rank, torch.ones(2, 2) * i, ys[0][i].to("cpu"))) p2c.get() + + +class TestDistributedNNFunctions(MultiProcessTestCase): + def setUp(self): + super(TestDistributedNNFunctions, self).setUp() + self._spawn_processes() + + def tearDown(self): + super(TestDistributedNNFunctions, self).tearDown() + try: + os.remove(self.file_name) + except OSError: + pass + + @property + def op_timeout_sec(self): + return 1 + + @property + def world_size(self): + return 2 + + def _test_broadcast(self, backend): + store = c10d.FileStore(self.file_name, self.world_size) + # This is required because these functions calls directly to the .dist and needs + # the world to be initialized + c10d.init_process_group( + store=store, rank=self.rank, world_size=self.world_size, backend=backend + ) + device = torch.device(f"cuda:{self.rank}") + x = torch.ones(5, 5, device=device) + self.rank + x.requires_grad = True + y = torch.distributed.nn.broadcast(x, 1) + self.assertEqual(y, 1 + torch.ones(5, 5)) + z = y.sin().sum() + z.backward() + # We can't check the gradient of communications numerically so we have to do some calculations + if self.rank == 1: + self.assertEqual(x.grad, 2 * torch.cos(x)) + elif self.rank == 0: + self.assertEqual(x.grad, torch.zeros(5, 5, device=device)) + + def _test_reduce(self, backend): + store = c10d.FileStore(self.file_name, self.world_size) + # This is required because these functions calls directly to the .dist and needs + # the world to be initialized + c10d.init_process_group( + store=store, rank=self.rank, world_size=self.world_size, backend=backend + ) + device = torch.device(f"cuda:{self.rank}") + x = torch.ones(5, 5, device=device) + self.rank + x.requires_grad = True + y = torch.distributed.nn.reduce(x, 1, op=c10d.ReduceOp.SUM) + + if self.rank == 1: + self.assertEqual(y, 3 * torch.ones(5, 5, device=device)) + + z = y.sin().sum() + z.backward() + # Gradients are broadcasted to both ranks + x_g = (3 * torch.ones(5, 5, device=device)).cos() + self.assertEqual(x.grad, x_g) + + def _test_allreduce(self, backend): + store = c10d.FileStore(self.file_name, self.world_size) + # This is required because these functions calls directly to the .dist and needs + # the world to be initialized + c10d.init_process_group( + store=store, rank=self.rank, world_size=self.world_size, backend=backend + ) + device = torch.device(f"cuda:{self.rank}") + x = torch.ones(5, 5, device=device) + self.rank + x.requires_grad = True + y = torch.distributed.nn.all_reduce(x, op=c10d.ReduceOp.SUM) + + self.assertEqual(y, 3 * torch.ones(5, 5, device=device)) + + z = y.sin().sum() + z.backward() + x_g = 2 * (3 * torch.ones(5, 5, device=device)).cos() + self.assertEqual(x.grad, x_g) + + def _test_all_gather(self, backend): + store = c10d.FileStore(self.file_name, self.world_size) + # This is required because these functions calls directly to the .dist and needs + # the world to be initialized + c10d.init_process_group( + store=store, rank=self.rank, world_size=self.world_size, backend=backend + ) + device = torch.device(f"cuda:{self.rank}") + x = torch.ones(5, 5, device=device) + self.rank + x.requires_grad = True + tensors = torch.distributed.nn.all_gather(x) + for i, t in enumerate(tensors): + self.assertEqual(t, torch.ones(5, 5, device=device) + i) + y = torch.sum(torch.stack(tensors), axis=0) + z = y.sin().sum() + z.backward() + + x_s = 2 * (3 * torch.ones(5, 5, device=device)).cos() + self.assertEqual(x.grad, x_s) + + def _test_all_to_all(self, backend): + store = c10d.FileStore(self.file_name, self.world_size) + # This is required because these functions calls directly to the .dist and needs + # the world to be initialized + c10d.init_process_group( + store=store, rank=self.rank, world_size=self.world_size, backend=backend + ) + device = torch.device(f"cuda:{self.rank}") + x0 = torch.ones(5, 5, device=device) + 2 * self.rank + x1 = torch.ones(5, 5, device=device) + 2 * self.rank + x0.requires_grad = True + x1.requires_grad = True + y0 = torch.empty_like(x0) + y1 = torch.empty_like(x1) + tensors = torch.distributed.nn.all_to_all([y0, y1], [x0, x1]) + for i, t in enumerate(tensors): + self.assertEqual(t, torch.ones(5, 5, device=device) + 2 * i) + y = torch.sum(torch.stack(tensors), axis=0) + z = y.sin().sum() + z.backward() + x_s = (4 * torch.ones(5, 5, device=device)).cos() + self.assertEqual(x0.grad, x_s) + self.assertEqual(x1.grad, x_s) + + def _test_all_to_all_single(self, backend): + store = c10d.FileStore(self.file_name, self.world_size) + # This is required because these functions calls directly to the .dist and needs + # the world to be initialized + c10d.init_process_group( + store=store, rank=self.rank, world_size=self.world_size, backend=backend + ) + device = torch.device(f"cuda:{self.rank}") + row = self.world_size * (self.rank + 1) * (self.world_size + 1) / 2 + x = torch.ones(int(row), 5, device=device) * (self.rank + 1) + x.requires_grad = True + y = torch.empty_like(x) + split_sizes = [(i + 1) * (self.rank + 1) for i in range(self.world_size)] + y = torch.distributed.nn.all_to_all_single( + y, x, output_split_sizes=split_sizes, input_split_sizes=split_sizes + ) + expected = [] + for idx, tensor in enumerate(torch.split(x, split_sizes)): + expected.append(torch.full_like(tensor, (idx + 1))) + expected = torch.cat(expected) + self.assertEqual(y, expected) + z = y.sin().sum() + z.backward() + x_s = ((self.rank + 1) * torch.ones(int(row), 5, device=device)).cos() + self.assertEqual(x.grad, x_s) diff --git a/test/distributed/test_c10d_spawn_gloo.py b/test/distributed/test_c10d_spawn_gloo.py index 02040152342..e13f17350ff 100644 --- a/test/distributed/test_c10d_spawn_gloo.py +++ b/test/distributed/test_c10d_spawn_gloo.py @@ -9,10 +9,10 @@ import test_c10d_spawn import torch import torch.distributed as c10d import torch.nn as nn -from test_c10d_spawn import _torch_dist_nn_available +from test_c10d_spawn import _torch_dist_nn_available, TestDistributedNNFunctions from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU from torch.testing._internal.common_distributed import requires_gloo, \ - create_device, MultiProcessTestCase, skip_if_lt_x_gpu + create_device, skip_if_lt_x_gpu from torch.testing._internal.common_utils import TestCase, run_tests, sandcastle_skip_if, TEST_WITH_DEV_DBG_ASAN # Fails on Python-3.9, see https://github.com/pytorch/pytorch/issues/51619 @@ -176,47 +176,45 @@ class DistributedDataParallelSingleProcessTest(TestCase): # Skip dev-asan as torch + multiprocessing spawn have known issues if not TEST_WITH_DEV_DBG_ASAN: - class TestDistributedNNFunctions(MultiProcessTestCase): - def setUp(self): - super(TestDistributedNNFunctions, self).setUp() - self._spawn_processes() - - def tearDown(self): - super(TestDistributedNNFunctions, self).tearDown() - try: - os.remove(self.file_name) - except OSError: - pass - - @property - def op_timeout_sec(self): - return 1 - - @property - def world_size(self): - return 2 - + class TestDistributedNNFunctionsGloo(TestDistributedNNFunctions): + # Test Common Ops First. @requires_gloo() @skip_if_lt_x_gpu(2) @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available") def test_broadcast(self): - store = c10d.FileStore(self.file_name, self.world_size) - # This is required because these functions calls directly to the .dist and needs - # the world to be initialized - c10d.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='gloo') - device = torch.device(f"cuda:{self.rank}") - x = torch.ones(5, 5, device=device) + self.rank - x.requires_grad = True - y = torch.distributed.nn.broadcast(x, 1) - self.assertEqual(y, 1 + torch.ones(5, 5)) - z = y.sin().sum() - z.backward() - # We can't check the gradient of communications numerically so we have to do some calculations - if self.rank == 1: - self.assertEqual(x.grad, 2 * torch.cos(x)) - elif self.rank == 0: - self.assertEqual(x.grad, torch.zeros(5, 5, device=device)) + self._test_broadcast("gloo") + @requires_gloo() + @skip_if_lt_x_gpu(2) + @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available") + def test_reduce(self): + self._test_reduce("gloo") + + @requires_gloo() + @skip_if_lt_x_gpu(2) + @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available") + def test_allreduce(self): + self._test_allreduce("gloo") + + @requires_gloo() + @skip_if_lt_x_gpu(2) + @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available") + def test_all_gather(self): + self._test_all_gather("gloo") + + @requires_gloo() + @skip_if_lt_x_gpu(2) + @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available") + def test_all_to_all(self): + self._test_all_to_all("gloo") + + @requires_gloo() + @skip_if_lt_x_gpu(2) + @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available") + def test_all_to_all_single(self): + self._test_all_to_all_single("gloo") + + # Test Ops only supported in GLOO. @requires_gloo() @skip_if_lt_x_gpu(2) @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available") @@ -275,92 +273,6 @@ if not TEST_WITH_DEV_DBG_ASAN: if self.rank == 0: self.assertEqual(x0.grad, torch.zeros(5, 5, device=device)) - @requires_gloo() - @skip_if_lt_x_gpu(2) - @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available") - def test_reduce(self): - store = c10d.FileStore(self.file_name, self.world_size) - # This is required because these functions calls directly to the .dist and needs - # the world to be initialized - c10d.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='gloo') - device = torch.device(f"cuda:{self.rank}") - x = torch.ones(5, 5, device=device) + self.rank - x.requires_grad = True - y = torch.distributed.nn.reduce(x, 1, op=c10d.ReduceOp.SUM) - - if self.rank == 1: - self.assertEqual(y, 3 * torch.ones(5, 5, device=device)) - - z = y.sin().sum() - z.backward() - # Gradients are broadcasted to both ranks - x_g = (3 * torch.ones(5, 5, device=device)).cos() - self.assertEqual(x.grad, x_g) - - @requires_gloo() - @skip_if_lt_x_gpu(2) - @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available") - def test_allreduce(self): - store = c10d.FileStore(self.file_name, self.world_size) - # This is required because these functions calls directly to the .dist and needs - # the world to be initialized - c10d.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='gloo') - device = torch.device(f"cuda:{self.rank}") - x = torch.ones(5, 5, device=device) + self.rank - x.requires_grad = True - y = torch.distributed.nn.all_reduce(x, op=c10d.ReduceOp.SUM) - - self.assertEqual(y, 3 * torch.ones(5, 5, device=device)) - - z = y.sin().sum() - z.backward() - x_g = 2 * (3 * torch.ones(5, 5, device=device)).cos() - self.assertEqual(x.grad, x_g) - - @requires_gloo() - @skip_if_lt_x_gpu(2) - @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available") - def test_all_gather(self): - store = c10d.FileStore(self.file_name, self.world_size) - # This is required because these functions calls directly to the .dist and needs - # the world to be initialized - c10d.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='gloo') - device = torch.device(f"cuda:{self.rank}") - x = torch.ones(5, 5, device=device) + self.rank - x.requires_grad = True - tensors = torch.distributed.nn.all_gather(x) - for i, t in enumerate(tensors): - self.assertEqual(t, torch.ones(5, 5, device=device) + i) - y = torch.sum(torch.stack(tensors), axis=0) - z = y.sin().sum() - z.backward() - - x_s = 2 * (3 * torch.ones(5, 5, device=device)).cos() - self.assertEqual(x.grad, x_s) - - @requires_gloo() - @skip_if_lt_x_gpu(2) - @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available") - def test_all_to_all(self): - store = c10d.FileStore(self.file_name, self.world_size) - # This is required because these functions calls directly to the .dist and needs - # the world to be initialized - c10d.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='gloo') - device = torch.device(f"cuda:{self.rank}") - x0 = torch.ones(5, 5, device=device) + 2 * self.rank - x1 = torch.ones(5, 5, device=device) + 2 * self.rank - x0.requires_grad = True - x1.requires_grad = True - tensors = torch.distributed.nn.all_to_all([x0, x1]) - for i, t in enumerate(tensors): - self.assertEqual(t, torch.ones(5, 5, device=device) + 2 * i) - y = torch.sum(torch.stack(tensors), axis=0) - z = y.sin().sum() - z.backward() - x_s = (4 * torch.ones(5, 5, device=device)).cos() - self.assertEqual(x0.grad, x_s) - self.assertEqual(x1.grad, x_s) - if __name__ == '__main__': run_tests() diff --git a/test/distributed/test_c10d_spawn_nccl.py b/test/distributed/test_c10d_spawn_nccl.py index 362de8c700b..427fae41898 100644 --- a/test/distributed/test_c10d_spawn_nccl.py +++ b/test/distributed/test_c10d_spawn_nccl.py @@ -4,15 +4,27 @@ import sys import test_c10d_spawn import torch import torch.distributed as c10d +from test_c10d_spawn import _torch_dist_nn_available, TestDistributedNNFunctions from torch.testing._internal.common_cuda import TEST_MULTIGPU -from torch.testing._internal.common_utils import TestCase, run_tests, sandcastle_skip_if +from torch.testing._internal.common_distributed import ( + requires_nccl, + skip_if_lt_x_gpu, +) +from torch.testing._internal.common_utils import ( + TestCase, + run_tests, + sandcastle_skip_if, + TEST_WITH_DEV_DBG_ASAN, +) NO_NCCL = not hasattr(c10d, "ProcessGroupNCCL") # Fails on Python-3.9, see https://github.com/pytorch/pytorch/issues/51619 if sys.version_info < (3, 9): - class ProcessGroupShareTensorTest(test_c10d_spawn.AbstractProcessGroupShareTensorTest, TestCase): + class ProcessGroupShareTensorTest( + test_c10d_spawn.AbstractProcessGroupShareTensorTest, TestCase + ): @classmethod def _init_pg_nccl(cls, rank, filename, world_size): store = c10d.FileStore(filename, world_size) @@ -25,7 +37,8 @@ if sys.version_info < (3, 9): ProcessGroupShareTensorTest._test_broadcast_process, [torch.ones(2, 2).to(i) * i for i in range(self.world_size)], ProcessGroupShareTensorTest._init_pg_nccl, - 1) + 1, + ) @sandcastle_skip_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed") @sandcastle_skip_if(NO_NCCL, "NCCL needed") @@ -34,11 +47,13 @@ if sys.version_info < (3, 9): ProcessGroupShareTensorTest._test_allreduce_process, [torch.ones(2, 2).to(i) for i in range(self.world_size)], ProcessGroupShareTensorTest._init_pg_nccl, - 1) + 1, + ) @classmethod def _test_reduce_process( - cls, rank, filename, shared_tensors, world_size, init_pg, c2p, p2c): + cls, rank, filename, shared_tensors, world_size, init_pg, c2p, p2c + ): pg = init_pg(rank, filename, world_size) x = shared_tensors[rank] pg.reduce(x, root=0, op=c10d.ReduceOp.SUM).wait() @@ -55,7 +70,8 @@ if sys.version_info < (3, 9): ProcessGroupShareTensorTest._test_reduce_process, [torch.ones(2, 2).to(i) for i in range(self.world_size)], ProcessGroupShareTensorTest._init_pg_nccl, - 1) + 1, + ) @sandcastle_skip_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed") @sandcastle_skip_if(NO_NCCL, "NCCL needed") @@ -64,8 +80,80 @@ if sys.version_info < (3, 9): ProcessGroupShareTensorTest._test_allgather_process, [torch.ones(2, 2).to(i) * i for i in range(self.world_size)], ProcessGroupShareTensorTest._init_pg_nccl, - self.world_size) + self.world_size, + ) -if __name__ == '__main__': +# Skip dev-asan as torch + multiprocessing spawn have known issues +if not TEST_WITH_DEV_DBG_ASAN: + + class TestDistributedNNFunctionsNccl(TestDistributedNNFunctions): + # Test Common Ops First. + @requires_nccl() + @skip_if_lt_x_gpu(2) + @sandcastle_skip_if( + not _torch_dist_nn_available, "torch.distributed.nn is not available" + ) + def test_broadcast(self): + self._test_broadcast("nccl") + + @requires_nccl() + @skip_if_lt_x_gpu(2) + @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available") + def test_reduce(self): + self._test_reduce("nccl") + + @requires_nccl() + @skip_if_lt_x_gpu(2) + @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available") + def test_allreduce(self): + self._test_allreduce("nccl") + + @requires_nccl() + @skip_if_lt_x_gpu(2) + @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available") + def test_all_gather(self): + self._test_all_gather("nccl") + + @requires_nccl() + @skip_if_lt_x_gpu(2) + @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available") + def test_all_to_all(self): + self._test_all_to_all("nccl") + + @requires_nccl() + @skip_if_lt_x_gpu(2) + @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available") + def test_all_to_all_single(self): + self._test_all_to_all_single("nccl") + + # Test Ops only supported in NCCL. + @requires_nccl() + @skip_if_lt_x_gpu(2) + @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available") + def test_reduce_scatter(self): + store = c10d.FileStore(self.file_name, self.world_size) + # This is required because these functions calls directly to the .dist and needs + # the world to be initialized + c10d.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='nccl') + device = torch.device(f"cuda:{self.rank}") + x0 = torch.ones(5, 5, device=device) + self.rank + x1 = torch.ones(5, 5, device=device) + self.rank + 1 + x0.requires_grad = True + x1.requires_grad = True + y = torch.empty_like(x0) + expected = (1 + self.world_size) * self.world_size / 2 + self.world_size * self.rank + y = torch.distributed.nn.reduce_scatter(y, [x0, x1]) + self.assertEqual(y, torch.ones(5, 5, device=device) * expected) + z = y.sin().sum() + z.backward() + expected_0 = (1 + self.world_size) * self.world_size / 2 + expected_1 = expected_0 + self.world_size + x_s_0 = (expected_0 * torch.ones(5, 5, device=device)).cos() + x_s_1 = (expected_1 * torch.ones(5, 5, device=device)).cos() + self.assertEqual(x0.grad, x_s_0) + self.assertEqual(x1.grad, x_s_1) + + +if __name__ == "__main__": run_tests() diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 1d51b233768..0084e4523a9 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1151,14 +1151,17 @@ Arguments: "reduce_scatter", [](::c10d::ProcessGroup& pg, at::Tensor& output, - std::vector& input) { + std::vector& input, + ::c10d::ReduceOp op) { std::vector outputs = {output}; std::vector> inputs = {input}; - return pg.reduce_scatter( - outputs, inputs, ::c10d::ReduceScatterOptions()); + ::c10d::ReduceScatterOptions opts; + opts.reduceOp = op; + return pg.reduce_scatter(outputs, inputs, opts); }, py::arg("output_tensors"), py::arg("input_tensor"), + py::arg("op") = ::c10d::ReduceOp::SUM, py::call_guard()) .def( diff --git a/torch/distributed/nn/functional.py b/torch/distributed/nn/functional.py index 7e03fc6e572..c40de387a5e 100644 --- a/torch/distributed/nn/functional.py +++ b/torch/distributed/nn/functional.py @@ -1,6 +1,6 @@ import torch -from torch.autograd import Function import torch.distributed as dist +from torch.autograd import Function def broadcast(tensor, src, group=dist.group.WORLD): @@ -79,6 +79,25 @@ def reduce(tensor, dst, op=dist.ReduceOp.SUM, group=dist.group.WORLD): return _Reduce.apply(dst, op, group, tensor) +def reduce_scatter(output, input_list, op=dist.ReduceOp.SUM, group=dist.group.WORLD): + """ + Reduces, then scatters a list of tensors to all processes in a group. + + Arguments: + output (Tensor): Output tensor. + input_list (list[Tensor]): List of tensors to reduce and scatter. + op (optional): One of the values from + ``torch.distributed.ReduceOp`` + enum. Specifies an operation used for element-wise reductions. + group (ProcessGroup, optional): The process group to work on. + + Returns: + Tensor: Output of the collective. + + """ + return _Reduce_Scatter.apply(op, group, output, *input_list) + + def all_gather(tensor, group=dist.group.WORLD): """ Gathers tensors from the whole group in a list. @@ -88,26 +107,58 @@ def all_gather(tensor, group=dist.group.WORLD): group (ProcessGroup, optional): The process group to work on. Returns: - tuple[Tensor]): Output of the collective. + tuple([Tensor]): Output of the collective. """ return _AllGather.apply(group, tensor) -def all_to_all(tensors, group=dist.group.WORLD): +def all_to_all(output_tensor_list, input_tensor_list, group=dist.group.WORLD): """ Each process scatters list of input tensors to all processes in a group and return gathered list of tensors in output list. Arguments: - tensors (list[Tensor]): List of tensors to scatter one per rank. + out_tensor_list (list[Tensor]): list of tensors to gather one per rank. + input_tensor_list (list[Tensor]): List of tensors to scatter one per rank. group (ProcessGroup, optional): The process group to work on. Returns: - tuple[Tensor]): Output of the collective. + tuple([Tensor]): Output of the collective. """ - return _AlltoAll.apply(group, *tensors) + return _AlltoAll.apply(group, output_tensor_list, *input_tensor_list) + + +def all_to_all_single( + output, + input, + output_split_sizes=None, + input_split_sizes=None, + group=dist.group.WORLD, +): + """ + Each process splits input tensor and then scatters the split list + to all processes in a group. Then concatenate the received tensors from all + the processes in the group and return single output tensor. + + Arguments: + output (Tensor): Gathered cancatenated output tensor. + input (Tensor): Input tensor to scatter. + output_split_sizes: (list[Int], optional): Output split sizes for dim 0 + if specified None or empty, dim 0 of ``output`` tensor must divide + equally by ``world_size``. + input_split_sizes: (list[Int], optional): Input split sizes for dim 0 + if specified None or empty, dim 0 of ``input`` tensor must divide + equally by ``world_size``. + + Returns: + Tensor: Output of the collective. + + """ + return _AlltoAllSingle.apply( + group, output, output_split_sizes, input_split_sizes, input + ) def all_reduce(tensor, op=dist.ReduceOp.SUM, group=dist.group.WORLD): @@ -207,6 +258,20 @@ class _Reduce(Function): return (None, None, None) + (_Broadcast.apply(ctx.src, ctx.group, grad_output),) +class _Reduce_Scatter(Function): + @staticmethod + def forward(ctx, op, group, tensor, *input_tensor_list): + ctx.group = group + dist.reduce_scatter(tensor, list(input_tensor_list), op=op, group=group) + return tensor + + @staticmethod + def backward(ctx, grad_output): + return (None, None, None) + _AllGather.apply( + ctx.group, grad_output.contiguous() + ) + + class _AllGather(Function): @staticmethod def forward(ctx, group, tensor): @@ -219,19 +284,19 @@ class _AllGather(Function): @staticmethod def backward(ctx, *grad_outputs): - gxs = _AlltoAll.apply(ctx.group, *grad_outputs) + tensor_list = [torch.empty_like(tensor) for tensor in grad_outputs] + gxs = _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs) gx = torch.sum(torch.stack(gxs), dim=0) return (None, gx) class _AlltoAll(Function): @staticmethod - def forward(ctx, group, *tensors): + def forward(ctx, group, out_tensor_list, *tensors): ctx.group = group - out_tensor_list = [ - torch.empty_like(tensors[i]) for i in range(dist.get_world_size(group=group)) + ctx.input_tensor_size_list = [ + tensors[i].size() for i in range(dist.get_world_size(group=group)) ] - reqs = [None] * dist.get_world_size(group=group) my_rank = dist.get_rank(group=group) # Implement it on means of scatter/gather, send/recv async operations have issues if dist.get_backend(group=group) is dist.Backend.GLOO: @@ -241,12 +306,51 @@ class _AlltoAll(Function): to_send = list(tensors) dist.scatter(out_tensor_list[i], to_send, i, group=group) else: - dist.all_to_all(out_tensor_list, list(tensors), group=group) + dist.all_to_all( + out_tensor_list, + list(tensors), + group=group, + ) return tuple(out_tensor_list) @staticmethod def backward(ctx, *grad_outputs): - return (None,) + _AlltoAll.apply(ctx.group, *grad_outputs) + tensor_list = [ + torch.empty(size, device=grad_outputs[0].device) + for size in ctx.input_tensor_size_list + ] + grad_outputs = tuple(tensor.contiguous() for tensor in grad_outputs) + return (None, None) + _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs) + + +class _AlltoAllSingle(Function): + @staticmethod + def forward(ctx, group, output, output_split_sizes, input_split_sizes, input): + ctx.group = group + ctx.input_size = input.size() + ctx.output_split_sizes = input_split_sizes + ctx.input_split_sizes = output_split_sizes + dist.all_to_all_single( + output, + input, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + ) + return output + + @staticmethod + def backward(ctx, grad_output): + tensor = torch.empty(ctx.input_size, device=grad_output.device) + return (None, None, None, None) + ( + _AlltoAllSingle.apply( + ctx.group, + tensor, + ctx.output_split_sizes, + ctx.input_split_sizes, + grad_output.contiguous(), + ), + ) class _AllReduce(Function):