pytorch/test/distributed/test_multi_threaded_pg.py
Wanchao Liang e16979c9a0 [threaded_pg] full rewrite of MultiThreadedTestCase to enable device_type tests (#91650)
This PR did a full rewrite of MultiThreadedTestCase, to make it more
aligned with the MultiProcessTestCase, also changed how it do spawning
and testing, so that we could embed thread local states when running
tests.

This PR enables device_type tests to work with MultiThreadedTestCase
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91650
Approved by: https://github.com/XilunWu
2023-01-17 03:26:36 +00:00

176 lines
6.4 KiB
Python

# Owner(s): ["oncall: distributed"]
import sys
import torch
import torch.distributed as dist
from torch._C._distributed_c10d import ReduceOp
from unittest import skip, SkipTest
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
from torch.testing._internal.common_distributed import (
spawn_threads_and_init_comms,
MultiThreadedTestCase,
)
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
IS_SANDCASTLE,
)
DEFAULT_WORLD_SIZE = 4
class TestCollectivesWithWrapper(TestCase):
@spawn_threads_and_init_comms(world_size=4)
def test_broadcast_object_list(self):
val = 99 if dist.get_rank() == 0 else None
object_list = [val] * dist.get_world_size()
dist.broadcast_object_list(object_list=object_list)
self.assertEqual(99, object_list[0])
def test_collective_error_on_rank_zero(self):
@spawn_threads_and_init_comms(world_size=4)
def _test_method(self):
input_tensor = torch.ones(3, 3) * dist.get_rank() # perform 1st all gather
output_tensors = [torch.empty_like(input_tensor) for _ in range(dist.get_world_size())]
dist.all_gather(output_tensors, input_tensor)
if dist.get_rank() == 0:
raise AssertionError("Mimic real test failure.") # fail on rank 0
dist.all_gather(output_tensors, input_tensor) # perform 2nd all gather
with self.assertRaises(RuntimeError):
_test_method(self)
def test_collective_error_on_rank_non_zero(self):
@spawn_threads_and_init_comms(world_size=4)
def _test_method(self):
input_tensor = torch.ones(3, 3) * dist.get_rank() # perform 1st all gather
output_tensors = [torch.empty_like(input_tensor) for _ in range(dist.get_world_size())]
dist.all_gather(output_tensors, input_tensor)
if dist.get_rank() == 1:
raise AssertionError("Mimic real test failure.") # fail on rank 1
dist.all_gather(output_tensors, input_tensor) # perform 2nd all gather
with self.assertRaises(RuntimeError):
_test_method(self)
def test_collective_error_on_rank_non_zero_all(self):
@spawn_threads_and_init_comms(world_size=4)
def _test_method(self):
input_tensor = torch.ones(3, 3) * dist.get_rank() # perform 1st all gather
output_tensors = [torch.empty_like(input_tensor) for _ in range(dist.get_world_size())]
dist.all_gather(output_tensors, input_tensor)
if dist.get_rank() > 0:
raise AssertionError("Mimic real test failure.") # fail on all non-zero rank
dist.all_gather(output_tensors, input_tensor) # perform 2nd all gather
with self.assertRaises(RuntimeError):
_test_method(self)
def test_skip(self):
@spawn_threads_and_init_comms(world_size=4)
@skip("check if skip exception can be captured correctly.")
def _test_method(self):
pass
if not IS_SANDCASTLE:
with self.assertRaises(SkipTest):
_test_method(self)
class TestCollectivesWithBaseClass(MultiThreadedTestCase):
@property
def world_size(self):
return 4
def setUp(self):
super().setUp()
self._spawn_threads()
def test_allgather(self):
input_tensor = torch.ones(3, 3) * dist.get_rank()
output_tensors = [torch.empty_like(input_tensor) for _ in range(self.world_size)]
dist.all_gather(output_tensors, input_tensor)
for rank, out_tensor in enumerate(output_tensors):
self.assertEqual(out_tensor, torch.ones(3, 3) * rank)
def test_broadcast(self):
input_tensor = torch.ones(3, 3) * dist.get_rank()
for rank in range(self.world_size):
cloned_input = input_tensor.clone()
dist.broadcast(cloned_input, src=rank)
self.assertEqual(cloned_input, torch.ones(3, 3) * rank)
def test_scatter(self):
if dist.get_rank() == 0:
scatter_list = [torch.ones(3, 3) * rank for rank in range(self.world_size)]
else:
scatter_list = None
output_tensor = torch.empty(3, 3)
dist.scatter(output_tensor, scatter_list)
self.assertEqual(output_tensor, torch.ones(3, 3) * dist.get_rank())
def test_reduce_scatter(self):
to_reduce_scatter = [torch.ones(3, 3) * rank for rank in range(self.world_size)]
output_tensor = torch.empty(3, 3)
dist.reduce_scatter(output_tensor, to_reduce_scatter)
expected_tensor = torch.ones(3, 3) * dist.get_rank() * self.world_size
self.assertEqual(output_tensor, expected_tensor)
def test_broadcast_object_list(self):
val = 99 if dist.get_rank() == 0 else None
object_list = [val] * dist.get_world_size()
print(f"{dist.get_rank()} -> {dist.get_world_size()}")
dist.broadcast_object_list(object_list=object_list)
self.assertEqual(99, object_list[0])
def test_all_reduce(self):
output = torch.ones(3, 3) * dist.get_rank()
dist.all_reduce(output)
res_num = ((0 + self.world_size - 1) * self.world_size) / 2
self.assertEqual(output, torch.ones(3, 3) * res_num)
# Test unimplemented error
with self.assertRaisesRegex(NotImplementedError, "only supports SUM on threaded pg for now"):
dist.all_reduce(output, op=ReduceOp.MAX)
def test_assert_equal_on_rank(self):
# RNG is shared across threads. So instead of asserting on all threads
# we only assert on rank 0
self_tensor = torch.rand(3, 3)
rank_0_tensor = self_tensor.clone()
dist.broadcast(rank_0_tensor, src=0)
self.assertEqualOnRank(rank_0_tensor, self_tensor, rank=0)
self.assertNotEqualOnRank(rank_0_tensor, self_tensor, rank=1)
def test_subpg(self):
subpg0 = dist.new_group([0, 1])
subpg1 = dist.new_group([2, 3])
current_rank = dist.get_rank()
output = torch.ones(3, 3) * current_rank
# call all_reduce on subpg0 and subpg1 concurrently
if current_rank in [0, 1]:
dist.all_reduce(output, group=subpg0)
else:
dist.all_reduce(output, group=subpg1)
if current_rank in [0, 1]:
self.assertEqual(output, torch.ones(3, 3) * 1)
else:
self.assertEqual(output, torch.ones(3, 3) * 5)
if __name__ == "__main__":
run_tests()