onnxruntime/cmake
aciddelgado ebd0368bb0
Make Flash Attention work on Windows (#21015)
### Description
Previously, Flash Attention only worked on Linux systems. This PR will
make it work and enable it to be built and run on Windows.

Limitations of Flash Attention in Windows: Requires CUDA 12.

### Motivation and Context
This will significantly increase the performance of Windows-based LLM's
with hardware sm>=80.

To illustrate the improvement of Flash Attention over Memory Efficient
Attention, here are some average benchmark numbers for the GQA operator,
run with configurations based on several recent models (Llama, Mixtral,
Phi-3). The benchmarks were obtained on RTX4090 GPU using the test
script located at
(onnxruntime/test/python/transformers/benchmark_gqa_windows.py).

* Clarifying Note: These benchmarks are just for the GQA operator, not
the entire model.

### Memory Efficient Attention Kernel Benchmarks:
| Model Name | Max Sequence Length | Inference Interval (ms) |
Throughput (samples/second) |

|----------------------------------------|---------------------|-------------------------|-----------------------------|
| Llama3-8B (Average Prompt) | 8192 | 0.19790525 | 13105.63425 |
| Llama3-8B (Average Token) | 8192 | 0.207775538 | 12025.10172 |
| Llama3-70B (Average Prompt) | 8192 | 0.216049167 | 11563.31185 |
| Llama3-70B (Average Token) | 8192 | 0.209730731 | 12284.38149 |
| Mixtral-8x22B-v0.1 (Average Prompt) | 32768 | 0.371928785 |
7031.440056 |
| Mixtral-8x22B-v0.1 (Average Token) | 32768 | 0.2996659 | 7607.947159 |
| Phi-3-mini-128k (Average Prompt) | 131072 | 0.183195867 | 15542.0852 |
| Phi-3-mini-128k (Average Token) | 131072 | 0.198215688 | 12874.53494 |
| Phi-3-small-128k (Average Prompt) | 65536 | 2.9884929 | 2332.584142 |
| Phi-3-small-128k (Average Token) | 65536 | 0.845072406 | 2877.85822 |
| Phi-3-medium-128K (Average Prompt) | 32768 | 0.324974429 | 8094.909517
|
| Phi-3-medium-128K (Average Token) | 32768 | 0.263662567 | 8978.463687
|

### Flash Attention Kernel Benchmarks:
| Model Name | Max Sequence Length | Inference Interval (ms) |
Throughput (samples/second) |

|--------------------------------------|---------------------|-------------------------|-----------------------------|
| Llama3-8B (Average Prompt) | 8192 | 0.163566292 | 16213.69057 |
| Llama3-8B (Average Token) | 8192 | 0.161643692 | 16196.14715 |
| Llama3-70B (Average Prompt) | 8192 | 0.160510375 | 17448.67753 |
| Llama3-70B (Average Token) | 8192 | 0.169427308 | 14702.62043 |
| Mixtral-8x22B-v0.1 (Average Prompt) | 32768 | 0.164121964 |
15618.51301 |
| Mixtral-8x22B-v0.1 (Average Token) | 32768 | 0.1715865 | 14524.32273 |
| Phi-3-mini-128k (Average Prompt) | 131072 | 0.167527167 | 14576.725 |
| Phi-3-mini-128k (Average Token) | 131072 | 0.175940594 | 15762.051 |
| Phi-3-small-128k (Average Prompt) | 65536 | 0.162719733 | 17824.494 |
| Phi-3-small-128k (Average Token) | 65536 | 0.14977525 | 16749.19858 |
| Phi-3-medium-128K (Average Prompt) | 32768 | 0.156490786 | 17679.2513
|
| Phi-3-medium-128K (Average Token) | 32768 | 0.165333833 | 14932.26079
|

Flash Attention is consistently faster for every configuration we
benchmarked, with improvements in our trials ranging from ~20% to ~650%.

In addition to these improvements in performance, Flash Attention has
better memory usage. For example, Memory Efficient Attention cannot
handle a max sequence length higher than 32,768, but Flash Attention can
handle max sequence lengths at least as high as 131,072.

---------

Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
2024-06-24 09:43:49 -07:00
..
external [ROCm] Update ck to use ck_tile (#21030) 2024-06-19 14:06:10 +08:00
patches Update protobuf_cmake.patch to allow extra disablements configurable by projects that build ORT (#20875) 2024-06-20 16:28:15 -07:00
tensorboard
adjust_global_compile_flags.cmake tools: build: fix typo (#21052) 2024-06-19 16:14:58 -07:00
arm64x.cmake Dev/mookerem/arm64x update (#20536) 2024-05-07 12:50:38 -07:00
CMakeLists.txt Make Flash Attention work on Windows (#21015) 2024-06-24 09:43:49 -07:00
CMakeSettings.json
codeconv.runsettings
deps.txt [ROCm] Update ck to use ck_tile (#21030) 2024-06-19 14:06:10 +08:00
deps_update_and_upload.py Update google benchmark to 1.8.3. (#19734) 2024-03-01 11:01:58 -08:00
EnableVisualStudioCodeAnalysis.props
gdk_toolchain.cmake
Info.plist.in
libonnxruntime.pc.cmake.in
linux_arm32_crosscompile_toolchain.cmake Add a build validation for Linux ARM64 cross-compile (#18200) 2023-11-08 13:03:18 -08:00
linux_arm64_crosscompile_toolchain.cmake Add a build validation for Linux ARM64 cross-compile (#18200) 2023-11-08 13:03:18 -08:00
maccatalyst_prepare_objects_for_prelink.py Support xcframework for mac catalyst builds. (#19534) 2024-03-20 10:55:19 -07:00
nuget_helpers.cmake
onnxruntime.cmake Delete pyop (#21094) 2024-06-19 16:21:33 -07:00
onnxruntime_codegen_tvm.cmake
onnxruntime_common.cmake Enable QNN HTP support for Node (#20576) 2024-05-09 13:11:07 -07:00
onnxruntime_compile_triton_kernel.cmake [CUDA] Add SparseAttention operator for Phi-3-small (#20216) 2024-04-30 09:06:29 -07:00
onnxruntime_config.h.in Enabling c++ 20 in MacOS build (#16187) 2023-09-26 11:27:02 -07:00
onnxruntime_csharp.cmake
onnxruntime_flatbuffers.cmake
onnxruntime_framework.cmake
onnxruntime_framework.natvis
onnxruntime_fuzz_test.cmake
onnxruntime_graph.cmake [Apple framework] Fix minimal build with training enabled. (#19858) 2024-03-12 11:33:30 -07:00
onnxruntime_ios.toolchain.cmake Support visionos build (#20365) 2024-04-23 18:15:07 -07:00
onnxruntime_java.cmake Remove deprecated "mobile" packages (#20941) 2024-06-07 16:20:32 -05:00
onnxruntime_java_unittests.cmake
onnxruntime_kernel_explorer.cmake [ROCm] Update ck to use ck_tile (#21030) 2024-06-19 14:06:10 +08:00
onnxruntime_mlas.cmake SQNBitGemm - move workspace size calculation functions to hardware-specific implementations (#20757) 2024-05-22 15:12:17 -07:00
onnxruntime_nodejs.cmake Enable QNN HTP support for Node (#20576) 2024-05-09 13:11:07 -07:00
onnxruntime_objectivec.cmake
onnxruntime_opschema_lib.cmake
onnxruntime_optimizer.cmake Flash attention recompute (#20603) 2024-05-21 13:38:19 +08:00
onnxruntime_providers.cmake Add initial support for CoreML ML Program to the CoreML EP. (#19347) 2024-02-15 08:46:03 +10:00
onnxruntime_providers_acl.cmake Split onnxruntime_providers.cmake to multiple (#17853) 2023-10-09 20:33:44 -07:00
onnxruntime_providers_armnn.cmake Split onnxruntime_providers.cmake to multiple (#17853) 2023-10-09 20:33:44 -07:00
onnxruntime_providers_azure.cmake Split onnxruntime_providers.cmake to multiple (#17853) 2023-10-09 20:33:44 -07:00
onnxruntime_providers_cann.cmake Split onnxruntime_providers.cmake to multiple (#17853) 2023-10-09 20:33:44 -07:00
onnxruntime_providers_coreml.cmake Fix Objective-C static analysis warnings. (#20417) 2024-04-24 11:48:29 -07:00
onnxruntime_providers_cpu.cmake Support visionos build (#20365) 2024-04-23 18:15:07 -07:00
onnxruntime_providers_cuda.cmake [CUDA] upgrade cutlass to 3.5.0 (#20940) 2024-06-11 13:32:15 -07:00
onnxruntime_providers_dml.cmake Delay load dxcore.dll in addition to ext-ms-win-dxcore-l1-1-0.dll (#18913) 2023-12-26 12:33:42 -08:00
onnxruntime_providers_dnnl.cmake Split onnxruntime_providers.cmake to multiple (#17853) 2023-10-09 20:33:44 -07:00
onnxruntime_providers_js.cmake Split onnxruntime_providers.cmake to multiple (#17853) 2023-10-09 20:33:44 -07:00
onnxruntime_providers_migraphx.cmake Revert "[MIGraphX EP] enable compilation and execution on Windows (21084)" (#21132) 2024-06-21 01:01:07 -07:00
onnxruntime_providers_nnapi.cmake Make partitioning utils QDQ aware so it does not break up QDQ node units (#19723) 2024-03-12 10:55:49 +10:00
onnxruntime_providers_openvino.cmake Ort openvino npu 1.17 master (#19966) 2024-03-21 18:44:00 -07:00
onnxruntime_providers_qnn.cmake Make partitioning utils QDQ aware so it does not break up QDQ node units (#19723) 2024-03-12 10:55:49 +10:00
onnxruntime_providers_rknpu.cmake Split onnxruntime_providers.cmake to multiple (#17853) 2023-10-09 20:33:44 -07:00
onnxruntime_providers_rocm.cmake [ROCm] Update ck to use ck_tile (#21030) 2024-06-19 14:06:10 +08:00
onnxruntime_providers_tensorrt.cmake Upgrade GCC and remove the dependency on GCC8's experimental std::filesystem implementation (#20893) 2024-06-03 10:14:08 -07:00
onnxruntime_providers_tvm.cmake Split onnxruntime_providers.cmake to multiple (#17853) 2023-10-09 20:33:44 -07:00
onnxruntime_providers_vitisai.cmake [VitisAI] Solve the problem that gsl cannot be found when compiling under linux (#20466) 2024-04-28 20:56:16 -07:00
onnxruntime_providers_webnn.cmake Split onnxruntime_providers.cmake to multiple (#17853) 2023-10-09 20:33:44 -07:00
onnxruntime_providers_xnnpack.cmake Make partitioning utils QDQ aware so it does not break up QDQ node units (#19723) 2024-03-12 10:55:49 +10:00
onnxruntime_python.cmake Delete pyop (#21094) 2024-06-19 16:21:33 -07:00
onnxruntime_rocm_hipify.cmake [ROCM] Exclude flash attention from hipify (#21091) 2024-06-19 08:59:10 -07:00
onnxruntime_session.cmake
onnxruntime_snpe_provider.cmake
onnxruntime_training.cmake Delete pyop (#21094) 2024-06-19 16:21:33 -07:00
onnxruntime_unittests.cmake Delete pyop (#21094) 2024-06-19 16:21:33 -07:00
onnxruntime_util.cmake
onnxruntime_visionos.toolchain.cmake Support visionos build (#20365) 2024-04-23 18:15:07 -07:00
onnxruntime_webassembly.cmake [js/web] optimize module export and deployment (#20165) 2024-05-20 09:51:16 -07:00
precompiled_header.cmake
riscv64.toolchain.cmake Enable RISC-V 64-bit Cross-Compiling Support for ONNX Runtime on Linux (#19238) 2024-01-24 16:27:05 -08:00
Sdl.ruleset
set_winapi_family_desktop.h
target_delayload.cmake
uwp_stubs.h
wcos_rules_override.cmake Stop using apiset in OneCore build: use onecoreuap.lib instead of onecoreuap_apiset.lib (#19632) 2024-02-23 22:31:57 -08:00
winml.cmake [CP] Fix for xfgcheck and Fix WAI ARM64 build (#19634) (#19644) 2024-03-13 17:54:06 -07:00
winml_cppwinrt.cmake
winml_sdk_helpers.cmake
winml_unittests.cmake Update C/C++ dependencies: abseil, date, nsync, googletest, wil, mp11, cpuinfo and safeint (#15470) 2023-09-08 13:35:04 -07:00