## Summary
This PR added 3 intra-node GPU allreduce algorithms to PyTorch:
- One-shot allreduce (inspired by FasterTransformer): all ranks simultaneously read and accumulate data from other ranks.
- Two-shot allreduce (inspired by FasterTransformer): all ranks simultanesouly read and accumulate `1 / world_size` data from other ranks. Then all ranks read accumulated data from other ranks. (effectively one-shot reduce-scatter + one-shot all-gather).
- Hybrid cube mesh allreduce (original): a one-shot allreduce variant that avoids transmission over PCIe on HCM topology.
## Micro Benchmarks



## Details
The intra-node algos are organized behind `c10d::IntraNodeComm`, which is responsible for:
- Managing handshaking and cuda IPC handle exchange among ranks.
- Querying NVLink connection and detecting topology.
- Performing algo selection based on available info.
- Launching the selected allreduce kernel.
`c10d::IntraNodeComm` is integrated into `c10d::ProcessGroupNCCL` as follows:
- When the `ENABLE_INTRA_NODE_COMM` environment variable is set, `c10d::ProcessGroupNCCL` initialize a `c10d::IntraNodeComm` for its ranks.
- If the setup is not suitable for intra-node comm (e.g. not all ranks are from the same node), the rendezvous logic guarantees all participants fall back consistently.
- `c10d::ProcessGroupNCCL::allreduce` consults `c10d::IntraNodeComm` whether to use intra-node allreduce and carries out the communication accordingly.
We currently detect two types of topoloies from the nNVLink connection mesh:
- Fully connected: all GPU pairs has direct NVLink connection (e.g. NVSwitch or fully connected sub-set of hybrid cube mesh)
- `msg <= 256KB`: one-shot allreduce.
- `256KB < msg <= 10MB`: two-shot allreduce.
- `msg > 10MB`: instructs the caller to fallback to NCCL.
- Hybrid cube mesh
- `msg <= 256KB`: one-shot allreduce.
- `msg > 256KB`: instructs the caller to fallback to NCCL.
## Next Steps
- Fine tune algo selection based on GPU model, topology, link speed.
- Potentially optimize the two-shot allreduce impl. Accroding to FasterTransformer, two-shot allreduce is preferred until 50MB. There might be room for improvement, but PyTorch does impose more constraints:
- FasterTransformer uses a single process to drive multiple devices. It can use `cudaDeviceEnablePeerAccess` enable device-level peer access.
- PyTorch uses multiple process to drive multiple devices. With cuda IPC, a device can only share a specific region to other devices. This means extra copies may be unavoidable.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114001
Approved by: https://github.com/yf225
## Summary
This PR added 3 intra-node GPU allreduce algorithms to PyTorch:
- One-shot allreduce (inspired by FasterTransformer): all ranks simultaneously read and accumulate data from other ranks.
- Two-shot allreduce (inspired by FasterTransformer): all ranks simultanesouly read and accumulate `1 / world_size` data from other ranks. Then all ranks read accumulated data from other ranks. (effectively one-shot reduce-scatter + one-shot all-gather).
- Hybrid cube mesh allreduce (original): a one-shot allreduce variant that avoids transmission over PCIe on HCM topology.
## Micro Benchmarks



## Details
The intra-node algos are organized behind `c10d::IntraNodeComm`, which is responsible for:
- Managing handshaking and cuda IPC handle exchange among ranks.
- Querying NVLink connection and detecting topology.
- Performing algo selection based on available info.
- Launching the selected allreduce kernel.
`c10d::IntraNodeComm` is integrated into `c10d::ProcessGroupNCCL` as follows:
- When the `ENABLE_INTRA_NODE_COMM` environment variable is set, `c10d::ProcessGroupNCCL` initialize a `c10d::IntraNodeComm` for its ranks.
- If the setup is not suitable for intra-node comm (e.g. not all ranks are from the same node), the rendezvous logic guarantees all participants fall back consistently.
- `c10d::ProcessGroupNCCL::allreduce` consults `c10d::IntraNodeComm` whether to use intra-node allreduce and carries out the communication accordingly.
We currently detect two types of topoloies from the nNVLink connection mesh:
- Fully connected: all GPU pairs has direct NVLink connection (e.g. NVSwitch or fully connected sub-set of hybrid cube mesh)
- `msg <= 256KB`: one-shot allreduce.
- `256KB < msg <= 10MB`: two-shot allreduce.
- `msg > 10MB`: instructs the caller to fallback to NCCL.
- Hybrid cube mesh
- `msg <= 256KB`: one-shot allreduce.
- `msg > 256KB`: instructs the caller to fallback to NCCL.
## Next Steps
- Fine tune algo selection based on GPU model, topology, link speed.
- Potentially optimize the two-shot allreduce impl. Accroding to FasterTransformer, two-shot allreduce is preferred until 50MB. There might be room for improvement, but PyTorch does impose more constraints:
- FasterTransformer uses a single process to drive multiple devices. It can use `cudaDeviceEnablePeerAccess` enable device-level peer access.
- PyTorch uses multiple process to drive multiple devices. With cuda IPC, a device can only share a specific region to other devices. This means extra copies may be unavoidable.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114001
Approved by: https://github.com/yf225
Since https://github.com/pytorch/pytorch/pull/99699 introduced a dependency on nvml for oom reporting in `c10/cuda/driver_api.h`, `c10/cuda/driver_api.cpp`, and `reportProcessMemoryInfo` from `c10/cuda/CUDACachingAllocator.cpp`, we've seen failures regarding cuda expandable segments and oom reporting in NVIDIA's internal CI, specifically on Jetson devices which don't have nvml support as it is incompatible with Jetson. Example failures using the latest upstream on Orin AGX node:
`python test/test_cuda.py -k test_notifies_oom` generates
```
Traceback (most recent call last):
File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
self.run()
File "/usr/lib/python3.10/threading.py", line 953, in run
self._target(*self._args, **self._kwargs)
File "/opt/pytorch/pytorch/test/test_cuda.py", line 1643, in _worker
results[t] = torch.nn.functional.conv2d(results[t], weight, padding=0)
RuntimeError: CUDA driver error: out of memory
```
`python test/test_cuda_expandable_segments.py` generates
```
Traceback (most recent call last):
File "/opt/pytorch/pytorch/test/test_cuda_expandable_segments.py", line 12, in <module>
exec(compile(open(filepath).read(), filepath, mode='exec'))
File "/opt/pytorch/pytorch/test/test_cuda.py", line 66, in <module>
class TestCuda(TestCase):
File "/opt/pytorch/pytorch/test/test_cuda.py", line 1609, in TestCuda
@unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
File "/usr/local/lib/python3.10/dist-packages/torch/testing/_internal/common_utils.py", line 4628, in wrapped
self._value = self._cb()
File "/usr/local/lib/python3.10/dist-packages/torch/testing/_internal/common_cuda.py", line 20, in <lambda>
TEST_CUDNN = LazyVal(lambda: TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE)))
RuntimeError: handle_0 INTERNAL ASSERT FAILED at "/opt/pytorch/pytorch/c10/cuda/driver_api.cpp":15, please report a bug to PyTorch.
```
This PR intends to fix this issue by adding various dlopen checks to make sure nvml actually exists, and safely fall back to using the older libcuda based features of cuda expandable segments and oom reporting if nvml is not found.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112121
Approved by: https://github.com/eqy, https://github.com/ngimel, https://github.com/albanD
This PR adds calls to nvml during an OOM to find out the total memory
in use by the process and any other CUDA processes on the device.
This makes it easier to identify cases where non-PyTorch libraries have
allocated memory or another process (such as a data loader) has also
allocated memory on the device.
This also rewords the other parts of the error message to make the meaning
of the memory statistics more clear with this new information:
"""
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 138.00 MiB.
GPU 0 has a total capacty of 15.90 GiB of which 8.44 MiB is free.
Process 1246069 has 577.00 MiB memory in use. Including non-PyTorch memory,
this process has 15.32 GiB memory in use. Of the allocated memory
14.12 GiB is allocated by PyTorch, and 410.41 MiB is reserved
by PyTorch but unallocated. If reserved but unallocated memory is large
try setting max_split_size_mb to avoid fragmentation. See documentation
for Memory Management and PYTORCH_CUDA_ALLOC_CONF
"""
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99699
Approved by: https://github.com/ngimel
Common advice we give for handling memory fragmentation issues is to
allocate a big block upfront to reserve memory which will get split up later.
For programs with changing tensor sizes this can be especially helpful to
avoid OOMs that happen the first time we see a new largest input and would
otherwise have to allocate new segments.
However the issue with allocating a block upfront is that is nearly impossible
to correctly estimate the size of that block. If too small, space in the block
will run out and the allocator will allocate separate blocks anyway. Too large,
and other non-PyTorch libraries might stop working because they cannot allocate
any memory.
This patch provides the same benefits as using a pre-allocating block but
without having to choose its size upfront. Using the cuMemMap-style APIs,
it adds the ability to expand the last block in a segment when more memory is
needed.
Compared to universally using cudaMallocAsync to avoid fragmentation,
this patch can fix this common fragmentation issue while preserving most
of the existing allocator behavior. This behavior can be enabled and disabled dynamically.
This should allow users to, for instance, allocate long-lived parameters and state in individual buffers,
and put temporary state into the large expandable blocks, further reducing
fragmentation.
See inline comments for information about the implementation and its limitations.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96995
Approved by: https://github.com/eellison