ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
Find a file
Tianlei Wu 6ffaaebb60
[CUDA] Attention kernel provider option (#21344)
### Description
* Add a cuda provider option `sdpa_kernel` to choose which attention kernel to run for testing purpose. 
* Allow dump which attention kernel is used per node.
* Reserve  a flag for cudnn flash attention which will be added soon.

#### CUDA provider option sdpa_kernel
Instead of setting environment variable, we also support setting it in
provider option. Note that the setting is global per session. That could
help performance testing of each kernel.

#### Attention Kernel Debug Info
Set an environment variable `ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO=1`,
and ORT will print sdpa kernel used in each node:

For example 
```
ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO=1 ./onnxruntime_test_all --gtest_filter=MultiHeadAttentionTest*
```
It will show debug information of kernel used in testing:
```
[ RUN      ] MultiHeadAttentionTest.SelfAttention_Batch2_HeadSize32_NoBias_NoMask_PackedQKV
AttentionKernelOptions: FLASH_ATTENTION=0 EFFICIENT_ATTENTION=0 TRT_FUSED_ATTENTION=1 CUDNN_FLASH_ATTENTION=0 TRT_FLASH_ATTENTION=1 TRT_CROSS_ATTENTION=0 TRT_CAUSAL_ATTENTION=0 MATH=1
Operator=MultiHeadAttention Node=node1 DataType=fp16 TRT_FUSED_ATTENTION=1
AttentionKernelOptions: FLASH_ATTENTION=0 EFFICIENT_ATTENTION=1 TRT_FUSED_ATTENTION=0 CUDNN_FLASH_ATTENTION=0 TRT_FLASH_ATTENTION=0 TRT_CROSS_ATTENTION=0 TRT_CAUSAL_ATTENTION=0 MATH=1
Operator=MultiHeadAttention Node=node1 DataType=fp16 EFFICIENT_ATTENTION=1
```
In this test case, the debug info shows that one session uses trt fused
attention and another session use efficient attention.
2024-07-19 13:58:54 -07:00
.config
.devcontainer
.gdn
.github Fix lint C++ actions (#21303) 2024-07-11 09:46:41 +08:00
.pipelines Fix onebranch exception in code signing (#21088) 2024-06-19 12:07:17 +08:00
.vscode disable gemm f16 on CPU (#19744) 2024-03-01 13:44:29 -08:00
cgmanifests Update absl (#21300) 2024-07-10 11:14:15 -07:00
cmake [CUDA] Attention kernel provider option (#21344) 2024-07-19 13:58:54 -07:00
csharp Fix typos - 1st Wave (#21278) 2024-07-11 13:35:08 +08:00
dockerfiles Update Dockerfile.cuda (#21042) 2024-06-13 23:50:03 -07:00
docs [CPU] SparseAttention op (#21110) 2024-07-03 21:51:57 -07:00
include/onnxruntime/core [CUDA] Attention kernel provider option (#21344) 2024-07-19 13:58:54 -07:00
java Fix Android build on Windows (#21304) 2024-07-15 12:29:02 -07:00
js [js/web] fix vulnerable version of dependencies (#21412) 2024-07-19 11:11:30 -07:00
objectivec Fix Objective-C static analysis warnings. (#20417) 2024-04-24 11:48:29 -07:00
onnxruntime [CUDA] Attention kernel provider option (#21344) 2024-07-19 13:58:54 -07:00
orttraining change ci docker image to rocm6.1 (#21296) 2024-07-18 14:50:01 +08:00
rust
samples
tools Update azure-kusto-data and azure-kusto-ingest (#21409) 2024-07-18 14:26:26 -07:00
winml Remove core/common/gsl.h (#20894) 2024-07-08 18:09:39 -07:00
.clang-format
.clang-tidy
.dockerignore
.gitattributes
.gitignore Build onnxruntime.dll as arm64x (#18633) 2023-12-06 16:49:00 -08:00
.gitmodules [js/web] optimize module export and deployment (#20165) 2024-05-20 09:51:16 -07:00
.lintrunner.toml Make Flash Attention work on Windows (#21015) 2024-06-24 09:43:49 -07:00
build.bat
build.sh
build_arm64x.bat remove unnecessary environment variable (#19166) 2024-01-16 16:24:37 -08:00
CITATION.cff Fix citation author name issue (#19597) 2024-02-22 17:03:56 -08:00
CODEOWNERS
CONTRIBUTING.md
lgtm.yml
LICENSE
NuGet.config
ort.wprp Fully dynamic ETW controlled logging for ORT and QNN logs (#20537) 2024-06-06 21:11:14 -07:00
ORT_icon_for_light_bg.png
packages.config Update DML to 1.14.1 (#20380) 2024-04-18 22:43:41 -07:00
pyproject.toml [CUDA] Add SparseAttention operator for Phi-3-small (#20216) 2024-04-30 09:06:29 -07:00
README.md Update README.md (#18963) 2024-01-03 17:26:25 -08:00
requirements-dev.txt
requirements-doc.txt
requirements-lintrunner.txt Bump ruff to 0.3.2 and black to 24 (#19878) 2024-03-13 10:00:32 -07:00
requirements-training.txt
requirements.txt Add compatibility for NumPy 2.0 (#21085) 2024-06-27 13:50:53 -07:00
SECURITY.md
setup.py Migraphx ep windows build (#21284) 2024-07-11 21:21:38 -07:00
ThirdPartyNotices.txt Fix HalideIR title in third party notices reference (#20190) 2024-04-05 11:12:43 -07:00
VERSION_NUMBER Bump up version in main from 1.18.0 to 1.19.0 (#20489) 2024-04-29 20:21:41 -07:00

ONNX Runtime is a cross-platform inference and training machine-learning accelerator.

ONNX Runtime inference can enable faster customer experiences and lower costs, supporting models from deep learning frameworks such as PyTorch and TensorFlow/Keras as well as classical machine learning libraries such as scikit-learn, LightGBM, XGBoost, etc. ONNX Runtime is compatible with different hardware, drivers, and operating systems, and provides optimal performance by leveraging hardware accelerators where applicable alongside graph optimizations and transforms. Learn more →

ONNX Runtime training can accelerate the model training time on multi-node NVIDIA GPUs for transformer models with a one-line addition for existing PyTorch training scripts. Learn more →

Get Started & Resources

Builtin Pipeline Status

System Inference Training
Windows Build Status
Build Status
Build Status
Linux Build Status
Build Status
Build Status
Build Status
Build Status
Build Status
Build Status
Build Status
Mac Build Status
Android Build Status
iOS Build Status
Web Build Status
Other Build Status

Third-party Pipeline Status

System Inference Training
Linux Build Status

Data/Telemetry

Windows distributions of this project may collect usage data and send it to Microsoft to help improve our products and services. See the privacy statement for more details.

Contributions and Feedback

We welcome contributions! Please see the contribution guidelines.

For feature requests or bug reports, please file a GitHub Issue.

For general discussion or questions, please use GitHub Discussions.

Code of Conduct

This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact opencode@microsoft.com with any additional questions or comments.

License

This project is licensed under the MIT License.