mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
## Context This stack prototypes automatic micro-pipelining of `all-gather -> matmul` and `matmul -> reduce-scatter` via Inductor. The idea originates from the paper [Overlap Communication with Dependent Computation via Decomposition in Large Deep Learning Models](https://dl.acm.org/doi/pdf/10.1145/3567955.3567959). The implementation and some key optimizations are heavily influenced by @lw's implementation in xformers. The stack contains several components: - `ProcessGroupCudaP2P` - a thin wrapper around `ProcessGroupNCCL`. It in addition maintains a P2P workspace that enables SM-free, one-sided P2P communication which is needed for optimal micro-pipelining. - `fused_all_gather_matmul` and `fused_matmul_reduce_scatter` dispatcher ops. - Post-grad fx pass that detects `all-gather -> matmul` and `matmul -> reduce-scatter` and replaces them with the fused dispatcher ops. To enable the prototype feature: - Set the distributed backend to `cuda_p2p`. - Set `torch._inductor.config._micro_pipeline_tp` to `True`. *NOTE: the prototype sets nothing in stone w.r.t to each component's design. The purpose is to have a performant baseline with reasonable design on which each component can be further improved.* ## Benchmark Setup: - 8 x H100 (500W) + 3rd gen NVSwitch. - Llama3 8B training w/ torchtitan. - 8-way TP. Reduced the number of layers from 32 to 8 for benchmarking purpose. Trace (baseline): https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/yifu_tmpjaz8zgx0 <img width="832" alt="image" src="https://github.com/pytorch/pytorch/assets/4156752/4addba77-5abc-4d2e-93ea-f68078587fe1"> Trace (w/ micro pipelining): https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/yifu_tmpn073b4wn <img width="963" alt="image" src="https://github.com/pytorch/pytorch/assets/4156752/4f44e78d-8196-43ab-a1ea-27390f07e9d2"> ## This PR `ProcessGroupCudaP2P` is a thin wrapper around `ProcessGroupNCCL`. By default, it routes all collectives to the underlying `ProcessGroupNCCL`. In addition, `ProcessGroupCudaP2P` initializes a P2P workspace that allows direct GPU memory access among the members. The workspace can be used in Python to optimize intra-node communication patterns or to create custom intra-node collectives in CUDA. `ProcessGroupCudaP2P` aims to bridge the gap where certain important patterns can be better optimized via fine-grained P2P memory access than with collectives in the latest version of NCCL. It is meant to complement NCCL rather than replacing it. Usage: ``` # Using ProcessGroupCudaP2P dist.init_process_group(backend="cuda_p2p", ...) # Using ProcessGroupCudaP2P while specifying ProcessGroupCudaP2P.Options pg_options = ProcessGroupCudaP2P.Options() dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...) # Using ProcessGroupCudaP2P while specifying ProcessGroupNCCL.Options pg_options = ProcessGroupNCCL.Options() dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...) # Using ProcessGroupCudaP2P while specifying both # ProcessGroupCudaP2P.Options and ProcessGroupNCCL.Options pg_options = ProcessGroupCudaP2P.Options() pg_options.nccl_options = ProcessGroupNCCL.Options() dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...) # Down-casting the backend to access p2p buffers for cuda_p2p specific # optimizations if is_cuda_p2p_group(group): backend = get_cuda_p2p_backend(group) if required_p2p_buffer_size > backend.get_buffer_size(): # fallback p2p_buffer = backend.get_p2p_buffer(...) else: # fallback ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/122163 Approved by: https://github.com/wanchaol |
||
|---|---|---|
| .. | ||
| distributed | ||
| dynamo | ||
| fastrnns | ||
| framework_overhead_benchmark | ||
| functional_autograd_benchmark | ||
| fuser | ||
| gpt_fast | ||
| inference | ||
| instruction_counts | ||
| nested | ||
| operator_benchmark | ||
| overrides_benchmark | ||
| profiler_benchmark | ||
| record_function_benchmark | ||
| serialization | ||
| sparse | ||
| static_runtime | ||
| tensorexpr | ||
| transformer | ||
| compare-fastrnn-results.py | ||
| compare.sh | ||
| README.md | ||
| upload_scribe.py | ||
PyTorch Benchmarks
This folder contains scripts that produce reproducible timings of various PyTorch features.
It also provides mechanisms to compare PyTorch with other frameworks.
Setup environment
Make sure you're on a machine with CUDA, torchvision, and pytorch installed. Install in the following order:
# Install torchvision. It comes with the pytorch stable release binary
conda install pytorch torchvision -c pytorch
# Install the latest pytorch master from source.
# It should supersede the installation from the release binary.
cd $PYTORCH_HOME
python setup.py build develop
# Check the pytorch installation version
python -c "import torch; print(torch.__version__)"
Benchmark List
Please refer to each subfolder to discover each benchmark suite. Links are provided where descriptions exist: