mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Update as per review comments
Fix typo to resolve CI issues. Fix issues as per review comments.
This commit is contained in:
parent
c8ae41e6ec
commit
6a29ea1efc
1 changed files with 20 additions and 9 deletions
|
|
@ -30,17 +30,16 @@ if TEST_WITH_DEV_DBG_ASAN:
|
|||
sys.exit(0)
|
||||
|
||||
if TEST_HPU:
|
||||
BACKEND = dist.Backend.HCCL
|
||||
device_count = torch.hpu.device_count()
|
||||
DEVICE = "hpu"
|
||||
elif TEST_CUDA:
|
||||
BACKEND = dist.Backend.NCCL
|
||||
device_count = torch.cuda.device_count()
|
||||
DEVICE = "cuda"
|
||||
else:
|
||||
BACKEND = dist.Backend.GLOO
|
||||
DEVICE = "cpu"
|
||||
|
||||
device_module = torch.get_device_module(DEVICE)
|
||||
device_count = device_module.device_count()
|
||||
BACKEND = dist.get_default_backend_for_device(DEVICE)
|
||||
|
||||
|
||||
def with_comms(func=None):
|
||||
if func is None:
|
||||
|
|
@ -50,10 +49,7 @@ def with_comms(func=None):
|
|||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if (
|
||||
BACKEND == (dist.Backend.NCCL or dist.Backend.HCCL)
|
||||
and device_count < self.world_size
|
||||
):
|
||||
if DEVICE != "cpu" and device_count < self.world_size:
|
||||
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
|
||||
|
||||
kwargs["device"] = DEVICE
|
||||
|
|
@ -84,6 +80,21 @@ class TestObjectCollectives(DistributedTestBase):
|
|||
for i, v in enumerate(output):
|
||||
self.assertEqual(i, v, f"rank: {self.rank}")
|
||||
|
||||
@skipIfHpu
|
||||
@with_comms()
|
||||
def test_send_recv_object_list(self, device):
|
||||
val = 99 if self.rank == 0 else None
|
||||
object_list = [val] * dist.get_world_size()
|
||||
if self.rank == 0:
|
||||
dist.send_object_list(object_list, 1)
|
||||
if self.rank == 1:
|
||||
dist.recv_object_list(object_list, 0)
|
||||
|
||||
if self.rank < 2:
|
||||
self.assertEqual(99, object_list[0])
|
||||
else:
|
||||
self.assertEqual(None, object_list[0])
|
||||
|
||||
@with_comms()
|
||||
def test_broadcast_object_list(self, device):
|
||||
val = 99 if self.rank == 0 else None
|
||||
|
|
|
|||
Loading…
Reference in a new issue