onnxruntime/tools/ci_build/github/azure-pipelines
aciddelgado ebd0368bb0
Make Flash Attention work on Windows (#21015)
### Description
Previously, Flash Attention only worked on Linux systems. This PR will
make it work and enable it to be built and run on Windows.

Limitations of Flash Attention in Windows: Requires CUDA 12.

### Motivation and Context
This will significantly increase the performance of Windows-based LLM's
with hardware sm>=80.

To illustrate the improvement of Flash Attention over Memory Efficient
Attention, here are some average benchmark numbers for the GQA operator,
run with configurations based on several recent models (Llama, Mixtral,
Phi-3). The benchmarks were obtained on RTX4090 GPU using the test
script located at
(onnxruntime/test/python/transformers/benchmark_gqa_windows.py).

* Clarifying Note: These benchmarks are just for the GQA operator, not
the entire model.

### Memory Efficient Attention Kernel Benchmarks:
| Model Name | Max Sequence Length | Inference Interval (ms) |
Throughput (samples/second) |

|----------------------------------------|---------------------|-------------------------|-----------------------------|
| Llama3-8B (Average Prompt) | 8192 | 0.19790525 | 13105.63425 |
| Llama3-8B (Average Token) | 8192 | 0.207775538 | 12025.10172 |
| Llama3-70B (Average Prompt) | 8192 | 0.216049167 | 11563.31185 |
| Llama3-70B (Average Token) | 8192 | 0.209730731 | 12284.38149 |
| Mixtral-8x22B-v0.1 (Average Prompt) | 32768 | 0.371928785 |
7031.440056 |
| Mixtral-8x22B-v0.1 (Average Token) | 32768 | 0.2996659 | 7607.947159 |
| Phi-3-mini-128k (Average Prompt) | 131072 | 0.183195867 | 15542.0852 |
| Phi-3-mini-128k (Average Token) | 131072 | 0.198215688 | 12874.53494 |
| Phi-3-small-128k (Average Prompt) | 65536 | 2.9884929 | 2332.584142 |
| Phi-3-small-128k (Average Token) | 65536 | 0.845072406 | 2877.85822 |
| Phi-3-medium-128K (Average Prompt) | 32768 | 0.324974429 | 8094.909517
|
| Phi-3-medium-128K (Average Token) | 32768 | 0.263662567 | 8978.463687
|

### Flash Attention Kernel Benchmarks:
| Model Name | Max Sequence Length | Inference Interval (ms) |
Throughput (samples/second) |

|--------------------------------------|---------------------|-------------------------|-----------------------------|
| Llama3-8B (Average Prompt) | 8192 | 0.163566292 | 16213.69057 |
| Llama3-8B (Average Token) | 8192 | 0.161643692 | 16196.14715 |
| Llama3-70B (Average Prompt) | 8192 | 0.160510375 | 17448.67753 |
| Llama3-70B (Average Token) | 8192 | 0.169427308 | 14702.62043 |
| Mixtral-8x22B-v0.1 (Average Prompt) | 32768 | 0.164121964 |
15618.51301 |
| Mixtral-8x22B-v0.1 (Average Token) | 32768 | 0.1715865 | 14524.32273 |
| Phi-3-mini-128k (Average Prompt) | 131072 | 0.167527167 | 14576.725 |
| Phi-3-mini-128k (Average Token) | 131072 | 0.175940594 | 15762.051 |
| Phi-3-small-128k (Average Prompt) | 65536 | 0.162719733 | 17824.494 |
| Phi-3-small-128k (Average Token) | 65536 | 0.14977525 | 16749.19858 |
| Phi-3-medium-128K (Average Prompt) | 32768 | 0.156490786 | 17679.2513
|
| Phi-3-medium-128K (Average Token) | 32768 | 0.165333833 | 14932.26079
|

Flash Attention is consistently faster for every configuration we
benchmarked, with improvements in our trials ranging from ~20% to ~650%.

In addition to these improvements in performance, Flash Attention has
better memory usage. For example, Memory Efficient Attention cannot
handle a max sequence length higher than 32,768, but Flash Attention can
handle max sequence lengths at least as high as 131,072.

---------

Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
2024-06-24 09:43:49 -07:00
..
nodejs/templates
nuget/templates Updating cudnn from 8 to 9 on exsiting cuda 12 docker image (#20925) 2024-06-11 09:37:16 -07:00
stages Add UsePythonVersion (#21109) 2024-06-19 20:47:21 -07:00
templates [Fix] use cmdline in Final Jar Testing Stage for new managed Windows Image (#21130) 2024-06-21 12:41:06 +08:00
triggers
android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml [QNN EP] Update QNN SDK to 2.23.0 (#21008) 2024-06-19 12:37:42 -07:00
android-x86_64-crosscompile-ci-pipeline.yml
bigmodels-ci-pipeline.yml Use A100 for LLama2 model test (#21068) 2024-06-18 11:04:02 +08:00
binary-size-checks-pipeline.yml
build-perf-test-binaries-pipeline.yml
c-api-noopenmp-packaging-pipelines.yml Update generate_nuspec_for_native_nuget.py for training (#21112) 2024-06-20 16:13:31 -07:00
clean-build-docker-image-cache-pipeline.yml
cuda-packaging-pipeline.yml
linux-ci-pipeline.yml
linux-cpu-aten-pipeline.yml
linux-cpu-eager-pipeline.yml
linux-cpu-minimal-build-ci-pipeline.yml
linux-dnnl-ci-pipeline.yml
linux-gpu-ci-pipeline.yml Updating cudnn from 8 to 9 on exsiting cuda 12 docker image (#20925) 2024-06-11 09:37:16 -07:00
linux-gpu-tensorrt-ci-pipeline.yml Updating cudnn from 8 to 9 on exsiting cuda 12 docker image (#20925) 2024-06-11 09:37:16 -07:00
linux-gpu-tensorrt-daily-perf-pipeline.yml
linux-migraphx-ci-pipeline.yml [ROCm] Update ck to use ck_tile (#21030) 2024-06-19 14:06:10 +08:00
linux-openvino-ci-pipeline.yml
linux-qnn-ci-pipeline.yml [QNN EP] Update QNN SDK to 2.23.0 (#21008) 2024-06-19 12:37:42 -07:00
mac-ci-pipeline.yml Delete pyop (#21094) 2024-06-19 16:21:33 -07:00
mac-coreml-ci-pipeline.yml
mac-ios-ci-pipeline.yml
mac-ios-packaging-pipeline.yml
mac-react-native-ci-pipeline.yml
npm-packaging-pipeline.yml
nuget-cuda-publishing-pipeline.yml
orttraining-linux-ci-pipeline.yml
orttraining-linux-gpu-ci-pipeline.yml
orttraining-linux-gpu-ortmodule-distributed-test-ci-pipeline.yml custom allreduce cuda kernel (#20703) 2024-06-13 11:09:49 -07:00
orttraining-linux-nightly-ortmodule-test-pipeline.yml
orttraining-mac-ci-pipeline.yml
orttraining-pai-ci-pipeline.yml [ROCm] Update ck to use ck_tile (#21030) 2024-06-19 14:06:10 +08:00
orttraining-py-packaging-pipeline-cpu.yml
orttraining-py-packaging-pipeline-cuda.yml
orttraining-py-packaging-pipeline-cuda12.yml
orttraining-py-packaging-pipeline-rocm.yml [ROCm] Update ck to use ck_tile (#21030) 2024-06-19 14:06:10 +08:00
post-merge-jobs.yml
publish-nuget.yml Add UsePython Task in Nuget Publish workflow (#21144) 2024-06-24 13:36:13 +08:00
py-cuda-package-test-pipeline.yml
py-cuda-packaging-pipeline.yml
py-cuda-publishing-pipeline.yml
py-package-build-pipeline.yml
py-package-test-pipeline.yml
py-packaging-pipeline.yml [QNN EP] Update QNN SDK to 2.23.0 (#21008) 2024-06-19 12:37:42 -07:00
qnn-ep-nuget-packaging-pipeline.yml [QNN EP] Update QNN SDK to 2.23.0 (#21008) 2024-06-19 12:37:42 -07:00
web-ci-pipeline.yml
win-ci-fuzz-testing.yml
win-ci-pipeline.yml
win-gpu-ci-pipeline.yml Make Flash Attention work on Windows (#21015) 2024-06-24 09:43:49 -07:00
win-gpu-reduce-op-ci-pipeline.yml Move jobs in onnxruntime-Win2022-GPU-T4 machine pool to onnxruntime-Win2022-GPU-A10 (#21023) 2024-06-12 22:04:40 -07:00
win-gpu-tensorrt-ci-pipeline.yml Move jobs in onnxruntime-Win2022-GPU-T4 machine pool to onnxruntime-Win2022-GPU-A10 (#21023) 2024-06-12 22:04:40 -07:00
win-qnn-arm64-ci-pipeline.yml [QNN EP] Update QNN SDK to 2.23.0 (#21008) 2024-06-19 12:37:42 -07:00
win-qnn-ci-pipeline.yml [QNN EP] Update QNN SDK to 2.23.0 (#21008) 2024-06-19 12:37:42 -07:00