diff --git a/test/distributed/test_compute_comm_reordering.py b/test/distributed/test_compute_comm_reordering.py index 657565974fe..4a31e3eec42 100644 --- a/test/distributed/test_compute_comm_reordering.py +++ b/test/distributed/test_compute_comm_reordering.py @@ -26,7 +26,6 @@ from torch.testing._internal.common_distributed import ( _dynamo_dist_per_rank_init, DynamoDistributedMultiProcTestCase, requires_nccl, - skip_if_lt_x_gpu, ) from torch.utils._triton import has_triton @@ -79,6 +78,9 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase): Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under """ + def at_least_x_gpu(self, x): + return torch.cuda.is_available() and torch.cuda.device_count() >= x + def get_world_trs(self): return { "tag": "", @@ -93,7 +95,6 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase): return 2 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") - @skip_if_lt_x_gpu(2) @patch.object(torch._inductor.config, "allow_buffer_reuse", True) # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor @patch.object(torch._inductor.config, "compile_threads", 1) @@ -112,7 +113,9 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase): b = torch.matmul(a, a) return torch.matmul(ar, b) - with _dynamo_dist_per_rank_init(self.rank, self.world_size): + with _dynamo_dist_per_rank_init( + self.rank, self.world_size, fake_pg=not self.at_least_x_gpu(2) + ): inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank compiled = torch.compile(func) code = run_and_get_triton_code(compiled, inputs) @@ -131,7 +134,6 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase): self.assertTrue(same(out, correct)) @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") - @skip_if_lt_x_gpu(2) @patch.object(torch._inductor.config, "allow_buffer_reuse", True) # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor @patch.object(torch._inductor.config, "compile_threads", 1) @@ -152,7 +154,9 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase): e = _functional_collectives.all_reduce(b, "sum", "0") return torch.matmul(d, e) - with _dynamo_dist_per_rank_init(self.rank, self.world_size): + with _dynamo_dist_per_rank_init( + self.rank, self.world_size, fake_pg=not self.at_least_x_gpu(2) + ): inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank compiled = torch.compile(func) code = run_and_get_triton_code(compiled, inputs) @@ -177,7 +181,6 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase): self.assertTrue(same(out, correct)) @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") - @skip_if_lt_x_gpu(2) @patch.object(torch._inductor.config, "allow_buffer_reuse", True) # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor @patch.object(torch._inductor.config, "compile_threads", 1) @@ -200,7 +203,9 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase): g = torch.matmul(f, f) return torch.mm(e, g) - with _dynamo_dist_per_rank_init(self.rank, self.world_size): + with _dynamo_dist_per_rank_init( + self.rank, self.world_size, fake_pg=not self.at_least_x_gpu(2) + ): inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank compiled = torch.compile(func) code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) @@ -229,7 +234,6 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase): self.assertTrue(same(out, correct)) @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") - @skip_if_lt_x_gpu(2) @patch.object(torch._inductor.config, "allow_buffer_reuse", True) # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor @patch.object(torch._inductor.config, "compile_threads", 1) @@ -252,7 +256,9 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase): e = torch.matmul(d + ar + fr, g) return (e,) - with _dynamo_dist_per_rank_init(self.rank, self.world_size): + with _dynamo_dist_per_rank_init( + self.rank, self.world_size, fake_pg=not self.at_least_x_gpu(2) + ): inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank compiled = torch.compile(func) code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) @@ -280,7 +286,6 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase): self.assertTrue(same(out, correct)) @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") - @skip_if_lt_x_gpu(2) @patch.object(torch._inductor.config, "allow_buffer_reuse", True) # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor @patch.object(torch._inductor.config, "compile_threads", 1) @@ -308,7 +313,9 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase): e = torch.matmul(d + ar + fr, g) return (e,) - with _dynamo_dist_per_rank_init(self.rank, self.world_size): + with _dynamo_dist_per_rank_init( + self.rank, self.world_size, fake_pg=not self.at_least_x_gpu(2) + ): inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank compiled = torch.compile(func) code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) @@ -336,7 +343,6 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase): self.assertTrue(same(out, correct)) @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") - @skip_if_lt_x_gpu(2) # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor @patch.object(torch._inductor.config, "compile_threads", 1) @patch.object( @@ -356,7 +362,9 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase): mm = torch.matmul(mul, ar) return (mm,) - with _dynamo_dist_per_rank_init(self.rank, self.world_size): + with _dynamo_dist_per_rank_init( + self.rank, self.world_size, fake_pg=not self.at_least_x_gpu(2) + ): inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank compiled = torch.compile(func) code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index a0a3429797c..f8e5bc8a484 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -1219,14 +1219,24 @@ class SaveForwardInputsModel(nn.Module): return self.c2(self.c1(x)) @contextmanager -def _dynamo_dist_per_rank_init(rank, world_size, init_pg=True): +def _dynamo_dist_per_rank_init(rank, world_size, init_pg=True, fake_pg=False): # To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase, # Just manually implement the most important part of the dynamo behavior to reset/clear. - torch.cuda.set_device(rank) + if not fake_pg: + torch.cuda.set_device(rank) os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '6789' if init_pg: - c10d.init_process_group("nccl", rank=rank, world_size=world_size) + if fake_pg: + store = torch.testing._internal.distributed.fake_pg.FakeStore() + c10d.init_process_group( + backend="fake", + world_size=world_size, + rank=rank, + store=store, + ) + else: + c10d.init_process_group("nccl", rank=rank, world_size=world_size) torch._dynamo.reset() torch._dynamo.utils.counters.clear() try: