pytorch/torch
Yifu Wang adfbd2b219 Introduce 3 low-latency, intra-node allreduce algorithms for small messages to PyTorch (#114001)
## Summary
This PR added 3 intra-node GPU allreduce algorithms to PyTorch:
- One-shot allreduce (inspired by FasterTransformer): all ranks simultaneously read and accumulate data from other ranks.
- Two-shot allreduce (inspired by FasterTransformer): all ranks simultanesouly read and accumulate `1 / world_size` data from other ranks. Then all ranks read accumulated data from other ranks. (effectively one-shot reduce-scatter + one-shot all-gather).
- Hybrid cube mesh allreduce (original): a one-shot allreduce variant that avoids transmission over PCIe on HCM topology.

## Micro Benchmarks
![image](https://github.com/pytorch/pytorch/assets/4156752/7bd25ffc-cd5b-4acb-bd65-b01bc136726e)

![image](https://github.com/pytorch/pytorch/assets/4156752/3ced31b4-6c31-4f34-a2d8-c072df29ae0e)

![image](https://github.com/pytorch/pytorch/assets/4156752/5b942c05-4fcc-4ec9-ae29-12c64080bb1c)

## Details
The intra-node algos are organized behind `c10d::IntraNodeComm`, which is responsible for:
- Managing handshaking and cuda IPC handle exchange among ranks.
- Querying NVLink connection and detecting topology.
- Performing algo selection based on available info.
- Launching the selected allreduce kernel.

`c10d::IntraNodeComm` is integrated into `c10d::ProcessGroupNCCL` as follows:
- When the `ENABLE_INTRA_NODE_COMM` environment variable is set, `c10d::ProcessGroupNCCL` initialize a `c10d::IntraNodeComm` for its ranks.
  - If the setup is not suitable for intra-node comm (e.g. not all ranks are from the same node), the rendezvous logic guarantees all participants fall back consistently.
- `c10d::ProcessGroupNCCL::allreduce` consults `c10d::IntraNodeComm` whether to use intra-node allreduce and carries out the communication accordingly.

We currently detect two types of topoloies from the nNVLink connection mesh:
- Fully connected: all GPU pairs has direct NVLink connection (e.g. NVSwitch or fully connected sub-set of hybrid cube mesh)
  - `msg <= 256KB`: one-shot allreduce.
  - `256KB < msg <= 10MB`: two-shot allreduce.
  -  `msg > 10MB`: instructs the caller to fallback to NCCL.
- Hybrid cube mesh
  - `msg <= 256KB`: one-shot allreduce.
  - `msg > 256KB`: instructs the caller to fallback to NCCL.

## Next Steps
- Fine tune algo selection based on GPU model, topology, link speed.
- Potentially optimize the two-shot allreduce impl. Accroding to FasterTransformer, two-shot allreduce is preferred until 50MB. There might be room for improvement, but PyTorch does impose more constraints:
  - FasterTransformer uses a single process to drive multiple devices. It can use `cudaDeviceEnablePeerAccess` enable device-level peer access.
  - PyTorch uses multiple process to drive multiple devices. With cuda IPC, a device can only share a specific region to other devices. This means extra copies may be unavoidable.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114001
Approved by: https://github.com/yf225
2023-12-14 08:13:08 +00:00
..
_awaits
_C [c10d] Create a python c10d API _set_pg_timeout to set timeout (#115453) 2023-12-12 20:52:43 +00:00
_C_flatbuffer
_custom_op Allow functionalization to work with optional mutable (#114803) 2023-11-30 23:48:03 +00:00
_decomp Use get_mkldnn_enabled for decompositions (#115448) 2023-12-12 22:42:51 +00:00
_dispatch
_dynamo [Reland][HigherOrderOp] remove unused get_item in MapHigherOrder (#115758) 2023-12-14 00:41:46 +00:00
_export [export] Preserve FQN in export_to_torch_ir (#115462) 2023-12-13 04:58:47 +00:00
_functorch Consider storage_changed for assigning alias_of_input in aot_autograd when computing differentiable outputs that alias each other (#115315) 2023-12-12 23:21:58 +00:00
_higher_order_ops Allow preserve_rng_state=True when torch.compile + selective checkpointing + CUDA (#113718) 2023-12-09 01:47:25 +00:00
_inductor [inductor] Added non-integer expr support for floordiv in triton codegen (#115751) 2023-12-13 23:17:42 +00:00
_lazy
_library Refactor can_auto_functionalize (#115134) 2023-12-05 22:43:06 +00:00
_logging Sort the output of TORCH_LOGS=help (#114657) 2023-11-30 20:13:51 +00:00
_numpy [BE]: Enable a PLC0131, PLC0132, PLC0205. Fix PLC0132 bug. (#115015) 2023-12-02 20:35:10 +00:00
_prims Add support for torch.Generator type in TorchScript (#110413) 2023-11-21 23:07:21 +00:00
_prims_common [inductor] Allow sympy expressions to participate in type promotion (#115676) 2023-12-13 22:22:37 +00:00
_refs Add decomposition for torch.block_diag (#115096) 2023-12-11 20:04:22 +00:00
_subclasses Extend auto_functionalized to support ops that return Tensors (#115135) 2023-12-05 22:43:06 +00:00
_vendor vendor packaging.version (#114108) 2023-11-21 11:51:23 +00:00
amp Add Half support for CPU autocast on eager mode (#112484) 2023-11-21 20:08:28 +00:00
ao [quant][fx] Lower operator.matmul in convert_fx (#113954) 2023-12-12 00:34:58 +00:00
autograd [BC breaking] Remove check_sparse_nnz argument of gradcheck (#115658) 2023-12-13 17:34:30 +00:00
backends [MPS] Add MacOS 14 runtime check (#115512) 2023-12-11 21:11:42 +00:00
compiler
contrib
cpu
csrc Introduce 3 low-latency, intra-node allreduce algorithms for small messages to PyTorch (#114001) 2023-12-14 08:13:08 +00:00
cuda Add bsr_dense_addmm triton kernel (#114595) 2023-11-29 05:29:25 +00:00
distributed Let all_reduce_coalesced accept one tensor as well (#115650) 2023-12-13 21:32:01 +00:00
distributions Fix hang in VonMises rejection sampling for small values of concentration (#114498) 2023-12-04 23:07:06 +00:00
export [export] Preserve FQN in export_to_torch_ir (#115462) 2023-12-13 04:58:47 +00:00
fft
func
futures
fx [sigmoid] fix for FX tracing unflattened modules (#115708) 2023-12-13 19:43:46 +00:00
jit [BE][Easy]: Apply RUF019: remove duplicate checks for dict access (#114478) 2023-11-29 00:14:02 +00:00
legacy
lib
linalg
masked make_fx can now SymIntify int inputs (#113452) 2023-11-18 06:39:09 +00:00
monitor
mps
multiprocessing Robustify torch.multiprocessing.spawn error reporting to be less deadlock prone (#114688) 2023-12-09 03:36:43 +00:00
nested Fix SDPA for SAM (#115636) 2023-12-12 18:52:38 +00:00
nn Fix backward for SDPA NT jagged layout (#115576) 2023-12-12 18:35:40 +00:00
onnx Store user model to simplify ONNXProgram.{adapt_torch_*,__call__} APIs (#115281) 2023-12-09 07:46:12 +00:00
optim Added More Information About Adadelta Optimizer (#106290) 2023-12-05 05:55:16 +00:00
package
profiler
quantization
signal
sparse Add instructions for generating optimal Triton kernel parameters of bsr_dense_addmm (#115504) 2023-12-12 16:44:51 +00:00
special
testing [CUDA][FP8] Skip test_dtypes on FP8 _scaled_mm (#115661) 2023-12-14 05:12:33 +00:00
utils [pytree] expand tree_map to accept multi-inputs (#115642) 2023-12-14 06:16:42 +00:00
__config__.py
__future__.py
__init__.py Add is_integer to SymFloat (#114703) 2023-12-07 23:23:53 +00:00
_appdirs.py
_classes.py
_compile.py
_custom_ops.py
_deploy.py
_guards.py Add Stateful/Stateless symbolic contexts, use fresh fake mode for dynamo backends (#113926) (#114526) 2023-11-26 23:40:32 +00:00
_jit_internal.py
_linalg_utils.py
_lobpcg.py
_lowrank.py
_meta_registrations.py Fix backward for SDPA NT jagged layout (#115576) 2023-12-12 18:35:40 +00:00
_namedtensor_internals.py
_ops.py torch.compile should auto-functionalize certain mutable ops (#114955) 2023-12-05 14:53:08 +00:00
_python_dispatcher.py
_sources.py
_storage_docs.py
_streambase.py
_tensor.py Make Float8 types serializeable (#114662) 2023-11-29 23:23:23 +00:00
_tensor_docs.py [doc] two diff meanings of rv generated by torch.tensor.geometric_ and torch.distributions.geometric.Geometric (#113183) 2023-11-15 03:49:04 +00:00
_tensor_str.py Do not error when printing view created in no-grad modified in-place in no-grad (#113716) 2023-11-16 18:57:56 +00:00
_torch_docs.py Updated docs for deprecated torch.set_default_tensor_type (#115041) 2023-12-07 16:17:36 +00:00
_utils.py Make Float8 types serializeable (#114662) 2023-11-29 23:23:23 +00:00
_utils_internal.py [inductor][Observability] Add log for Optimus to enable easier debug (#110452) 2023-12-01 18:25:56 +00:00
_VF.py
_vmap_internals.py
_weights_only_unpickler.py Make Float8 types serializeable (#114662) 2023-11-29 23:23:23 +00:00
abi-check.cpp
CMakeLists.txt Revert "[Reland2] Update NVTX to NVTX3 (#109843)" 2023-12-05 16:10:20 +00:00
custom_class.h [Reland] [1/N] Fixes clang-tidy warnings in header files (#114668) 2023-11-29 07:11:51 +00:00
custom_class_detail.h
extension.h
functional.py make_fx can now SymIntify int inputs (#113452) 2023-11-18 06:39:09 +00:00
hub.py
library.h
library.py Optimize inspect.stack() call in caffe2/torch/library.py (#114700) 2023-11-29 20:54:02 +00:00
overrides.py Add python and C++ support for LPPool3d (#114199) 2023-12-08 18:18:44 +00:00
py.typed
quasirandom.py
random.py
README.txt
return_types.py [pytree] register pytree node type in both C++ pytree and Python pytree (#112111) 2023-11-28 11:41:38 +00:00
script.h
serialization.py [BE] Do not warn when safely loading legacy dicts (#113614) 2023-11-14 22:09:10 +00:00
storage.py
torch_version.py vendor packaging.version (#114108) 2023-11-21 11:51:23 +00:00
types.py improve annotation device parameters where a device ordinal is allowed (#113647) 2023-11-17 14:41:22 +00:00
version.py.tpl

Note [TH abstraction violation]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

TH/THC provide some hpp headers, which are proper C++ headers rather than
C headers.  These headers serve double duty as *internal implementation
detail* headers, whose contents should largely not be used by external
clients.

Ideally, we would not install these headers at all; instead, you should
use public functions (in headers like `THTensor.h`, NOT `THTensor.hpp`)
to manipulate these structs.  However, there are a few places
in torch/csrc where we violate this abstraction.  They are marked with
a pointer to this note.  Each of those sites will have to be refactored
when we refactor the guts of THTensor and related structures.