mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
- [x] Direct dependency on UCX is completely removed, UCC active set API always enabled - [x] Remove `TORCH_UCC_PROFILING_ENABLE`, always enable profiling - [x] Fixes profiling of `recv` and `all_gather` - [x] Use the NCCL TL of UCC on CUDA, as the UCP TL is not well supported on CUDA Most tests are passing, but there are a few skipped tests: - `scatter` and `gather` are not supported by the UCP TL of UCC on CPU tensors - A few flaky tests in PyTorch's CI environment - Profiler-related failures, some of them will be fixed by @Fuzzkatt in https://github.com/pytorch/pytorch/pull/84368 After this PR is merged, I will continue to work on these skipped failures. Pull Request resolved: https://github.com/pytorch/pytorch/pull/83285 Approved by: https://github.com/vtlam, https://github.com/malfet, https://github.com/kwen2501
40 lines
1.1 KiB
Python
40 lines
1.1 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
import os
|
|
import sys
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = False
|
|
|
|
if not dist.is_available():
|
|
print("Distributed not available, skipping tests", file=sys.stderr)
|
|
sys.exit(0)
|
|
|
|
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN, NO_MULTIPROCESSING_SPAWN
|
|
from torch.testing._internal.distributed.distributed_test import (
|
|
DistributedTest, TestDistBackend
|
|
)
|
|
|
|
if TEST_WITH_DEV_DBG_ASAN:
|
|
print("Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr)
|
|
sys.exit(0)
|
|
|
|
if NO_MULTIPROCESSING_SPAWN:
|
|
print("Spawn not available, skipping tests.", file=sys.stderr)
|
|
sys.exit(0)
|
|
|
|
BACKEND = os.environ["BACKEND"]
|
|
|
|
if BACKEND == "gloo" or BACKEND == "nccl" or BACKEND == "ucc":
|
|
class TestDistBackendWithSpawn(TestDistBackend, DistributedTest._DistTestBase):
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
self._spawn_processes()
|
|
torch.backends.cudnn.flags(allow_tf32=False).__enter__()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|