Use fake PG for test_compute_comm_reordering.py unit tests (#131415)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131415
Approved by: https://github.com/yifuwang
This commit is contained in:
Will Feng 2024-07-22 20:26:55 -07:00 committed by PyTorch MergeBot
parent 980bb54361
commit fc3d2b26cd
2 changed files with 34 additions and 16 deletions

View file

@ -26,7 +26,6 @@ from torch.testing._internal.common_distributed import (
_dynamo_dist_per_rank_init,
DynamoDistributedMultiProcTestCase,
requires_nccl,
skip_if_lt_x_gpu,
)
from torch.utils._triton import has_triton
@ -79,6 +78,9 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under
"""
def at_least_x_gpu(self, x):
return torch.cuda.is_available() and torch.cuda.device_count() >= x
def get_world_trs(self):
return {
"tag": "",
@ -93,7 +95,6 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
return 2
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@ -112,7 +113,9 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
b = torch.matmul(a, a)
return torch.matmul(ar, b)
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not self.at_least_x_gpu(2)
):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs)
@ -131,7 +134,6 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
self.assertTrue(same(out, correct))
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@ -152,7 +154,9 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
e = _functional_collectives.all_reduce(b, "sum", "0")
return torch.matmul(d, e)
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not self.at_least_x_gpu(2)
):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs)
@ -177,7 +181,6 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
self.assertTrue(same(out, correct))
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@ -200,7 +203,9 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
g = torch.matmul(f, f)
return torch.mm(e, g)
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not self.at_least_x_gpu(2)
):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
@ -229,7 +234,6 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
self.assertTrue(same(out, correct))
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@ -252,7 +256,9 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
e = torch.matmul(d + ar + fr, g)
return (e,)
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not self.at_least_x_gpu(2)
):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
@ -280,7 +286,6 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
self.assertTrue(same(out, correct))
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@ -308,7 +313,9 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
e = torch.matmul(d + ar + fr, g)
return (e,)
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not self.at_least_x_gpu(2)
):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
@ -336,7 +343,6 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
self.assertTrue(same(out, correct))
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@patch.object(
@ -356,7 +362,9 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
mm = torch.matmul(mul, ar)
return (mm,)
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not self.at_least_x_gpu(2)
):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())

View file

@ -1219,14 +1219,24 @@ class SaveForwardInputsModel(nn.Module):
return self.c2(self.c1(x))
@contextmanager
def _dynamo_dist_per_rank_init(rank, world_size, init_pg=True):
def _dynamo_dist_per_rank_init(rank, world_size, init_pg=True, fake_pg=False):
# To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase,
# Just manually implement the most important part of the dynamo behavior to reset/clear.
torch.cuda.set_device(rank)
if not fake_pg:
torch.cuda.set_device(rank)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '6789'
if init_pg:
c10d.init_process_group("nccl", rank=rank, world_size=world_size)
if fake_pg:
store = torch.testing._internal.distributed.fake_pg.FakeStore()
c10d.init_process_group(
backend="fake",
world_size=world_size,
rank=rank,
store=store,
)
else:
c10d.init_process_group("nccl", rank=rank, world_size=world_size)
torch._dynamo.reset()
torch._dynamo.utils.counters.clear()
try: