mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
This PR introduces the following:
### torch.ops.symm_mem._async_input_mm
`_async_input_mm(Tensor a, Tensor b, Tensor a_chunk_signals, int a_chunk_pivot) -> Tensor`
An mm impl that supports consuming asynchronous input. It guarantees the following rasterization order, and that the corresponding signal arrives before an input chunk is consumed.
```
num_chunks = a_chunks_signals.numel()
for chunk_idx in range(a_chunk_pivot, num_chunks + a_chunk_pivot):
chunk_idx = chunk_idx % num_chunks
wait_signal(a_chunk_signals, chunk_idx)
# Compute output tiles that consumes the input chunk
```
### PersistentAsyncInputScheduler
This is a forked version of PersistentScheduler that supports consuming asynchronous input. This tile scheduler introduces the following arguments:
- `tiles_per_chunk_m` – Specifies the size of an M chunk. Chunks are the granularity at which the asynchronous input becomes ready. It must be an interger multiple of the size of an M tile.
- `chunk_signals` – `chunk_signals[i] == 1` indicates that chunk i is ready. Before returning a work tile, get_current_work() waits for the signal to ensure that the corresponding chunk is ready.
- `tile_idx_pivot_m` – After applying swizzling, apply `pivot(m) => (m + tile_idx_pivot_m) % tiles_m` to `m`. In a distributed setting, this allows different ranks to process different m indices at the same time, thus avoiding communication hotspots.
Note that this scheduler currently only supports the `KernelTmaWarpSpecializedCooperative` kernel schedule. This is enforced via the template argument `KernelSchedule`.
Usage:
```
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>,
CollectiveMainloop,
CollectiveEpilogue,
cutlass::gemm::PersistentAsyncInputScheduler<KernelSchedule>>;
```
### _fused_all_gather_matmul_native
An ag-mm impl that combines `torch.ops.symm_mem._async_input_mm` and progress-aware all-gather. This is not yet enabled via the async-tp passes. We will use it as a backend to optimize the current decomposition-based async-tp impl.
## Benchmarks
### 4096x3584x8192
- cublas + nccl: 539us
- decomp-based async-tp w/o cuda graph: 694us
- decomp-based async-tp w/ cuda graph: 478us
- new cutlass kernel: 408us
<img width="478" alt="image" src="https://github.com/user-attachments/assets/39f316ab-36c5-4b41-af77-07854a385dfc">
### 2048x3584x8192
- cublas + nccl: 301us
- decomp-based async-tp w/o cuda graph: 687us
- decomp-based async-tp w/ cuda graph: 356us
- new cutlass kernel: 276us
<img width="441" alt="image" src="https://github.com/user-attachments/assets/9e23ce21-863b-43dd-a562-fb05d3a5a144">
## Next Steps
- Add tuning logic
- Use `_fused_all_gather_matmul_native` as a backend for the decomp-based async-tp impl
Differential temp Revision: [D65623152](https://our.internmc.facebook.com/intern/diff/D65623152)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139227
Approved by: https://github.com/weifengpy, https://github.com/Chillee
|
||
|---|---|---|
| .. | ||
| cutlass/gemm/kernel | ||
| AsyncMM.cu | ||
| AsyncMM.cuh | ||