Update as per review comments

Fix typo to resolve CI issues.
Fix issues as per review comments.
This commit is contained in:
amathewc 2025-01-21 12:01:43 +02:00 committed by PyTorch MergeBot
parent c8ae41e6ec
commit 6a29ea1efc

View file

@ -30,17 +30,16 @@ if TEST_WITH_DEV_DBG_ASAN:
sys.exit(0)
if TEST_HPU:
BACKEND = dist.Backend.HCCL
device_count = torch.hpu.device_count()
DEVICE = "hpu"
elif TEST_CUDA:
BACKEND = dist.Backend.NCCL
device_count = torch.cuda.device_count()
DEVICE = "cuda"
else:
BACKEND = dist.Backend.GLOO
DEVICE = "cpu"
device_module = torch.get_device_module(DEVICE)
device_count = device_module.device_count()
BACKEND = dist.get_default_backend_for_device(DEVICE)
def with_comms(func=None):
if func is None:
@ -50,10 +49,7 @@ def with_comms(func=None):
@wraps(func)
def wrapper(self, *args, **kwargs):
if (
BACKEND == (dist.Backend.NCCL or dist.Backend.HCCL)
and device_count < self.world_size
):
if DEVICE != "cpu" and device_count < self.world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
kwargs["device"] = DEVICE
@ -84,6 +80,21 @@ class TestObjectCollectives(DistributedTestBase):
for i, v in enumerate(output):
self.assertEqual(i, v, f"rank: {self.rank}")
@skipIfHpu
@with_comms()
def test_send_recv_object_list(self, device):
val = 99 if self.rank == 0 else None
object_list = [val] * dist.get_world_size()
if self.rank == 0:
dist.send_object_list(object_list, 1)
if self.rank == 1:
dist.recv_object_list(object_list, 0)
if self.rank < 2:
self.assertEqual(99, object_list[0])
else:
self.assertEqual(None, object_list[0])
@with_comms()
def test_broadcast_object_list(self, device):
val = 99 if self.rank == 0 else None