From 6a29ea1efcba1cae285e2c4e52f49045dbfc5223 Mon Sep 17 00:00:00 2001 From: amathewc Date: Tue, 21 Jan 2025 12:01:43 +0200 Subject: [PATCH] Update as per review comments Fix typo to resolve CI issues. Fix issues as per review comments. --- .../test_c10d_object_collectives.py | 29 +++++++++++++------ 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/test/distributed/test_c10d_object_collectives.py b/test/distributed/test_c10d_object_collectives.py index bbcd5c1347c..594564c4560 100644 --- a/test/distributed/test_c10d_object_collectives.py +++ b/test/distributed/test_c10d_object_collectives.py @@ -30,17 +30,16 @@ if TEST_WITH_DEV_DBG_ASAN: sys.exit(0) if TEST_HPU: - BACKEND = dist.Backend.HCCL - device_count = torch.hpu.device_count() DEVICE = "hpu" elif TEST_CUDA: - BACKEND = dist.Backend.NCCL - device_count = torch.cuda.device_count() DEVICE = "cuda" else: - BACKEND = dist.Backend.GLOO DEVICE = "cpu" +device_module = torch.get_device_module(DEVICE) +device_count = device_module.device_count() +BACKEND = dist.get_default_backend_for_device(DEVICE) + def with_comms(func=None): if func is None: @@ -50,10 +49,7 @@ def with_comms(func=None): @wraps(func) def wrapper(self, *args, **kwargs): - if ( - BACKEND == (dist.Backend.NCCL or dist.Backend.HCCL) - and device_count < self.world_size - ): + if DEVICE != "cpu" and device_count < self.world_size: sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) kwargs["device"] = DEVICE @@ -84,6 +80,21 @@ class TestObjectCollectives(DistributedTestBase): for i, v in enumerate(output): self.assertEqual(i, v, f"rank: {self.rank}") + @skipIfHpu + @with_comms() + def test_send_recv_object_list(self, device): + val = 99 if self.rank == 0 else None + object_list = [val] * dist.get_world_size() + if self.rank == 0: + dist.send_object_list(object_list, 1) + if self.rank == 1: + dist.recv_object_list(object_list, 0) + + if self.rank < 2: + self.assertEqual(99, object_list[0]) + else: + self.assertEqual(None, object_list[0]) + @with_comms() def test_broadcast_object_list(self, device): val = 99 if self.rank == 0 else None