pytorch/test/test_c10d.py
Pieter Noordhuis 3da27312bb
Export ProcessGroupGloo options to Python (#8664)
This surfaces the options struct that can be passed to the
ProcessGroupGloo constructor to Python. By default, if no options struct
is passed at construction time, the Python bindings default to using a
struct with a TCP backed Gloo device that uses the machine's hostname to
resolve the IP address to bind to.
2018-06-20 09:08:06 -07:00

360 lines
11 KiB
Python

import math
import multiprocessing
import sys
import tempfile
import unittest
from functools import wraps
import torch
import torch.distributed.c10d as c10d
from common import TestCase
TCP_ADDR = '127.0.0.1'
TCP_PORT = 29500
TIMEOUT_DEFAULT = 5
TIMEOUT_OVERRIDE = {}
def get_timeout(test_id):
return TIMEOUT_OVERRIDE.get(test_id.split('.')[-1], TIMEOUT_DEFAULT)
if not c10d.is_available():
print('c10d not available, skipping tests')
sys.exit(0)
class StoreTestBase(object):
def _create_store(self, i):
raise RuntimeError("not implemented")
def _test_set_get(self, fs):
fs.set("key0", "value0")
fs.set("key1", "value1")
fs.set("key2", "value2")
self.assertEqual(b"value0", fs.get("key0"))
self.assertEqual(b"value1", fs.get("key1"))
self.assertEqual(b"value2", fs.get("key2"))
def test_set_get(self):
self._test_set_get(self._create_store())
class FileStoreTest(TestCase, StoreTestBase):
def setUp(self):
self.file = tempfile.NamedTemporaryFile()
def tearDown(self):
self.file.close()
def _create_store(self):
return c10d.FileStore(self.file.name)
class TCPStoreTest(TestCase, StoreTestBase):
def _create_store(self):
return c10d.TCPStore(TCP_ADDR, TCP_PORT, True)
class RendezvousTest(TestCase):
def test_unknown_handler(self):
with self.assertRaisesRegex(RuntimeError, "^No rendezvous handler"):
c10d.rendezvous('invalid://')
class RendezvousFileTest(TestCase):
def test_common_errors(self):
with self.assertRaisesRegex(ValueError, 'path missing'):
gen = c10d.rendezvous('file://?rank=0&size=1')
next(gen)
with self.assertRaisesRegex(ValueError, 'rank parameter missing'):
gen = c10d.rendezvous('file:///tmp/foo?size=1')
next(gen)
with self.assertRaisesRegex(ValueError, 'size parameter missing'):
gen = c10d.rendezvous('file:///tmp/foo?rank=0')
next(gen)
def test_nominal(self):
with tempfile.NamedTemporaryFile() as file:
url = 'file://%s?size=%d' % (file.name, 2)
gen0 = c10d.rendezvous(url + "&rank=0")
store0, rank0, size0 = next(gen0)
self.assertEqual(0, rank0)
self.assertEqual(2, size0)
gen1 = c10d.rendezvous(url + "&rank=1")
store1, rank1, size1 = next(gen1)
self.assertEqual(1, rank1)
self.assertEqual(2, size1)
# Set value on both stores
store0.set("key0", "value0")
store1.set("key1", "value1")
# Cross check with get
self.assertEqual(b"value0", store1.get("key0"))
self.assertEqual(b"value1", store0.get("key1"))
class RendezvousTCPTest(TestCase):
def test_common_errors(self):
with self.assertRaisesRegex(ValueError, 'port number missing'):
gen = c10d.rendezvous('tcp://127.0.0.1?rank=0&size=1')
next(gen)
with self.assertRaisesRegex(ValueError, 'rank parameter missing'):
gen = c10d.rendezvous('tcp://127.0.0.1:23456?size=1')
next(gen)
with self.assertRaisesRegex(ValueError, 'size parameter missing'):
gen = c10d.rendezvous('tcp://127.0.0.1:23456?rank=0')
next(gen)
def test_nominal(self):
url = 'tcp://127.0.0.1:23456?size=%d' % 2
gen0 = c10d.rendezvous(url + "&rank=0")
store0, rank0, size0 = next(gen0)
self.assertEqual(0, rank0)
self.assertEqual(2, size0)
gen1 = c10d.rendezvous(url + "&rank=1")
store1, rank1, size1 = next(gen1)
self.assertEqual(1, rank1)
self.assertEqual(2, size1)
# Set value on both stores
store0.set("key0", "value0")
store1.set("key1", "value1")
# Cross check with get
self.assertEqual(b"value0", store1.get("key0"))
self.assertEqual(b"value1", store0.get("key1"))
class ProcessGroupGlooTest(TestCase):
MAIN_PROCESS_RANK = -1
@staticmethod
def join_or_run(fn):
@wraps(fn)
def wrapper(self):
if self.rank == self.MAIN_PROCESS_RANK:
self._join_processes(fn)
else:
fn(self)
return wrapper
# The main process spawns N subprocesses that run the test.
# This function patches overwrites every test function to either
# assume the role of the main process and join its subprocesses,
# or run the underlying test function.
@classmethod
def setUpClass(cls):
for attr in dir(cls):
if attr.startswith('test'):
fn = getattr(cls, attr)
setattr(cls, attr, cls.join_or_run(fn))
def setUp(self):
self.rank = self.MAIN_PROCESS_RANK
self.size = 4
self.file = tempfile.NamedTemporaryFile()
self.processes = [self._spawn_process(rank) for rank in range(int(self.size))]
def tearDown(self):
for p in self.processes:
p.terminate()
self.file.close()
def _spawn_process(self, rank):
name = 'process ' + str(rank)
process = multiprocessing.Process(target=self._run, name=name, args=(rank,))
process.start()
return process
def _run(self, rank):
self.rank = rank
# self.id() == e.g. '__main__.TestDistributed.test_get_rank'
# We're retreiving a corresponding test and executing it.
getattr(self, self.id().split(".")[2])()
sys.exit(0)
def _join_processes(self, fn):
timeout = get_timeout(self.id())
for p in self.processes:
p.join(timeout)
def opts(self):
opts = c10d.ProcessGroupGloo.Options()
opts.timeout = 1.0
opts.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")]
return opts
def test_broadcast_ops(self):
store = c10d.FileStore(self.file.name)
pg = c10d.ProcessGroupGloo(store, self.rank, self.size, self.opts())
def broadcast(xs, rootRank, rootTensor):
opts = c10d.BroadcastOptions()
opts.rootRank = rootRank
opts.rootTensor = rootTensor
work = pg.broadcast(xs, opts)
work.wait()
# Every rank is root once, every tensor index is root once
for i in range(self.size):
for j in range(2):
xs = [
torch.Tensor([self.rank * self.size + 0.0]),
torch.Tensor([self.rank * self.size + 1.0]),
]
broadcast(xs, i, j)
self.assertEqual(torch.Tensor([i * self.size + j]), xs[0])
self.assertEqual(torch.Tensor([i * self.size + j]), xs[1])
# Test overloaded convenience function
x = torch.Tensor([self.rank + 1.0])
work = pg.broadcast(x, root=0)
work.wait()
self.assertEqual(torch.Tensor([1.0]), x)
def test_allreduce_ops(self):
store = c10d.FileStore(self.file.name)
pg = c10d.ProcessGroupGloo(store, self.rank, self.size, self.opts())
def allreduce(x, op):
opts = c10d.AllreduceOptions()
opts.reduceOp = op
work = pg.allreduce([x], opts)
work.wait()
# Sum
x = torch.Tensor([self.rank + 1.0])
allreduce(x, c10d.ReduceOp.SUM)
self.assertEqual(torch.Tensor([float(self.size * (self.size + 1) / 2)]), x)
# Product
x = torch.Tensor([self.rank + 1.0])
allreduce(x, c10d.ReduceOp.PRODUCT)
self.assertEqual(torch.Tensor([float(math.factorial(self.size))]), x)
# Min
x = torch.Tensor([self.rank + 1.0])
allreduce(x, c10d.ReduceOp.MIN)
self.assertEqual(torch.Tensor([1.0]), x)
# Max
x = torch.Tensor([self.rank + 1.0])
allreduce(x, c10d.ReduceOp.MAX)
self.assertEqual(torch.Tensor([self.size]), x)
# Test overloaded convenience function (defaults to using sum)
x = torch.Tensor([self.rank + 1.0])
work = pg.allreduce(x)
work.wait()
self.assertEqual(torch.Tensor([float(self.size * (self.size + 1) / 2)]), x)
class ProcessGroupNCCLTest(TestCase):
MAIN_PROCESS_RANK = 0
def setUp(self):
if not hasattr(c10d, "ProcessGroupNCCL"):
raise unittest.SkipTest("C10D is not built with NCCL process group,"
" skipping test")
self.rank = self.MAIN_PROCESS_RANK
self.size = 1
self.file = tempfile.NamedTemporaryFile()
if not torch.cuda.is_available():
raise unittest.SkipTest("torch.cuda not available, skipping test")
self.num_gpus = torch.cuda.device_count()
if self.num_gpus < 2:
raise unittest.SkipTest("Requires at least 2 GPUs, skipping test")
def tearDown(self):
self.file.close()
def test_broadcast_ops(self):
store = c10d.FileStore(self.file.name)
pg = c10d.ProcessGroupNCCL(store, self.rank, self.size)
def broadcast(xs, rootRank, rootTensor):
opts = c10d.BroadcastOptions()
opts.rootRank = rootRank
opts.rootTensor = rootTensor
work = pg.broadcast(xs, opts)
work.wait()
# for every root tensor
for rt in range(self.num_gpus):
tensors = []
for i in range(self.num_gpus):
tensors.append(torch.Tensor([i]).cuda(i))
broadcast(tensors, self.rank, rt)
for i in range(self.num_gpus):
self.assertEqual(tensors[i], tensors[rt])
def test_allreduce_ops(self):
store = c10d.FileStore(self.file.name)
pg = c10d.ProcessGroupNCCL(store, self.rank, self.size)
def allreduce(tensors, op):
opts = c10d.AllreduceOptions()
opts.reduceOp = op
work = pg.allreduce(tensors, opts)
work.wait()
# Sum
tensors = []
for i in range(self.num_gpus):
tensors.append(torch.Tensor([i + 1]).cuda(i))
allreduce(tensors, c10d.ReduceOp.SUM)
for i in range(self.num_gpus):
self.assertEqual(
torch.Tensor([float(self.num_gpus * (self.num_gpus + 1) / 2)]),
tensors[i])
# Product
tensors = []
for i in range(self.num_gpus):
tensors.append(torch.Tensor([i + 1]).cuda(i))
allreduce(tensors, c10d.ReduceOp.PRODUCT)
for i in range(self.num_gpus):
self.assertEqual(
torch.Tensor([float(math.factorial(self.num_gpus))]),
tensors[i])
# Min
tensors = []
for i in range(self.num_gpus):
tensors.append(torch.Tensor([i + 1]).cuda(i))
allreduce(tensors, c10d.ReduceOp.MIN)
for i in range(self.num_gpus):
self.assertEqual(torch.Tensor([1.0]), tensors[i])
# Max
tensors = []
for i in range(self.num_gpus):
tensors.append(torch.Tensor([i + 1]).cuda(i))
allreduce(tensors, c10d.ReduceOp.MAX)
for i in range(self.num_gpus):
self.assertEqual(torch.Tensor([self.num_gpus]), tensors[i])
if __name__ == '__main__':
unittest.main()