mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Use fake PG for test_compute_comm_reordering.py unit tests (#131415)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131415 Approved by: https://github.com/yifuwang
This commit is contained in:
parent
980bb54361
commit
fc3d2b26cd
2 changed files with 34 additions and 16 deletions
|
|
@ -26,7 +26,6 @@ from torch.testing._internal.common_distributed import (
|
|||
_dynamo_dist_per_rank_init,
|
||||
DynamoDistributedMultiProcTestCase,
|
||||
requires_nccl,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
|
|
@ -79,6 +78,9 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
|||
Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under
|
||||
"""
|
||||
|
||||
def at_least_x_gpu(self, x):
|
||||
return torch.cuda.is_available() and torch.cuda.device_count() >= x
|
||||
|
||||
def get_world_trs(self):
|
||||
return {
|
||||
"tag": "",
|
||||
|
|
@ -93,7 +95,6 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
|||
return 2
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
|
||||
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
||||
@patch.object(torch._inductor.config, "compile_threads", 1)
|
||||
|
|
@ -112,7 +113,9 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
|||
b = torch.matmul(a, a)
|
||||
return torch.matmul(ar, b)
|
||||
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank, self.world_size, fake_pg=not self.at_least_x_gpu(2)
|
||||
):
|
||||
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, inputs)
|
||||
|
|
@ -131,7 +134,6 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
|||
self.assertTrue(same(out, correct))
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
|
||||
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
||||
@patch.object(torch._inductor.config, "compile_threads", 1)
|
||||
|
|
@ -152,7 +154,9 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
|||
e = _functional_collectives.all_reduce(b, "sum", "0")
|
||||
return torch.matmul(d, e)
|
||||
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank, self.world_size, fake_pg=not self.at_least_x_gpu(2)
|
||||
):
|
||||
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, inputs)
|
||||
|
|
@ -177,7 +181,6 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
|||
self.assertTrue(same(out, correct))
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
|
||||
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
||||
@patch.object(torch._inductor.config, "compile_threads", 1)
|
||||
|
|
@ -200,7 +203,9 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
|||
g = torch.matmul(f, f)
|
||||
return torch.mm(e, g)
|
||||
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank, self.world_size, fake_pg=not self.at_least_x_gpu(2)
|
||||
):
|
||||
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
|
||||
|
|
@ -229,7 +234,6 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
|||
self.assertTrue(same(out, correct))
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
|
||||
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
||||
@patch.object(torch._inductor.config, "compile_threads", 1)
|
||||
|
|
@ -252,7 +256,9 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
|||
e = torch.matmul(d + ar + fr, g)
|
||||
return (e,)
|
||||
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank, self.world_size, fake_pg=not self.at_least_x_gpu(2)
|
||||
):
|
||||
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
|
||||
|
|
@ -280,7 +286,6 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
|||
self.assertTrue(same(out, correct))
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
|
||||
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
||||
@patch.object(torch._inductor.config, "compile_threads", 1)
|
||||
|
|
@ -308,7 +313,9 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
|||
e = torch.matmul(d + ar + fr, g)
|
||||
return (e,)
|
||||
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank, self.world_size, fake_pg=not self.at_least_x_gpu(2)
|
||||
):
|
||||
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
|
||||
|
|
@ -336,7 +343,6 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
|||
self.assertTrue(same(out, correct))
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_lt_x_gpu(2)
|
||||
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
||||
@patch.object(torch._inductor.config, "compile_threads", 1)
|
||||
@patch.object(
|
||||
|
|
@ -356,7 +362,9 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
|||
mm = torch.matmul(mul, ar)
|
||||
return (mm,)
|
||||
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank, self.world_size, fake_pg=not self.at_least_x_gpu(2)
|
||||
):
|
||||
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
|
||||
|
|
|
|||
|
|
@ -1219,14 +1219,24 @@ class SaveForwardInputsModel(nn.Module):
|
|||
return self.c2(self.c1(x))
|
||||
|
||||
@contextmanager
|
||||
def _dynamo_dist_per_rank_init(rank, world_size, init_pg=True):
|
||||
def _dynamo_dist_per_rank_init(rank, world_size, init_pg=True, fake_pg=False):
|
||||
# To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase,
|
||||
# Just manually implement the most important part of the dynamo behavior to reset/clear.
|
||||
torch.cuda.set_device(rank)
|
||||
if not fake_pg:
|
||||
torch.cuda.set_device(rank)
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = '6789'
|
||||
if init_pg:
|
||||
c10d.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||
if fake_pg:
|
||||
store = torch.testing._internal.distributed.fake_pg.FakeStore()
|
||||
c10d.init_process_group(
|
||||
backend="fake",
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
store=store,
|
||||
)
|
||||
else:
|
||||
c10d.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||
torch._dynamo.reset()
|
||||
torch._dynamo.utils.counters.clear()
|
||||
try:
|
||||
|
|
|
|||
Loading…
Reference in a new issue