onnxruntime/tools/ci_build/github/azure-pipelines
aciddelgado ebd0368bb0
Make Flash Attention work on Windows (#21015)
### Description
Previously, Flash Attention only worked on Linux systems. This PR will
make it work and enable it to be built and run on Windows.

Limitations of Flash Attention in Windows: Requires CUDA 12.

### Motivation and Context
This will significantly increase the performance of Windows-based LLM's
with hardware sm>=80.

To illustrate the improvement of Flash Attention over Memory Efficient
Attention, here are some average benchmark numbers for the GQA operator,
run with configurations based on several recent models (Llama, Mixtral,
Phi-3). The benchmarks were obtained on RTX4090 GPU using the test
script located at
(onnxruntime/test/python/transformers/benchmark_gqa_windows.py).

* Clarifying Note: These benchmarks are just for the GQA operator, not
the entire model.

### Memory Efficient Attention Kernel Benchmarks:
| Model Name | Max Sequence Length | Inference Interval (ms) |
Throughput (samples/second) |

|----------------------------------------|---------------------|-------------------------|-----------------------------|
| Llama3-8B (Average Prompt) | 8192 | 0.19790525 | 13105.63425 |
| Llama3-8B (Average Token) | 8192 | 0.207775538 | 12025.10172 |
| Llama3-70B (Average Prompt) | 8192 | 0.216049167 | 11563.31185 |
| Llama3-70B (Average Token) | 8192 | 0.209730731 | 12284.38149 |
| Mixtral-8x22B-v0.1 (Average Prompt) | 32768 | 0.371928785 |
7031.440056 |
| Mixtral-8x22B-v0.1 (Average Token) | 32768 | 0.2996659 | 7607.947159 |
| Phi-3-mini-128k (Average Prompt) | 131072 | 0.183195867 | 15542.0852 |
| Phi-3-mini-128k (Average Token) | 131072 | 0.198215688 | 12874.53494 |
| Phi-3-small-128k (Average Prompt) | 65536 | 2.9884929 | 2332.584142 |
| Phi-3-small-128k (Average Token) | 65536 | 0.845072406 | 2877.85822 |
| Phi-3-medium-128K (Average Prompt) | 32768 | 0.324974429 | 8094.909517
|
| Phi-3-medium-128K (Average Token) | 32768 | 0.263662567 | 8978.463687
|

### Flash Attention Kernel Benchmarks:
| Model Name | Max Sequence Length | Inference Interval (ms) |
Throughput (samples/second) |

|--------------------------------------|---------------------|-------------------------|-----------------------------|
| Llama3-8B (Average Prompt) | 8192 | 0.163566292 | 16213.69057 |
| Llama3-8B (Average Token) | 8192 | 0.161643692 | 16196.14715 |
| Llama3-70B (Average Prompt) | 8192 | 0.160510375 | 17448.67753 |
| Llama3-70B (Average Token) | 8192 | 0.169427308 | 14702.62043 |
| Mixtral-8x22B-v0.1 (Average Prompt) | 32768 | 0.164121964 |
15618.51301 |
| Mixtral-8x22B-v0.1 (Average Token) | 32768 | 0.1715865 | 14524.32273 |
| Phi-3-mini-128k (Average Prompt) | 131072 | 0.167527167 | 14576.725 |
| Phi-3-mini-128k (Average Token) | 131072 | 0.175940594 | 15762.051 |
| Phi-3-small-128k (Average Prompt) | 65536 | 0.162719733 | 17824.494 |
| Phi-3-small-128k (Average Token) | 65536 | 0.14977525 | 16749.19858 |
| Phi-3-medium-128K (Average Prompt) | 32768 | 0.156490786 | 17679.2513
|
| Phi-3-medium-128K (Average Token) | 32768 | 0.165333833 | 14932.26079
|

Flash Attention is consistently faster for every configuration we
benchmarked, with improvements in our trials ranging from ~20% to ~650%.

In addition to these improvements in performance, Flash Attention has
better memory usage. For example, Memory Efficient Attention cannot
handle a max sequence length higher than 32,768, but Flash Attention can
handle max sequence lengths at least as high as 131,072.

---------

Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
2024-06-24 09:43:49 -07:00
..
nodejs/templates Adding Job names to jobs without a name (#20961) 2024-06-06 19:09:21 -07:00
nuget/templates Updating cudnn from 8 to 9 on exsiting cuda 12 docker image (#20925) 2024-06-11 09:37:16 -07:00
stages Add UsePythonVersion (#21109) 2024-06-19 20:47:21 -07:00
templates [Fix] use cmdline in Final Jar Testing Stage for new managed Windows Image (#21130) 2024-06-21 12:41:06 +08:00
triggers
android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml [QNN EP] Update QNN SDK to 2.23.0 (#21008) 2024-06-19 12:37:42 -07:00
android-x86_64-crosscompile-ci-pipeline.yml Clarify when protobuf dependency builds protoc (#20542) 2024-05-08 08:30:11 +10:00
bigmodels-ci-pipeline.yml Use A100 for LLama2 model test (#21068) 2024-06-18 11:04:02 +08:00
binary-size-checks-pipeline.yml
build-perf-test-binaries-pipeline.yml
c-api-noopenmp-packaging-pipelines.yml Update generate_nuspec_for_native_nuget.py for training (#21112) 2024-06-20 16:13:31 -07:00
clean-build-docker-image-cache-pipeline.yml
cuda-packaging-pipeline.yml Update training packaging pipeline's docker files (#20853) 2024-05-30 23:48:42 -07:00
linux-ci-pipeline.yml Update training packaging pipeline's docker files (#20853) 2024-05-30 23:48:42 -07:00
linux-cpu-aten-pipeline.yml Update Aten pipeline's docker file to use UBI8 (#20856) 2024-05-30 07:38:15 -07:00
linux-cpu-eager-pipeline.yml Update Aten pipeline's docker file to use UBI8 (#20856) 2024-05-30 07:38:15 -07:00
linux-cpu-minimal-build-ci-pipeline.yml Update training packaging pipeline's docker files (#20853) 2024-05-30 23:48:42 -07:00
linux-dnnl-ci-pipeline.yml Update training packaging pipeline's docker files (#20853) 2024-05-30 23:48:42 -07:00
linux-gpu-ci-pipeline.yml Updating cudnn from 8 to 9 on exsiting cuda 12 docker image (#20925) 2024-06-11 09:37:16 -07:00
linux-gpu-tensorrt-ci-pipeline.yml Updating cudnn from 8 to 9 on exsiting cuda 12 docker image (#20925) 2024-06-11 09:37:16 -07:00
linux-gpu-tensorrt-daily-perf-pipeline.yml [EP Perf] Fix on EP Perf (#20683) 2024-05-15 21:38:52 -07:00
linux-migraphx-ci-pipeline.yml [ROCm] Update ck to use ck_tile (#21030) 2024-06-19 14:06:10 +08:00
linux-openvino-ci-pipeline.yml
linux-qnn-ci-pipeline.yml [QNN EP] Update QNN SDK to 2.23.0 (#21008) 2024-06-19 12:37:42 -07:00
mac-ci-pipeline.yml Delete pyop (#21094) 2024-06-19 16:21:33 -07:00
mac-coreml-ci-pipeline.yml
mac-ios-ci-pipeline.yml Upgrade min ios version to 13.0 (#20773) 2024-06-04 10:15:20 -07:00
mac-ios-packaging-pipeline.yml Upgrade min ios version to 13.0 (#20773) 2024-06-04 10:15:20 -07:00
mac-react-native-ci-pipeline.yml Address React Native pipeline component detection timeout (#20871) 2024-05-30 16:37:03 -07:00
npm-packaging-pipeline.yml Increase NPM ComponentDetection.Timeout: 1200 (#20681) 2024-05-15 13:41:59 -07:00
nuget-cuda-publishing-pipeline.yml adding publishing stage to publish java CUDA 12 pkg to ado (#20834) 2024-05-29 16:24:23 -07:00
orttraining-linux-ci-pipeline.yml Remove manylinux build scripts from python packaging pipeline (#20786) 2024-05-24 08:18:22 -07:00
orttraining-linux-gpu-ci-pipeline.yml
orttraining-linux-gpu-ortmodule-distributed-test-ci-pipeline.yml custom allreduce cuda kernel (#20703) 2024-06-13 11:09:49 -07:00
orttraining-linux-nightly-ortmodule-test-pipeline.yml
orttraining-mac-ci-pipeline.yml
orttraining-pai-ci-pipeline.yml [ROCm] Update ck to use ck_tile (#21030) 2024-06-19 14:06:10 +08:00
orttraining-py-packaging-pipeline-cpu.yml Update training packaging pipeline's docker files (#20853) 2024-05-30 23:48:42 -07:00
orttraining-py-packaging-pipeline-cuda.yml Update training packaging pipeline's docker files (#20853) 2024-05-30 23:48:42 -07:00
orttraining-py-packaging-pipeline-cuda12.yml Update training packaging pipeline's docker files (#20853) 2024-05-30 23:48:42 -07:00
orttraining-py-packaging-pipeline-rocm.yml [ROCm] Update ck to use ck_tile (#21030) 2024-06-19 14:06:10 +08:00
post-merge-jobs.yml Remove deprecated "mobile" packages (#20941) 2024-06-07 16:20:32 -05:00
publish-nuget.yml Add UsePython Task in Nuget Publish workflow (#21144) 2024-06-24 13:36:13 +08:00
py-cuda-package-test-pipeline.yml
py-cuda-packaging-pipeline.yml Remove manylinux build scripts from python packaging pipeline (#20786) 2024-05-24 08:18:22 -07:00
py-cuda-publishing-pipeline.yml Update py-publishing pipeline to use the resoure from packaging pipeline (#20888) 2024-06-01 16:10:02 -07:00
py-package-build-pipeline.yml
py-package-test-pipeline.yml Upgrade GCC and remove the dependency on GCC8's experimental std::filesystem implementation (#20893) 2024-06-03 10:14:08 -07:00
py-packaging-pipeline.yml [QNN EP] Update QNN SDK to 2.23.0 (#21008) 2024-06-19 12:37:42 -07:00
qnn-ep-nuget-packaging-pipeline.yml [QNN EP] Update QNN SDK to 2.23.0 (#21008) 2024-06-19 12:37:42 -07:00
web-ci-pipeline.yml
win-ci-fuzz-testing.yml Run fuzz testing before the CG task cleans up the build directory (#20500) 2024-04-29 16:02:53 +10:00
win-ci-pipeline.yml Adding Job names to jobs without a name (#20961) 2024-06-06 19:09:21 -07:00
win-gpu-ci-pipeline.yml Make Flash Attention work on Windows (#21015) 2024-06-24 09:43:49 -07:00
win-gpu-reduce-op-ci-pipeline.yml Move jobs in onnxruntime-Win2022-GPU-T4 machine pool to onnxruntime-Win2022-GPU-A10 (#21023) 2024-06-12 22:04:40 -07:00
win-gpu-tensorrt-ci-pipeline.yml Move jobs in onnxruntime-Win2022-GPU-T4 machine pool to onnxruntime-Win2022-GPU-A10 (#21023) 2024-06-12 22:04:40 -07:00
win-qnn-arm64-ci-pipeline.yml [QNN EP] Update QNN SDK to 2.23.0 (#21008) 2024-06-19 12:37:42 -07:00
win-qnn-ci-pipeline.yml [QNN EP] Update QNN SDK to 2.23.0 (#21008) 2024-06-19 12:37:42 -07:00