ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
Find a file
Tianlei Wu baaef59696
Add sparse attention kernel for H100 (sm90) (#20553)
### Description
Follow up of https://github.com/microsoft/onnxruntime/pull/20216 to add
sparse attention kernel compiled by Triton for H100 (sm90).
- [x] Refine sparse attention v1 kernel compilation (remove some
combinations)
- [x] compile kernels for v1 kernels
- [x] compile kernels for H100
- [x] run performance tests

### Performane

Test setting `batch_size=4, num_heads=32, max_seq_len=8192,
head_size=128, sparse_block_size=64, local_blocks=16, vert_stride=8,
num_layout=8`

We compare sparse attention to corresponding GQA with local attention
windows size 1024, or GQA with dense causal. Note that ORT-GQA-Dense has
more computation than ORT-SparseAtt, while ORT-GQA-Local has less
computation (no vertial strides) than ORT-SparseAtt. They are added for
reference. It is not fair comparison, but could show the benefit of
sparsity vs dense.

Example results in Azure Standard_ND96isr_H100_v5 VM with NVIDIA
H100-80GB-HBM3 GPU (sm=90):
```
    prompt-sm90-batch4-head32-d128-local16-vert8-torch.float16:
       sequence_length  TORCH-GQA  ORT-GQA-Dense  ORT-GQA-Local  ORT-SparseAtt
    0             16.0   0.079877       0.006362       0.006403       0.042758
    1             32.0   0.086920       0.016404       0.016686       0.044183
    2             64.0   0.090727       0.020429       0.020409       0.045343
    3            128.0   0.128148       0.032009       0.031984       0.051516
    4            256.0   0.323933       0.074110       0.073920       0.068308
    5            512.0   1.021856       0.162167       0.161951       0.109226
    6           1024.0   3.596002       0.452629       0.452780       0.231653
    7           2048.0  13.865088       1.499534       1.195749       0.515488
    8           4096.0   0.000000       5.454785       2.669682       1.163233
    9           8192.0   0.000000      22.068159       6.018604       2.772873

    token-sm90-batch4-head32-d128-local16-vert8-torch.float16:
       past_sequence_length  TORCH-GQA  ORT-GQA-Dense  ORT-GQA-Local  ORT-SparseAtt
    0                  16.0   0.104460       0.012652       0.012661       0.069549
    1                  32.0   0.113866       0.012776       0.012765       0.069024
    2                  64.0   0.124600       0.016791       0.012672       0.069397
    3                 128.0   0.108658       0.017900       0.018294       0.074844
    4                 256.0   0.115463       0.029409       0.029608       0.078911
    5                 512.0   0.149824       0.033968       0.033701       0.092998
    6                1024.0   0.234050       0.042930       0.042951       0.116920
    7                2048.0   0.390695       0.061462       0.043008       0.121555
    8                4096.0   0.000000       0.097505       0.042948       0.134757
    9                8191.0   0.000000       0.165861       0.043542       0.158796
```
The following might be able to help performance on short sequence
length. Need update operator spec:
 Fall back to flash attention when total_sequence length < local_blocks * block_size

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2024-05-04 19:53:32 -07:00
.config
.devcontainer
.gdn
.github
.pipelines
.vscode
cgmanifests
cmake
csharp
dockerfiles
docs
include/onnxruntime/core
java
js
objectivec
onnxruntime Add sparse attention kernel for H100 (sm90) (#20553) 2024-05-04 19:53:32 -07:00
orttraining
rust
samples
tools
winml
.clang-format
.clang-tidy
.dockerignore
.gitattributes
.gitignore
.gitmodules
.lintrunner.toml
build.bat
build.sh
build_arm64x.bat
CITATION.cff
CODEOWNERS
CONTRIBUTING.md
lgtm.yml
LICENSE
NuGet.config
ort.wprp
ORT_icon_for_light_bg.png
packages.config
pyproject.toml
README.md
requirements-dev.txt
requirements-doc.txt
requirements-lintrunner.txt
requirements-training.txt
requirements.txt.in
SECURITY.md
setup.py
ThirdPartyNotices.txt
VERSION_NUMBER

ONNX Runtime is a cross-platform inference and training machine-learning accelerator.

ONNX Runtime inference can enable faster customer experiences and lower costs, supporting models from deep learning frameworks such as PyTorch and TensorFlow/Keras as well as classical machine learning libraries such as scikit-learn, LightGBM, XGBoost, etc. ONNX Runtime is compatible with different hardware, drivers, and operating systems, and provides optimal performance by leveraging hardware accelerators where applicable alongside graph optimizations and transforms. Learn more →

ONNX Runtime training can accelerate the model training time on multi-node NVIDIA GPUs for transformer models with a one-line addition for existing PyTorch training scripts. Learn more →

Get Started & Resources

Builtin Pipeline Status

System Inference Training
Windows Build Status
Build Status
Build Status
Linux Build Status
Build Status
Build Status
Build Status
Build Status
Build Status
Build Status
Build Status
Mac Build Status
Android Build Status
iOS Build Status
Web Build Status
Other Build Status

Third-party Pipeline Status

System Inference Training
Linux Build Status

Data/Telemetry

Windows distributions of this project may collect usage data and send it to Microsoft to help improve our products and services. See the privacy statement for more details.

Contributions and Feedback

We welcome contributions! Please see the contribution guidelines.

For feature requests or bug reports, please file a GitHub Issue.

For general discussion or questions, please use GitHub Discussions.

Code of Conduct

This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact opencode@microsoft.com with any additional questions or comments.

License

This project is licensed under the MIT License.