mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[pgnccl] add a restart test for PGs in blocking mode (#139496)
Summary: Restarting (aborting and re-initialize a PG) is a basic need if we want to achieve in-process restart of PGs without tearing down the whole process. Add this tests to verify that this is supported by current NCCL. Note that this restart test passes steadily only for blocking mode for now. In nonblockin mode. There is problem in either nccl init or abort that needs further investigation Test Plan: new UT Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/139496 Approved by: https://github.com/c-p-i-o, https://github.com/kwen2501
This commit is contained in:
parent
0b13bdd877
commit
4c64a7f33f
1 changed files with 37 additions and 0 deletions
|
|
@ -347,6 +347,43 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
|||
with self.assertRaises(ValueError):
|
||||
dist.all_reduce(t)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
def test_restart_pg(self):
|
||||
# Note: restart test passes steadily only for blocking mode for now.
|
||||
# TODO: expand this test to non-blocking mode
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
device = torch.device(f"cuda:{self.rank % torch.cuda.device_count()}")
|
||||
|
||||
# initialize pg for the first time
|
||||
c10d.init_process_group(
|
||||
"nccl",
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
store=store,
|
||||
)
|
||||
t0 = torch.rand(10, 10, device=device)
|
||||
# First allreduce to lazy initialize default pg
|
||||
dist.all_reduce(t0)
|
||||
torch.cuda.synchronize()
|
||||
# Destroy pg
|
||||
dist.destroy_process_group()
|
||||
|
||||
# re-initialize pg
|
||||
c10d.init_process_group(
|
||||
"nccl",
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
store=store,
|
||||
)
|
||||
t1 = torch.rand(5, 5, device=device)
|
||||
dist.all_reduce(t1)
|
||||
torch.cuda.synchronize()
|
||||
dist.destroy_process_group()
|
||||
# validate default pg is no longer valid
|
||||
with self.assertRaises(ValueError):
|
||||
dist.all_reduce(t1)
|
||||
|
||||
CUDA_12_AND_ABOVE = torch.cuda.is_available() and (
|
||||
torch.version.cuda is not None and int(torch.version.cuda.split(".")[0]) >= 12
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue