onnxruntime/.github/workflows
aciddelgado ebd0368bb0
Make Flash Attention work on Windows (#21015)
### Description
Previously, Flash Attention only worked on Linux systems. This PR will
make it work and enable it to be built and run on Windows.

Limitations of Flash Attention in Windows: Requires CUDA 12.

### Motivation and Context
This will significantly increase the performance of Windows-based LLM's
with hardware sm>=80.

To illustrate the improvement of Flash Attention over Memory Efficient
Attention, here are some average benchmark numbers for the GQA operator,
run with configurations based on several recent models (Llama, Mixtral,
Phi-3). The benchmarks were obtained on RTX4090 GPU using the test
script located at
(onnxruntime/test/python/transformers/benchmark_gqa_windows.py).

* Clarifying Note: These benchmarks are just for the GQA operator, not
the entire model.

### Memory Efficient Attention Kernel Benchmarks:
| Model Name | Max Sequence Length | Inference Interval (ms) |
Throughput (samples/second) |

|----------------------------------------|---------------------|-------------------------|-----------------------------|
| Llama3-8B (Average Prompt) | 8192 | 0.19790525 | 13105.63425 |
| Llama3-8B (Average Token) | 8192 | 0.207775538 | 12025.10172 |
| Llama3-70B (Average Prompt) | 8192 | 0.216049167 | 11563.31185 |
| Llama3-70B (Average Token) | 8192 | 0.209730731 | 12284.38149 |
| Mixtral-8x22B-v0.1 (Average Prompt) | 32768 | 0.371928785 |
7031.440056 |
| Mixtral-8x22B-v0.1 (Average Token) | 32768 | 0.2996659 | 7607.947159 |
| Phi-3-mini-128k (Average Prompt) | 131072 | 0.183195867 | 15542.0852 |
| Phi-3-mini-128k (Average Token) | 131072 | 0.198215688 | 12874.53494 |
| Phi-3-small-128k (Average Prompt) | 65536 | 2.9884929 | 2332.584142 |
| Phi-3-small-128k (Average Token) | 65536 | 0.845072406 | 2877.85822 |
| Phi-3-medium-128K (Average Prompt) | 32768 | 0.324974429 | 8094.909517
|
| Phi-3-medium-128K (Average Token) | 32768 | 0.263662567 | 8978.463687
|

### Flash Attention Kernel Benchmarks:
| Model Name | Max Sequence Length | Inference Interval (ms) |
Throughput (samples/second) |

|--------------------------------------|---------------------|-------------------------|-----------------------------|
| Llama3-8B (Average Prompt) | 8192 | 0.163566292 | 16213.69057 |
| Llama3-8B (Average Token) | 8192 | 0.161643692 | 16196.14715 |
| Llama3-70B (Average Prompt) | 8192 | 0.160510375 | 17448.67753 |
| Llama3-70B (Average Token) | 8192 | 0.169427308 | 14702.62043 |
| Mixtral-8x22B-v0.1 (Average Prompt) | 32768 | 0.164121964 |
15618.51301 |
| Mixtral-8x22B-v0.1 (Average Token) | 32768 | 0.1715865 | 14524.32273 |
| Phi-3-mini-128k (Average Prompt) | 131072 | 0.167527167 | 14576.725 |
| Phi-3-mini-128k (Average Token) | 131072 | 0.175940594 | 15762.051 |
| Phi-3-small-128k (Average Prompt) | 65536 | 0.162719733 | 17824.494 |
| Phi-3-small-128k (Average Token) | 65536 | 0.14977525 | 16749.19858 |
| Phi-3-medium-128K (Average Prompt) | 32768 | 0.156490786 | 17679.2513
|
| Phi-3-medium-128K (Average Token) | 32768 | 0.165333833 | 14932.26079
|

Flash Attention is consistently faster for every configuration we
benchmarked, with improvements in our trials ranging from ~20% to ~650%.

In addition to these improvements in performance, Flash Attention has
better memory usage. For example, Memory Efficient Attention cannot
handle a max sequence length higher than 32,768, but Flash Attention can
handle max sequence lengths at least as high as 131,072.

---------

Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
2024-06-24 09:43:49 -07:00
..
cffconvert.yml Bump actions/checkout from 3 to 4 (#17487) 2023-09-13 09:22:21 -07:00
codeql.yml Use Java 11 to build project in the codeql pipeline (#19999) 2024-03-20 17:53:48 -07:00
generate-skip-doc-change.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
gradle-wrapper-validation.yml Bump gradle/wrapper-validation-action from 2 to 3 (#20305) 2024-04-16 14:20:51 -07:00
labeler.yml Update labeler.yml to change permissions (#19709) 2024-02-28 21:10:25 -08:00
lint.yml Make Flash Attention work on Windows (#21015) 2024-06-24 09:43:49 -07:00
mac.yml Add Mac CI GitHub Actions workflow (#20717) 2024-05-20 10:27:03 -07:00
publish-c-apidocs.yml Bump actions/upload-artifact from 3 to 4 (#18920) 2023-12-31 21:10:47 -08:00
publish-csharp-apidocs.yml Bump nuget/setup-nuget from 1 to 2 (#19411) 2024-02-13 15:59:15 -08:00
publish-gh-pages.yml Add website publish placeholder (#17318) 2023-08-30 11:01:54 -07:00
publish-java-apidocs.yml Bump gradle/gradle-build-action from 2 to 3 (#19297) 2024-02-05 09:41:57 -08:00
publish-js-apidocs.yml Bump actions/upload-artifact from 3 to 4 (#18920) 2023-12-31 21:10:47 -08:00
publish-objectivec-apidocs.yml Fix training and macos ci pipelines (#20034) 2024-03-26 12:20:11 -07:00
publish-python-apidocs.yml Bump actions/upload-artifact from 3 to 4 (#18920) 2023-12-31 21:10:47 -08:00
sca.yml Fix a perm issue in Windows Static Analysis pipeline (#21100) 2024-06-19 14:44:39 -07:00
skip-doc-change.yml.j2 Update Win_GPU_CI trigger (#13290) 2022-10-12 15:22:42 +08:00
stale.yml Update stale.yml to use old version as a bug fix (#19532) 2024-02-15 17:03:11 -08:00
windows.yml Remove TVM EP's pipeline (#20813) 2024-05-25 20:42:41 -07:00