mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
PT1 Distributed Release MileStone No.1 - Completed Distributed Package and CI tests (#10871)
Summary: The PR includes: (1) torch.distributed.c10d, which now includes the complete backward compatible frontend API for `torch.distributed` (2) `env://` init method functionality (3) Minor change to `test_distributed.py`, which is now a test for `torch.distributed.c10d`. (4) The old `test_distributed.py' is now moved to `test_distributed_thd` (5) Miscellaneous bug fixes. (6) DDP CPU test is removed since c10d doesn't have this support yet, but this is a very easy test after moving DDP CPU's dependency to torch.distributed.c10d. (7) CI config to test MPI, NCCL, and Gloo backend of c10d **Now all the distributed test including c10d DDP can pass with the c10d frontend API** TODO: (in a separate PR) MPI subgroup support, once this is added, CI group test will be enabled. Pull Request resolved: https://github.com/pytorch/pytorch/pull/10871 Differential Revision: D9554514 Pulled By: teng-li fbshipit-source-id: fb686ad42258526c8b4372148e82969fac4f42dd
This commit is contained in:
parent
fa7c81c640
commit
56539f5fe1
12 changed files with 2404 additions and 118 deletions
|
|
@ -1,15 +1,28 @@
|
|||
#!/bin/bash
|
||||
|
||||
# For distributed, four environmental configs:
|
||||
# (1) build with only NCCL
|
||||
# (2) build with NCCL and MPI
|
||||
# (3) build with only MPI
|
||||
# (4) build with neither
|
||||
if [[ "$BUILD_ENVIRONMENT" == *-xenial-cuda9-* ]]; then
|
||||
# TODO: move this to Docker
|
||||
sudo apt-get update
|
||||
sudo apt-get install libnccl-dev=2.2.13-1+cuda9.0 libnccl2=2.2.13-1+cuda9.0
|
||||
fi
|
||||
|
||||
if [[ "$BUILD_ENVIRONMENT" == *-xenial-cuda8-* ]] || [[ "$BUILD_ENVIRONMENT" == *-xenial-cuda9-cudnn7-py2* ]]; then
|
||||
# TODO: move this to Docker
|
||||
sudo apt-get update
|
||||
sudo apt-get install openmpi-bin libopenmpi-dev
|
||||
sudo apt-get install -y --no-install-recommends openssh-client openssh-server
|
||||
sudo mkdir -p /var/run/sshd
|
||||
fi
|
||||
|
||||
if [[ "$BUILD_ENVIRONMENT" == "pytorch-linux-xenial-py3-clang5-asan" ]]; then
|
||||
exec "$(dirname "${BASH_SOURCE[0]}")/build-asan.sh" $*
|
||||
fi
|
||||
|
||||
# TODO: move this to Docker
|
||||
# TODO: add both NCCL and MPI in CI test by fixing these test first
|
||||
sudo apt-get update
|
||||
sudo apt-get install libnccl-dev libnccl2
|
||||
# sudo apt-get install openmpi-bin libopenmpi-dev
|
||||
|
||||
# Required environment variable: $BUILD_ENVIRONMENT
|
||||
# (This is set by default in the Docker images we build, so you don't
|
||||
# need to set it yourself.
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ import unittest
|
|||
import warnings
|
||||
import random
|
||||
import contextlib
|
||||
import socket
|
||||
from functools import wraps
|
||||
from itertools import product
|
||||
from copy import deepcopy
|
||||
|
|
@ -550,3 +551,12 @@ def download_file(url, binary=True):
|
|||
msg = "could not download test file '{}'".format(url)
|
||||
warnings.warn(msg, RuntimeWarning)
|
||||
raise unittest.SkipTest(msg)
|
||||
|
||||
|
||||
def find_free_port():
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.bind(('localhost', 0))
|
||||
sockname = sock.getsockname()
|
||||
sock.close()
|
||||
return sockname[1]
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ import tempfile
|
|||
import torch
|
||||
from torch.utils import cpp_extension
|
||||
from common import TEST_WITH_ROCM
|
||||
import torch.distributed.c10d as c10d
|
||||
|
||||
TESTS = [
|
||||
'autograd',
|
||||
|
|
@ -31,12 +32,14 @@ TESTS = [
|
|||
'nn',
|
||||
'optim',
|
||||
'sparse',
|
||||
'thd_distributed',
|
||||
'torch',
|
||||
'utils',
|
||||
]
|
||||
|
||||
WINDOWS_BLACKLIST = [
|
||||
'distributed',
|
||||
'thd_distributed',
|
||||
]
|
||||
|
||||
ROCM_BLACKLIST = [
|
||||
|
|
@ -50,10 +53,29 @@ ROCM_BLACKLIST = [
|
|||
'multiprocessing',
|
||||
'nccl',
|
||||
'nn',
|
||||
'thd_distributed',
|
||||
'utils',
|
||||
]
|
||||
|
||||
DISTRIBUTED_TESTS_CONFIG = {
|
||||
'gloo': {
|
||||
'WORLD_SIZE': '2' if torch.cuda.device_count() == 2 else '3'
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
if c10d.is_available():
|
||||
if c10d.is_mpi_available():
|
||||
DISTRIBUTED_TESTS_CONFIG['mpi'] = {
|
||||
'WORLD_SIZE': '3'
|
||||
}
|
||||
if c10d.is_nccl_available():
|
||||
DISTRIBUTED_TESTS_CONFIG['nccl'] = {
|
||||
'WORLD_SIZE': '2' if torch.cuda.device_count() == 2 else '3'
|
||||
}
|
||||
|
||||
|
||||
THD_DISTRIBUTED_TESTS_CONFIG = {
|
||||
'tcp': {
|
||||
'WORLD_SIZE': '3'
|
||||
},
|
||||
|
|
@ -126,7 +148,10 @@ def test_distributed(python, test_module, test_directory, options):
|
|||
if options.verbose and not mpi_available:
|
||||
print_to_stderr(
|
||||
'MPI not available -- MPI backend tests will be skipped')
|
||||
for backend, env_vars in DISTRIBUTED_TESTS_CONFIG.items():
|
||||
config = DISTRIBUTED_TESTS_CONFIG
|
||||
if test_module == "test_thd_distributed":
|
||||
config = THD_DISTRIBUTED_TESTS_CONFIG
|
||||
for backend, env_vars in config.items():
|
||||
if backend == 'mpi' and not mpi_available:
|
||||
continue
|
||||
for with_init_file in {True, False}:
|
||||
|
|
@ -141,7 +166,10 @@ def test_distributed(python, test_module, test_directory, options):
|
|||
os.environ['INIT_METHOD'] = 'env://'
|
||||
os.environ.update(env_vars)
|
||||
if with_init_file:
|
||||
init_method = 'file://{}/shared_init_file'.format(tmp_dir)
|
||||
if test_module == "test_distributed":
|
||||
init_method = 'file://{}/'.format(tmp_dir)
|
||||
else:
|
||||
init_method = 'file://{}/shared_init_file'.format(tmp_dir)
|
||||
os.environ['INIT_METHOD'] = init_method
|
||||
try:
|
||||
os.mkdir(os.path.join(tmp_dir, 'barrier'))
|
||||
|
|
@ -170,6 +198,7 @@ def test_distributed(python, test_module, test_directory, options):
|
|||
CUSTOM_HANDLERS = {
|
||||
'cpp_extensions': test_cpp_extensions,
|
||||
'distributed': test_distributed,
|
||||
'thd_distributed': test_distributed,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import copy
|
||||
import math
|
||||
import multiprocessing
|
||||
import socket
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
|
@ -10,6 +9,7 @@ from functools import wraps
|
|||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
import common
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed import c10d
|
||||
|
|
@ -60,15 +60,6 @@ def get_timeout(test_id):
|
|||
return TIMEOUT_OVERRIDE.get(test_id.split('.')[-1], TIMEOUT_DEFAULT)
|
||||
|
||||
|
||||
def find_free_port():
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.bind(('localhost', 0))
|
||||
sockname = sock.getsockname()
|
||||
sock.close()
|
||||
return sockname[1]
|
||||
|
||||
|
||||
def gpus_for_rank(world_size):
|
||||
"""Multigpu tests are designed to simulate the multi nodes with multi
|
||||
GPUs on each node. Nccl backend requires equal #GPUs in each process.
|
||||
|
|
@ -126,14 +117,14 @@ class PrefixFileStoreTest(TestCase, StoreTestBase):
|
|||
class TCPStoreTest(TestCase, StoreTestBase):
|
||||
def _create_store(self):
|
||||
addr = 'localhost'
|
||||
port = find_free_port()
|
||||
port = common.find_free_port()
|
||||
return c10d.TCPStore(addr, port, True)
|
||||
|
||||
|
||||
class PrefixTCPStoreTest(TestCase, StoreTestBase):
|
||||
def setUp(self):
|
||||
addr = 'localhost'
|
||||
port = find_free_port()
|
||||
port = common.find_free_port()
|
||||
self.tcpstore = c10d.TCPStore(addr, port, True)
|
||||
self.prefix = "test_prefix"
|
||||
|
||||
|
|
@ -150,10 +141,10 @@ class RendezvousTest(TestCase):
|
|||
class RendezvousFileTest(TestCase):
|
||||
def test_common_errors(self):
|
||||
with self.assertRaisesRegex(ValueError, 'path missing'):
|
||||
gen = c10d.rendezvous('file://?rank=0&size=1')
|
||||
gen = c10d.rendezvous('file://?rank=0&world_size=1')
|
||||
next(gen)
|
||||
with self.assertRaisesRegex(ValueError, 'rank parameter missing'):
|
||||
gen = c10d.rendezvous('file:///tmp/foo?size=1')
|
||||
gen = c10d.rendezvous('file:///tmp/foo?world_size=1')
|
||||
next(gen)
|
||||
with self.assertRaisesRegex(ValueError, 'size parameter missing'):
|
||||
gen = c10d.rendezvous('file:///tmp/foo?rank=0')
|
||||
|
|
@ -161,7 +152,7 @@ class RendezvousFileTest(TestCase):
|
|||
|
||||
def test_nominal(self):
|
||||
with tempfile.NamedTemporaryFile() as file:
|
||||
url = 'file://%s?size=%d' % (file.name, 2)
|
||||
url = 'file://%s?world_size=%d' % (file.name, 2)
|
||||
gen0 = c10d.rendezvous(url + "&rank=0")
|
||||
store0, rank0, size0 = next(gen0)
|
||||
self.assertEqual(0, rank0)
|
||||
|
|
@ -183,10 +174,10 @@ class RendezvousFileTest(TestCase):
|
|||
class RendezvousTCPTest(TestCase):
|
||||
def test_common_errors(self):
|
||||
with self.assertRaisesRegex(ValueError, 'port number missing'):
|
||||
gen = c10d.rendezvous('tcp://127.0.0.1?rank=0&size=1')
|
||||
gen = c10d.rendezvous('tcp://127.0.0.1?rank=0&world_size=1')
|
||||
next(gen)
|
||||
with self.assertRaisesRegex(ValueError, 'rank parameter missing'):
|
||||
gen = c10d.rendezvous('tcp://127.0.0.1:23456?size=1')
|
||||
gen = c10d.rendezvous('tcp://127.0.0.1:23456?world_size=1')
|
||||
next(gen)
|
||||
with self.assertRaisesRegex(ValueError, 'size parameter missing'):
|
||||
gen = c10d.rendezvous('tcp://127.0.0.1:23456?rank=0')
|
||||
|
|
@ -194,8 +185,8 @@ class RendezvousTCPTest(TestCase):
|
|||
|
||||
def test_nominal(self):
|
||||
addr = 'localhost'
|
||||
port = find_free_port()
|
||||
url = 'tcp://%s:%d?size=%d' % (addr, port, 2)
|
||||
port = common.find_free_port()
|
||||
url = 'tcp://%s:%d?world_size=%d' % (addr, port, 2)
|
||||
gen0 = c10d.rendezvous(url + "&rank=0")
|
||||
store0, rank0, size0 = next(gen0)
|
||||
self.assertEqual(0, rank0)
|
||||
|
|
@ -245,7 +236,7 @@ class MultiProcessTestCase(TestCase):
|
|||
def setUp(self):
|
||||
self.rank = self.MAIN_PROCESS_RANK
|
||||
self.file = tempfile.NamedTemporaryFile()
|
||||
self.port = find_free_port()
|
||||
self.port = common.find_free_port()
|
||||
self.processes = [self._spawn_process(rank) for rank in range(int(self.world_size))]
|
||||
|
||||
def tearDown(self):
|
||||
|
|
@ -529,8 +520,9 @@ class DistributedDataParallelTest(MultiProcessTestCase):
|
|||
model = Net()
|
||||
ddp_model = distributed_c10d._DistributedDataParallelC10d(
|
||||
copy.deepcopy(model).cuda(gpus[0]),
|
||||
process_group,
|
||||
device_ids=gpus)
|
||||
device_ids=gpus,
|
||||
process_group=process_group)
|
||||
|
||||
model.cuda(gpus[0])
|
||||
|
||||
local_batch_size = len(gpus)
|
||||
|
|
|
|||
|
|
@ -5,29 +5,32 @@ import multiprocessing
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
import tempfile
|
||||
import unittest
|
||||
from contextlib import contextmanager
|
||||
from functools import reduce, wraps
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.c10d as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from common import TestCase
|
||||
from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR
|
||||
from torch.autograd import Variable
|
||||
|
||||
import common
|
||||
|
||||
BACKEND = os.environ["BACKEND"]
|
||||
TEMP_DIR = os.environ["TEMP_DIR"]
|
||||
INIT_METHOD = os.getenv("INIT_METHOD", "env://")
|
||||
MASTER_PORT = "29500"
|
||||
|
||||
DEFAULT_TIMEOUT = 300
|
||||
CUSTOMIZED_TIMEOUT = {"test_DistributedDataParallel": 500}
|
||||
|
||||
if INIT_METHOD.startswith("file://"):
|
||||
FOLDER = INIT_METHOD[7:]
|
||||
|
||||
|
||||
def get_timeout(test_id):
|
||||
test_name = test_id.split(".")[-1]
|
||||
|
|
@ -361,8 +364,9 @@ class _DistTestBase(object):
|
|||
rank_to_GPU = self._init_multigpu_helper()
|
||||
self._test_broadcast_helper(group, group_id, rank, True, rank_to_GPU)
|
||||
|
||||
@unittest.skipIf(BACKEND == "nccl", "Nccl does not support newGroup")
|
||||
@skip_if_small_worldsize
|
||||
@unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors")
|
||||
@unittest.skipIf(BACKEND == "mpi", "MPI does not support group")
|
||||
def test_broadcast_group(self):
|
||||
group, group_id, rank = self._init_group_test()
|
||||
self._test_broadcast_helper(group, group_id, rank)
|
||||
|
|
@ -454,7 +458,8 @@ class _DistTestBase(object):
|
|||
self._test_reduce_helper(group, group_id, rank, dist.reduce_op.MAX, -1, 10, 10)
|
||||
|
||||
@unittest.skipIf(BACKEND == "gloo", "Gloo does not support reduce")
|
||||
@unittest.skipIf(BACKEND == "nccl", "Nccl does not support newGroup")
|
||||
@unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors")
|
||||
@unittest.skipIf(BACKEND == "mpi", "MPI does not support group")
|
||||
@skip_if_small_worldsize
|
||||
def test_reduce_group_sum(self):
|
||||
group, group_id, rank = self._init_group_test()
|
||||
|
|
@ -469,7 +474,8 @@ class _DistTestBase(object):
|
|||
)
|
||||
|
||||
@unittest.skipIf(BACKEND == "gloo", "Gloo does not support reduce")
|
||||
@unittest.skipIf(BACKEND == "nccl", "Nccl does not support newGroup")
|
||||
@unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors")
|
||||
@unittest.skipIf(BACKEND == "mpi", "MPI does not support group")
|
||||
@skip_if_small_worldsize
|
||||
def test_reduce_group_product(self):
|
||||
group, group_id, rank = self._init_group_test()
|
||||
|
|
@ -484,14 +490,16 @@ class _DistTestBase(object):
|
|||
)
|
||||
|
||||
@unittest.skipIf(BACKEND == "gloo", "Gloo does not support reduce")
|
||||
@unittest.skipIf(BACKEND == "nccl", "Nccl does not support newGroup")
|
||||
@unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors")
|
||||
@unittest.skipIf(BACKEND == "mpi", "MPI does not support group")
|
||||
@skip_if_small_worldsize
|
||||
def test_reduce_group_min(self):
|
||||
group, group_id, rank = self._init_group_test()
|
||||
self._test_reduce_helper(group, group_id, rank, dist.reduce_op.MIN, 1010, 1, 1)
|
||||
|
||||
@unittest.skipIf(BACKEND == "gloo", "Gloo does not support reduce")
|
||||
@unittest.skipIf(BACKEND == "nccl", "Nccl does not support newGroup")
|
||||
@unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors")
|
||||
@unittest.skipIf(BACKEND == "mpi", "MPI does not support group")
|
||||
@skip_if_small_worldsize
|
||||
def test_reduce_group_max(self):
|
||||
group, group_id, rank = self._init_group_test()
|
||||
|
|
@ -540,8 +548,8 @@ class _DistTestBase(object):
|
|||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
BACKEND != "gloo" and BACKEND != "nccl",
|
||||
"Only Gloo & Nccl backend support CUDA allReduce",
|
||||
BACKEND != "gloo",
|
||||
"Only Gloo backend will have CUDA allReduce tested",
|
||||
)
|
||||
@skip_if_no_cuda_distributed
|
||||
@skip_if_no_gpu
|
||||
|
|
@ -587,8 +595,9 @@ class _DistTestBase(object):
|
|||
group, group_id, rank, dist.reduce_op.MAX, -1, 10, 10
|
||||
)
|
||||
|
||||
@unittest.skipIf(BACKEND == "nccl", "Nccl does not support newGroup")
|
||||
@skip_if_small_worldsize
|
||||
@unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors")
|
||||
@unittest.skipIf(BACKEND == "mpi", "MPI does not support group")
|
||||
def test_all_reduce_group_sum(self):
|
||||
group, group_id, rank = self._init_group_test()
|
||||
self._test_all_reduce_helper(
|
||||
|
|
@ -601,8 +610,9 @@ class _DistTestBase(object):
|
|||
2 + (10 * (len(group) - 1)),
|
||||
)
|
||||
|
||||
@unittest.skipIf(BACKEND == "nccl", "Nccl does not support newGroup")
|
||||
@skip_if_small_worldsize
|
||||
@unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors")
|
||||
@unittest.skipIf(BACKEND == "mpi", "MPI does not support group")
|
||||
def test_all_reduce_group_product(self):
|
||||
group, group_id, rank = self._init_group_test()
|
||||
self._test_all_reduce_helper(
|
||||
|
|
@ -615,16 +625,18 @@ class _DistTestBase(object):
|
|||
reduce((lambda x, y: x * y), [10] * (len(group) - 1), 2),
|
||||
)
|
||||
|
||||
@unittest.skipIf(BACKEND == "nccl", "Nccl does not support newGroup")
|
||||
@skip_if_small_worldsize
|
||||
@unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors")
|
||||
@unittest.skipIf(BACKEND == "mpi", "MPI does not support group")
|
||||
def test_all_reduce_group_min(self):
|
||||
group, group_id, rank = self._init_group_test()
|
||||
self._test_all_reduce_helper(
|
||||
group, group_id, rank, dist.reduce_op.MIN, 1010, 1, 1
|
||||
)
|
||||
|
||||
@unittest.skipIf(BACKEND == "nccl", "Nccl does not support newGroup")
|
||||
@skip_if_small_worldsize
|
||||
@unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors")
|
||||
@unittest.skipIf(BACKEND == "mpi", "MPI does not support group")
|
||||
def test_all_reduce_group_max(self):
|
||||
group, group_id, rank = self._init_group_test()
|
||||
self._test_all_reduce_helper(
|
||||
|
|
@ -652,6 +664,7 @@ class _DistTestBase(object):
|
|||
|
||||
@unittest.skipIf(BACKEND == "gloo", "Gloo does not support scatter")
|
||||
@unittest.skipIf(BACKEND == "nccl", "Nccl does not support scatter")
|
||||
@unittest.skipIf(BACKEND == "mpi", "MPI does not support group")
|
||||
@skip_if_small_worldsize
|
||||
def test_scatter_group(self):
|
||||
group, group_id, rank = self._init_group_test()
|
||||
|
|
@ -679,7 +692,8 @@ class _DistTestBase(object):
|
|||
self._test_gather_helper(group, group_id, rank)
|
||||
|
||||
@unittest.skipIf(BACKEND == "gloo", "Gloo does not support gather")
|
||||
@unittest.skipIf(BACKEND == "nccl", "Nccl does not support newGroup")
|
||||
@unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors")
|
||||
@unittest.skipIf(BACKEND == "mpi", "MPI does not support group")
|
||||
@skip_if_small_worldsize
|
||||
def test_gather_group(self):
|
||||
group, group_id, rank = self._init_group_test()
|
||||
|
|
@ -703,12 +717,13 @@ class _DistTestBase(object):
|
|||
|
||||
self._barrier()
|
||||
|
||||
@unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors")
|
||||
@unittest.skipIf(BACKEND != "mpi", "Only MPI supports CPU all gather")
|
||||
def test_all_gather(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
self._test_all_gather_helper(group, group_id, rank)
|
||||
|
||||
@unittest.skipIf(BACKEND != "nccl", "Only Nccl supports CUDA all gather")
|
||||
@unittest.skipIf(BACKEND == "nccl", "CUDA all gather skipped for NCCL")
|
||||
@skip_if_no_cuda_distributed
|
||||
@skip_if_no_gpu
|
||||
def test_all_gather_cuda(self):
|
||||
|
|
@ -716,8 +731,10 @@ class _DistTestBase(object):
|
|||
rank_to_GPU = self._init_multigpu_helper()
|
||||
self._test_all_gather_helper(group, group_id, rank, True, rank_to_GPU)
|
||||
|
||||
@unittest.skipIf(BACKEND == "nccl", "Nccl does not support newGroup")
|
||||
@skip_if_small_worldsize
|
||||
@unittest.skipIf(BACKEND == "gloo", "Gloo does not support gather")
|
||||
@unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors")
|
||||
@unittest.skipIf(BACKEND == "mpi", "MPI does not support group")
|
||||
def test_all_gather_group(self):
|
||||
group, group_id, rank = self._init_group_test()
|
||||
self._test_all_gather_helper(group, group_id, rank)
|
||||
|
|
@ -740,13 +757,14 @@ class _DistTestBase(object):
|
|||
|
||||
self._barrier()
|
||||
|
||||
@unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors")
|
||||
@unittest.skipIf(BACKEND != "mpi", "Only MPI supports barrier")
|
||||
def test_barrier(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
self._test_barrier_helper(group, group_id, rank)
|
||||
|
||||
@unittest.skipIf(BACKEND == "nccl", "Nccl does not support newGroup")
|
||||
@skip_if_small_worldsize
|
||||
@unittest.skipIf(BACKEND != "mpi", "Only MPI supports barrier")
|
||||
@unittest.skipIf(BACKEND == "mpi", "MPI does not support group")
|
||||
def test_barrier_group(self):
|
||||
group, group_id, rank = self._init_group_test()
|
||||
self._test_barrier_helper(group, group_id, rank)
|
||||
|
|
@ -765,7 +783,8 @@ class _DistTestBase(object):
|
|||
self.assertEqual(tensor, expected_tensor)
|
||||
self._barrier()
|
||||
|
||||
@unittest.skipIf(BACKEND != "nccl", "Only Nccl backend supports broadcast multigpu")
|
||||
@unittest.skipIf(BACKEND == "mpi", "MPI doesn't support broadcast multigpu")
|
||||
@unittest.skipIf(BACKEND == "nccl", "NCCL broadcast multigpu skipped")
|
||||
@skip_if_no_gpu
|
||||
def test_broadcast_multigpu(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
|
|
@ -802,7 +821,8 @@ class _DistTestBase(object):
|
|||
|
||||
self._barrier()
|
||||
|
||||
@unittest.skipIf(BACKEND != "nccl", "Only Nccl backend supports allreduce multigpu")
|
||||
@unittest.skipIf(BACKEND == "mpi", "MPI doesn't support broadcast multigpu")
|
||||
@unittest.skipIf(BACKEND == "nccl", "CUDA all_reduce multigpu skipped for NCCL")
|
||||
@skip_if_no_gpu
|
||||
def test_all_reduce_multigpu(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
|
|
@ -985,7 +1005,7 @@ class _DistTestBase(object):
|
|||
# DDP training setup
|
||||
model_DDP = copy.deepcopy(model)
|
||||
model_DDP.cuda(gpu_subset[0])
|
||||
model_DDP = nn.parallel.DistributedDataParallel(
|
||||
model_DDP = nn.parallel._DistributedDataParallelC10d(
|
||||
model_DDP, device_ids=gpu_subset
|
||||
)
|
||||
|
||||
|
|
@ -1006,33 +1026,8 @@ class _DistTestBase(object):
|
|||
)
|
||||
self._barrier()
|
||||
|
||||
@unittest.skipIf(
|
||||
BACKEND == "nccl", "nccl does not support DistributedDataParallelCPU"
|
||||
)
|
||||
def test_DistributedDataParallelCPU(self):
|
||||
# Run a simple end to end DDP-CPU model, use result of single node
|
||||
# model as baseline
|
||||
group, group_id, rank = self._init_global_test()
|
||||
|
||||
# cpu training setup
|
||||
model_base = self._create_Net()
|
||||
|
||||
# DDP-CPU training setup
|
||||
model_DDP = copy.deepcopy(model_base)
|
||||
model_DDP = nn.parallel.DistributedDataParallelCPU(model_DDP)
|
||||
|
||||
# dummy data initialization
|
||||
local_bs = 2
|
||||
global_bs, input_cpu, target, loss = self._prepare_dummy_data(local_bs)
|
||||
|
||||
# check two model parameters over 2 iterations
|
||||
self._test_DDP_2iter(
|
||||
model_base, model_DDP, input_cpu, target, loss, local_bs, rank, global_bs
|
||||
)
|
||||
self._barrier()
|
||||
|
||||
|
||||
if BACKEND == "tcp" or BACKEND == "gloo" or BACKEND == "nccl":
|
||||
if BACKEND == "gloo" or BACKEND == "nccl":
|
||||
WORLD_SIZE = os.environ["WORLD_SIZE"]
|
||||
|
||||
class TestDistBackend(TestCase, _DistTestBase):
|
||||
|
|
@ -1052,7 +1047,6 @@ if BACKEND == "tcp" or BACKEND == "gloo" or BACKEND == "nccl":
|
|||
@classmethod
|
||||
def setUpClass(cls):
|
||||
os.environ["MASTER_ADDR"] = MASTER_ADDR
|
||||
os.environ["MASTER_PORT"] = MASTER_PORT
|
||||
os.environ["WORLD_SIZE"] = WORLD_SIZE
|
||||
for attr in dir(cls):
|
||||
if attr.startswith("test"):
|
||||
|
|
@ -1060,6 +1054,17 @@ if BACKEND == "tcp" or BACKEND == "gloo" or BACKEND == "nccl":
|
|||
setattr(cls, attr, cls.manager_join(fn))
|
||||
|
||||
def setUp(self):
|
||||
# Adding this hack until we fix the FileStore to delete its
|
||||
# content at the end
|
||||
global INIT_METHOD
|
||||
if INIT_METHOD.startswith("file://"):
|
||||
_, filename = tempfile.mkstemp(prefix=FOLDER)
|
||||
INIT_METHOD = "file://{}".format(filename)
|
||||
|
||||
if INIT_METHOD.startswith("env://"):
|
||||
port = common.find_free_port()
|
||||
os.environ["MASTER_PORT"] = str(port)
|
||||
|
||||
self.processes = []
|
||||
self.rank = self.MANAGER_PROCESS_RANK
|
||||
Barrier.init()
|
||||
|
|
@ -1081,7 +1086,10 @@ if BACKEND == "tcp" or BACKEND == "gloo" or BACKEND == "nccl":
|
|||
self.rank = rank
|
||||
try:
|
||||
dist.init_process_group(
|
||||
init_method=INIT_METHOD, backend=BACKEND, world_size=int(WORLD_SIZE)
|
||||
init_method=INIT_METHOD,
|
||||
backend=BACKEND,
|
||||
world_size=int(WORLD_SIZE),
|
||||
rank=self.rank
|
||||
)
|
||||
except RuntimeError as e:
|
||||
if "recompile" in e.args[0]:
|
||||
|
|
|
|||
1148
test/test_thd_distributed.py
Normal file
1148
test/test_thd_distributed.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -346,8 +346,8 @@ PyObject* c10d_init(PyObject* _unused) {
|
|||
#endif
|
||||
|
||||
shared_ptr_class_<::c10d::ProcessGroup::Work>(module, "Work")
|
||||
.def("isCompleted", &::c10d::ProcessGroup::Work::isCompleted)
|
||||
.def("isSuccess", &::c10d::ProcessGroup::Work::isSuccess)
|
||||
.def("is_completed", &::c10d::ProcessGroup::Work::isCompleted)
|
||||
.def("is_success", &::c10d::ProcessGroup::Work::isSuccess)
|
||||
.def("exception", &::c10d::ProcessGroup::Work::exception)
|
||||
.def("synchronize", &::c10d::ProcessGroup::Work::synchronize)
|
||||
.def(
|
||||
|
|
|
|||
|
|
@ -6,20 +6,8 @@ def is_available():
|
|||
|
||||
|
||||
if is_available() and not torch._C._c10d_init():
|
||||
raise RuntimeError("c10d initialization failed")
|
||||
raise RuntimeError("Failed to initialize PyTorch distributed support")
|
||||
|
||||
|
||||
if is_available():
|
||||
from .rendezvous import rendezvous, register_rendezvous_handler
|
||||
from . import BroadcastOptions, AllreduceOptions
|
||||
|
||||
DEFAULT_REDUCE_OPTIONS = AllreduceOptions()
|
||||
|
||||
def broadcast(tensor, src, process_group):
|
||||
opts = BroadcastOptions()
|
||||
opts.rootRank = src
|
||||
opts.rootTensor = 0
|
||||
return process_group.broadcast([tensor], opts)
|
||||
|
||||
def all_reduce(tensor, process_group, opts=DEFAULT_REDUCE_OPTIONS):
|
||||
return process_group.allreduce([tensor], opts)
|
||||
from .distributed_c10d import *
|
||||
|
|
|
|||
1054
torch/distributed/c10d/distributed_c10d.py
Normal file
1054
torch/distributed/c10d/distributed_c10d.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -3,6 +3,7 @@ try:
|
|||
except ImportError:
|
||||
from urlparse import urlparse
|
||||
|
||||
import os
|
||||
from . import FileStore, TCPStore
|
||||
|
||||
|
||||
|
|
@ -59,13 +60,13 @@ def _file_rendezvous_handler(url):
|
|||
query = dict(pair.split("=") for pair in filter(None, result.query.split("&")))
|
||||
if "rank" not in query:
|
||||
raise _error("rank parameter missing")
|
||||
if "size" not in query:
|
||||
raise _error("size parameter missing")
|
||||
if "world_size" not in query:
|
||||
raise _error("world size parameter missing")
|
||||
|
||||
rank = int(query["rank"])
|
||||
size = int(query["size"])
|
||||
world_size = int(query["world_size"])
|
||||
store = FileStore(path)
|
||||
yield (store, rank, size)
|
||||
yield (store, rank, world_size)
|
||||
|
||||
# If this configuration is invalidated, there is nothing we can do about it
|
||||
raise RuntimeError("Unable to perform rerendezvous using file:// method")
|
||||
|
|
@ -81,18 +82,52 @@ def _tcp_rendezvous_handler(url):
|
|||
query = dict(pair.split("=") for pair in filter(None, result.query.split("&")))
|
||||
if "rank" not in query:
|
||||
raise _error("rank parameter missing")
|
||||
if "size" not in query:
|
||||
raise _error("size parameter missing")
|
||||
if "world_size" not in query:
|
||||
raise _error("world size parameter missing")
|
||||
|
||||
rank = int(query["rank"])
|
||||
size = int(query["size"])
|
||||
world_size = int(query["world_size"])
|
||||
start_daemon = rank == 0
|
||||
store = TCPStore(result.hostname, result.port, start_daemon)
|
||||
yield (store, rank, size)
|
||||
yield (store, rank, world_size)
|
||||
|
||||
# If this configuration is invalidated, there is nothing we can do about it
|
||||
raise RuntimeError("Unable to perform rerendezvous using tcp:// method")
|
||||
|
||||
|
||||
def _env_rendezvous_handler(url):
|
||||
def _error(msg):
|
||||
return ValueError("env:// rendezvous: " + msg)
|
||||
|
||||
if url != "env://":
|
||||
raise _error("Only `env://` is expected for the env init method")
|
||||
world_size = os.environ["WORLD_SIZE"]
|
||||
if world_size is None:
|
||||
raise _error("world size is missing")
|
||||
rank = os.environ["RANK"]
|
||||
if rank is None:
|
||||
raise _error("rank is missing")
|
||||
master_addr = os.environ["MASTER_ADDR"]
|
||||
if master_addr is None:
|
||||
raise _error("master addr is missing")
|
||||
master_port = os.environ["MASTER_PORT"]
|
||||
if master_port is None:
|
||||
raise _error("master port is missing")
|
||||
|
||||
# Converting before creating the store
|
||||
rank = int(rank)
|
||||
world_size = int(world_size)
|
||||
master_port = int(master_port)
|
||||
|
||||
# Now start the TCP store daemon on the rank 0
|
||||
start_daemon = rank == 0
|
||||
store = TCPStore(master_addr, master_port, start_daemon)
|
||||
yield (store, rank, world_size)
|
||||
|
||||
# If this configuration is invalidated, there is nothing we can do about it
|
||||
raise RuntimeError("Unable to perform rerendezvous using env:// method")
|
||||
|
||||
|
||||
register_rendezvous_handler("file", _file_rendezvous_handler)
|
||||
register_rendezvous_handler("tcp", _tcp_rendezvous_handler)
|
||||
register_rendezvous_handler("env", _env_rendezvous_handler)
|
||||
|
|
|
|||
|
|
@ -386,16 +386,17 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::gather(
|
|||
const GatherOptions& opts) {
|
||||
checkSingleTensor(inputTensors);
|
||||
|
||||
if (outputTensors.size() != 1) {
|
||||
throw std::runtime_error("Gather: multi-GPU collective is not supported");
|
||||
}
|
||||
|
||||
if (rank_ != opts.rootRank) {
|
||||
if (outputTensors.size() > 0) {
|
||||
if (outputTensors[0].size() > 0) {
|
||||
throw std::runtime_error(
|
||||
"Gather: number of output tensors should be 0 "
|
||||
"for non-root");
|
||||
}
|
||||
} else {
|
||||
if (outputTensors.size() != 1) {
|
||||
throw std::runtime_error("Gather: multi-GPU collective is not supported");
|
||||
}
|
||||
if (static_cast<size_t>(size_) != outputTensors[0].size()) {
|
||||
throw std::runtime_error(
|
||||
"Gather: number of output tensors should equal "
|
||||
|
|
@ -449,17 +450,17 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::scatter(
|
|||
std::vector<std::vector<at::Tensor>>& inputTensors,
|
||||
const ScatterOptions& opts) {
|
||||
checkSingleTensor(outputTensors);
|
||||
if (inputTensors.size() != 1) {
|
||||
throw std::runtime_error("Scatter: multi-GPU collective is not supported");
|
||||
}
|
||||
|
||||
if (rank_ != opts.rootRank) {
|
||||
if (inputTensors.size() > 0) {
|
||||
if (inputTensors[0].size() > 0) {
|
||||
throw std::runtime_error(
|
||||
"Scatter: number of input tensors should be 0 "
|
||||
"for non-root");
|
||||
}
|
||||
} else {
|
||||
if (inputTensors.size() != 1) {
|
||||
throw std::runtime_error("Gather: multi-GPU collective is not supported");
|
||||
}
|
||||
if (static_cast<size_t>(size_) != inputTensors[0].size()) {
|
||||
throw std::runtime_error(
|
||||
"Scatter: number of input tensors should equal "
|
||||
|
|
|
|||
|
|
@ -91,13 +91,14 @@ class _DistributedDataParallelC10d(Module):
|
|||
|
||||
Args:
|
||||
module: module to be parallelized
|
||||
process_group: the c10d process group to be used for distributed data
|
||||
all-reduction
|
||||
device_ids: CUDA devices (default: all devices)
|
||||
output_device: device location of output (default: device_ids[0])
|
||||
broadcast_buffers: flag that enables syncing (broadcasting) buffers of
|
||||
the module at beginning of the forward function.
|
||||
(default: True)
|
||||
process_group: the c10d process group to be used for distributed data
|
||||
all-reduction. If None, the default process group will
|
||||
be used
|
||||
bucket_cap_mb: DistributedDataParallelC10d will bucket parameters into
|
||||
multiple buckets so that gradient reduction of each
|
||||
bucket can potentially overlap with backward computation.
|
||||
|
|
@ -112,9 +113,9 @@ class _DistributedDataParallelC10d(Module):
|
|||
>>> pg = torch.distributed.c10d.ProcessGroupGloo(store, rank, world_size)
|
||||
>>> net = torch.nn._DistributedDataParallelC10d(model, pg)
|
||||
"""
|
||||
def __init__(self, module, process_group, device_ids=None,
|
||||
def __init__(self, module, device_ids=None,
|
||||
output_device=None, dim=0, broadcast_buffers=True,
|
||||
bucket_cap_mb=25):
|
||||
process_group=None, bucket_cap_mb=25):
|
||||
|
||||
super(_DistributedDataParallelC10d, self).__init__()
|
||||
|
||||
|
|
@ -125,13 +126,19 @@ class _DistributedDataParallelC10d(Module):
|
|||
if output_device is None:
|
||||
output_device = device_ids[0]
|
||||
|
||||
if process_group is None:
|
||||
self.process_group = c10d.get_default_group()
|
||||
else:
|
||||
self.process_group = process_group
|
||||
|
||||
self.dim = dim
|
||||
self.module = module
|
||||
self.process_group = process_group
|
||||
self.device_ids = device_ids
|
||||
self.output_device = output_device
|
||||
self.broadcast_buffers = broadcast_buffers
|
||||
|
||||
self.allreduce_opts = c10d.AllreduceOptions()
|
||||
|
||||
MB = 1024 * 1024
|
||||
|
||||
# used for intra-node param sync and inter-node sync as well
|
||||
|
|
@ -341,7 +348,8 @@ class _DistributedDataParallelC10d(Module):
|
|||
nccl.reduce(grads_batch_coalesced, root=0, streams=self.default_streams)
|
||||
|
||||
# now work on the first gpu
|
||||
reduction_work = c10d.all_reduce(grads_batch_coalesced[0], self.process_group)
|
||||
reduction_work = self.process_group.allreduce([grads_batch_coalesced[0]],
|
||||
self.allreduce_opts)
|
||||
self.reduction_works[bucket_idx] = reduction_work
|
||||
self.buckets_coalesced[bucket_idx] = grads_batch_coalesced[0]
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue