Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
# -------------------------------------------------------------------------
|
|
|
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
|
|
|
# Licensed under the MIT License.
|
|
|
|
|
# --------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
"""
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
Benchmark performance of MultiHeadAttention with ORT or PyTorch.
|
|
|
|
|
|
|
|
|
|
In Linux, run the the following:
|
|
|
|
|
sh benchmark_mha.sh
|
|
|
|
|
|
|
|
|
|
In Windows, run the the following:
|
|
|
|
|
benchmark_mha.cmd
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
"""
|
|
|
|
|
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
import argparse
|
|
|
|
|
import csv
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
import math
|
|
|
|
|
import os
|
2024-06-19 17:23:26 +00:00
|
|
|
import platform
|
2024-08-22 00:30:16 +00:00
|
|
|
import re
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
import statistics
|
2024-08-22 00:30:16 +00:00
|
|
|
import sys
|
|
|
|
|
import threading
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
import time
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
from contextlib import nullcontext
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
from enum import IntEnum
|
|
|
|
|
from typing import Callable, Dict, List, Optional, Tuple
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
|
|
|
|
|
import torch
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
import torch.utils.benchmark as benchmark
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
from onnx import TensorProto, helper
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
from packaging.version import Version
|
|
|
|
|
from torch.nn.attention import SDPBackend, sdpa_kernel
|
|
|
|
|
from torch.nn.functional import scaled_dot_product_attention
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
from onnxruntime.transformers.io_binding_helper import CudaSession
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InputFormats:
|
2024-06-12 20:04:25 +00:00
|
|
|
Q_K_V_BSNH_BSNH_BSNH = 0
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
QKV_BSN3H = 1
|
|
|
|
|
Q_KV_BSNH_BSN2H = 2
|
2024-06-12 20:04:25 +00:00
|
|
|
Q_K_V_BSNH_BNSH_BNSH = 3 # For cross attention
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def input_format_str(format: int) -> str:
|
2024-06-12 20:04:25 +00:00
|
|
|
names = InputFormats.get_name_list()
|
|
|
|
|
return names[format]
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def convert(format_str: str) -> int:
|
|
|
|
|
names = InputFormats.get_name_list()
|
|
|
|
|
return names.index(format_str)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_name_list() -> List[str]:
|
|
|
|
|
return ["Q,K,V", "QKV", "Q,KV", "Q,K',V'"]
|
|
|
|
|
|
|
|
|
|
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
class SdpaKernel(IntEnum):
|
|
|
|
|
"""Bit flags for sdpa_kernel CUDA provider option"""
|
|
|
|
|
|
|
|
|
|
DEFAULT = 0
|
|
|
|
|
FLASH_ATTENTION = 1
|
|
|
|
|
EFFICIENT_ATTENTION = 2
|
|
|
|
|
TRT_FUSED_ATTENTION = 4
|
|
|
|
|
CUDNN_FLASH_ATTENTION = 8
|
|
|
|
|
MATH = 16
|
|
|
|
|
TRT_FLASH_ATTENTION = 32
|
|
|
|
|
TRT_CROSS_ATTENTION = 64
|
|
|
|
|
TRT_CAUSAL_ATTENTION = 128
|
|
|
|
|
|
|
|
|
|
|
2024-08-09 08:31:00 +00:00
|
|
|
# Since we support attention bias, so we only need support up to 2D mask.
|
|
|
|
|
class AttentionMaskFormat(IntEnum):
|
|
|
|
|
Mask_None = 0 # No attention mask.
|
|
|
|
|
Mask_1D_Key_SeqLen = 1 # Shape (batch_size), actual sequence lengths (excluding padding on the right side).
|
|
|
|
|
Mask_2D_Key_PaddingMask = 2 # Shape (batch_size, total_sequence_length), key padding mask mask.
|
|
|
|
|
|
|
|
|
|
|
2024-06-12 20:04:25 +00:00
|
|
|
class MultiHeadAttentionConfig:
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
batch_size: int,
|
|
|
|
|
sequence_length: int,
|
|
|
|
|
num_heads: int,
|
|
|
|
|
head_size: int,
|
|
|
|
|
causal: bool,
|
|
|
|
|
past_sequence_length: int = 0,
|
|
|
|
|
kv_sequence_length=None,
|
|
|
|
|
max_cache_sequence_length=None,
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
scale: float = 0.0,
|
2024-06-19 17:23:26 +00:00
|
|
|
provider="CPUExecutionProvider",
|
|
|
|
|
device: Optional[torch.device] = None,
|
|
|
|
|
enable_cuda_graph: bool = False,
|
|
|
|
|
dtype=torch.float,
|
2024-06-12 20:04:25 +00:00
|
|
|
use_kv_cache: bool = False,
|
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498)
### Description
#### Issues Fixed
(1) **TRT cross attention not thread safe**. [Core changes like
this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6)
are used to make it thread-safe:
* Add an once_flag to CumulatedSequenceLengthCache to make sure it is
only initialized once; and change the cache to be read only after
initialization. Previously, the content is not read-only so it might be
changed by other thread and potentially cause buffer overrun.
* The kernel initialization is not guarded (Although the factory of
kernel loading has static mutex to guard multiple threading), so the
mutable variable might be set by two different threads at the same time.
Add an once_flag to avoid that.
This requires need some workspace computation change as well. So I did
not create a separated pull request.
(2) **Bias for cross attention**
That scenario has assumption that only query has bias, but not for key
and value. However, such assumption is not verified in runtime and there
was no comment of assumption, and there was no test case so the support
of scenario was disabled by mistake. Actually, the scenario is used in
whisper model (TODO: we shall add tests for whisper to CI pipeline, and
also update fusion script to verify such assumptions if needed.)
CUDA/CPU kernels supports bias for cross attention as long as bias is
zero for key and value. I updated the check to support the scenario and
added comments wherever there is such assumption.
(3) **Fallback support**
Previously, unfused kernel did not support packed qkv and packed kv
formats. That means some case might fail since there is no fallback. I
added new AddBiasTranpose cuda kernels for them to support fallback, so
that all supported cases will not fail.
#### Improvements
(4) **QKV workspace size**.
The logic for no_qkv_workspace could be easily out of sync since related
code are scattered in different source files. I refactor the code to
move all related code to one file (attention_prepare_qkv.cu) and add
asserts, so that the logic can be in sync.
(5) **Remove confusing concept of pass past in kv**
parameters.pass_past_in_kv is confusing since the k/v in cross attention
is not past state. Remove it and use parameters.qkv_format ==
Q_K_V_BSNH_BNSH_BNSH instead.
New code does not use past_key/past_value for cross attention, so the
logic is more clear.
(6) **More coverage and less workspace and less transpose of flash and
efficient attention**
Previously, there is one condition does not run flash or efficient
attention:
```
bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;
```
After this change, we can use flash and efficient attention for the
case, and also less workspace.
For example, cross attention with bias, the original code uses two
additional workspaces:
```
transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
add bias: query => q, temp_k_workspace => k, temp_v_workspace => v
```
New logic is like
```
if (has bias)
Add bias to query, key, value, and store in q, k, v workspace
else
Use query, key and value directly as q, k and v in kernel
```
We can see that, we do not need allocate temp_k_workspace and
temp_v_workspace so use less memory. New code saved two transposes in
this case.
Flash and efficient attention supports BSNH or BNSH formats for k and v.
In old code, k/v are also converted to BSNH format. Some is not
necessary. I do some change to convert k/v to BSNH or BNSH case by case.
So that there are more cases can be covered by flash or efficient
attention to improve performance.
(6) **Debugging support**
Previously, there is less debug info. In this change, I add a flag for
debug info in the AttentionData. So that we can output debug info during
the processing.
Also add functions to consolidate the dumping of inputs, QKV processing
and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow
disable dumping from cuda kernel.
#### Summary of changes
(1) Refactoring the CheckInputs, and pass in operator type.
(2) Refactoring the PrepareQKV to support fallback for packed qkv or
packed kv inputs.
(3) Change a few case of PrepareQKV to allow more case covered by flash
and efficient attention.
(4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace
parameters.pass_past_in_kv
(5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of
assumption that key/value has no bias in this case.
(6) Fix thread-safe issue in CumulatedSequenceLengthCache handling.
(7) Add test cases to cover all supported scenarios.
Current support scenarios for MultiHeadAttention for CUDA/CPU:
| Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc
| ---- | ---- | ---- | ------ | ----- | --------- | -------- |
-----|---------
| BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed
| BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in
CPU
| BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in
CPU
| BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for
Q only
| BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only
present
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and
present <br> (not share buffer)
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
|
|
|
has_past_input: bool = False,
|
2024-06-12 20:04:25 +00:00
|
|
|
share_past_present_buffer: bool = False,
|
|
|
|
|
input_format: int = InputFormats.Q_K_V_BSNH_BSNH_BSNH,
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
verbose: bool = False,
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
has_bias: bool = False, # bias for input projection
|
|
|
|
|
has_attn_bias: bool = False, # bias added before softmax. For example,relative position bias.
|
|
|
|
|
broadcast_attn_bias_dim_0: bool = False, # broadcast attention bias dimension 0
|
|
|
|
|
broadcast_attn_bias_dim_1: bool = False, # broadcast attention bias dimension 1
|
2024-08-09 08:31:00 +00:00
|
|
|
mask_format: int = AttentionMaskFormat.Mask_None,
|
2024-06-12 20:04:25 +00:00
|
|
|
):
|
|
|
|
|
self.operator = "MultiHeadAttention"
|
|
|
|
|
self.batch_size = batch_size
|
|
|
|
|
self.sequence_length = sequence_length
|
|
|
|
|
self.kv_sequence_length = kv_sequence_length or sequence_length
|
|
|
|
|
self.max_cache_sequence_length = max_cache_sequence_length
|
|
|
|
|
self.past_sequence_length = past_sequence_length
|
|
|
|
|
self.num_heads = num_heads
|
|
|
|
|
self.head_size = head_size
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
self.causal = causal
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
self.scale = scale or (1.0 / (head_size**0.5))
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
|
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498)
### Description
#### Issues Fixed
(1) **TRT cross attention not thread safe**. [Core changes like
this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6)
are used to make it thread-safe:
* Add an once_flag to CumulatedSequenceLengthCache to make sure it is
only initialized once; and change the cache to be read only after
initialization. Previously, the content is not read-only so it might be
changed by other thread and potentially cause buffer overrun.
* The kernel initialization is not guarded (Although the factory of
kernel loading has static mutex to guard multiple threading), so the
mutable variable might be set by two different threads at the same time.
Add an once_flag to avoid that.
This requires need some workspace computation change as well. So I did
not create a separated pull request.
(2) **Bias for cross attention**
That scenario has assumption that only query has bias, but not for key
and value. However, such assumption is not verified in runtime and there
was no comment of assumption, and there was no test case so the support
of scenario was disabled by mistake. Actually, the scenario is used in
whisper model (TODO: we shall add tests for whisper to CI pipeline, and
also update fusion script to verify such assumptions if needed.)
CUDA/CPU kernels supports bias for cross attention as long as bias is
zero for key and value. I updated the check to support the scenario and
added comments wherever there is such assumption.
(3) **Fallback support**
Previously, unfused kernel did not support packed qkv and packed kv
formats. That means some case might fail since there is no fallback. I
added new AddBiasTranpose cuda kernels for them to support fallback, so
that all supported cases will not fail.
#### Improvements
(4) **QKV workspace size**.
The logic for no_qkv_workspace could be easily out of sync since related
code are scattered in different source files. I refactor the code to
move all related code to one file (attention_prepare_qkv.cu) and add
asserts, so that the logic can be in sync.
(5) **Remove confusing concept of pass past in kv**
parameters.pass_past_in_kv is confusing since the k/v in cross attention
is not past state. Remove it and use parameters.qkv_format ==
Q_K_V_BSNH_BNSH_BNSH instead.
New code does not use past_key/past_value for cross attention, so the
logic is more clear.
(6) **More coverage and less workspace and less transpose of flash and
efficient attention**
Previously, there is one condition does not run flash or efficient
attention:
```
bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;
```
After this change, we can use flash and efficient attention for the
case, and also less workspace.
For example, cross attention with bias, the original code uses two
additional workspaces:
```
transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
add bias: query => q, temp_k_workspace => k, temp_v_workspace => v
```
New logic is like
```
if (has bias)
Add bias to query, key, value, and store in q, k, v workspace
else
Use query, key and value directly as q, k and v in kernel
```
We can see that, we do not need allocate temp_k_workspace and
temp_v_workspace so use less memory. New code saved two transposes in
this case.
Flash and efficient attention supports BSNH or BNSH formats for k and v.
In old code, k/v are also converted to BSNH format. Some is not
necessary. I do some change to convert k/v to BSNH or BNSH case by case.
So that there are more cases can be covered by flash or efficient
attention to improve performance.
(6) **Debugging support**
Previously, there is less debug info. In this change, I add a flag for
debug info in the AttentionData. So that we can output debug info during
the processing.
Also add functions to consolidate the dumping of inputs, QKV processing
and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow
disable dumping from cuda kernel.
#### Summary of changes
(1) Refactoring the CheckInputs, and pass in operator type.
(2) Refactoring the PrepareQKV to support fallback for packed qkv or
packed kv inputs.
(3) Change a few case of PrepareQKV to allow more case covered by flash
and efficient attention.
(4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace
parameters.pass_past_in_kv
(5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of
assumption that key/value has no bias in this case.
(6) Fix thread-safe issue in CumulatedSequenceLengthCache handling.
(7) Add test cases to cover all supported scenarios.
Current support scenarios for MultiHeadAttention for CUDA/CPU:
| Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc
| ---- | ---- | ---- | ------ | ----- | --------- | -------- |
-----|---------
| BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed
| BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in
CPU
| BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in
CPU
| BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for
Q only
| BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only
present
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and
present <br> (not share buffer)
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
|
|
|
# Support the case that there is no past but need present output (for prompt case).
|
|
|
|
|
self.has_past_input = has_past_input
|
|
|
|
|
if has_past_input:
|
|
|
|
|
assert use_kv_cache
|
|
|
|
|
else: # no past input
|
|
|
|
|
assert past_sequence_length == 0
|
|
|
|
|
|
|
|
|
|
self.has_present_output = use_kv_cache
|
|
|
|
|
|
2024-06-12 20:04:25 +00:00
|
|
|
self.use_kv_cache = use_kv_cache
|
|
|
|
|
if not use_kv_cache:
|
|
|
|
|
assert past_sequence_length == 0
|
2024-06-19 17:23:26 +00:00
|
|
|
else:
|
|
|
|
|
assert self.kv_sequence_length == self.sequence_length
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
|
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498)
### Description
#### Issues Fixed
(1) **TRT cross attention not thread safe**. [Core changes like
this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6)
are used to make it thread-safe:
* Add an once_flag to CumulatedSequenceLengthCache to make sure it is
only initialized once; and change the cache to be read only after
initialization. Previously, the content is not read-only so it might be
changed by other thread and potentially cause buffer overrun.
* The kernel initialization is not guarded (Although the factory of
kernel loading has static mutex to guard multiple threading), so the
mutable variable might be set by two different threads at the same time.
Add an once_flag to avoid that.
This requires need some workspace computation change as well. So I did
not create a separated pull request.
(2) **Bias for cross attention**
That scenario has assumption that only query has bias, but not for key
and value. However, such assumption is not verified in runtime and there
was no comment of assumption, and there was no test case so the support
of scenario was disabled by mistake. Actually, the scenario is used in
whisper model (TODO: we shall add tests for whisper to CI pipeline, and
also update fusion script to verify such assumptions if needed.)
CUDA/CPU kernels supports bias for cross attention as long as bias is
zero for key and value. I updated the check to support the scenario and
added comments wherever there is such assumption.
(3) **Fallback support**
Previously, unfused kernel did not support packed qkv and packed kv
formats. That means some case might fail since there is no fallback. I
added new AddBiasTranpose cuda kernels for them to support fallback, so
that all supported cases will not fail.
#### Improvements
(4) **QKV workspace size**.
The logic for no_qkv_workspace could be easily out of sync since related
code are scattered in different source files. I refactor the code to
move all related code to one file (attention_prepare_qkv.cu) and add
asserts, so that the logic can be in sync.
(5) **Remove confusing concept of pass past in kv**
parameters.pass_past_in_kv is confusing since the k/v in cross attention
is not past state. Remove it and use parameters.qkv_format ==
Q_K_V_BSNH_BNSH_BNSH instead.
New code does not use past_key/past_value for cross attention, so the
logic is more clear.
(6) **More coverage and less workspace and less transpose of flash and
efficient attention**
Previously, there is one condition does not run flash or efficient
attention:
```
bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;
```
After this change, we can use flash and efficient attention for the
case, and also less workspace.
For example, cross attention with bias, the original code uses two
additional workspaces:
```
transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
add bias: query => q, temp_k_workspace => k, temp_v_workspace => v
```
New logic is like
```
if (has bias)
Add bias to query, key, value, and store in q, k, v workspace
else
Use query, key and value directly as q, k and v in kernel
```
We can see that, we do not need allocate temp_k_workspace and
temp_v_workspace so use less memory. New code saved two transposes in
this case.
Flash and efficient attention supports BSNH or BNSH formats for k and v.
In old code, k/v are also converted to BSNH format. Some is not
necessary. I do some change to convert k/v to BSNH or BNSH case by case.
So that there are more cases can be covered by flash or efficient
attention to improve performance.
(6) **Debugging support**
Previously, there is less debug info. In this change, I add a flag for
debug info in the AttentionData. So that we can output debug info during
the processing.
Also add functions to consolidate the dumping of inputs, QKV processing
and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow
disable dumping from cuda kernel.
#### Summary of changes
(1) Refactoring the CheckInputs, and pass in operator type.
(2) Refactoring the PrepareQKV to support fallback for packed qkv or
packed kv inputs.
(3) Change a few case of PrepareQKV to allow more case covered by flash
and efficient attention.
(4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace
parameters.pass_past_in_kv
(5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of
assumption that key/value has no bias in this case.
(6) Fix thread-safe issue in CumulatedSequenceLengthCache handling.
(7) Add test cases to cover all supported scenarios.
Current support scenarios for MultiHeadAttention for CUDA/CPU:
| Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc
| ---- | ---- | ---- | ------ | ----- | --------- | -------- |
-----|---------
| BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed
| BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in
CPU
| BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in
CPU
| BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for
Q only
| BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only
present
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and
present <br> (not share buffer)
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
|
|
|
# Only BSNH input format supports past state.
|
|
|
|
|
if input_format != InputFormats.Q_K_V_BSNH_BSNH_BSNH:
|
|
|
|
|
assert not self.has_past_input
|
|
|
|
|
assert not self.has_present_output
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
|
2024-06-12 20:04:25 +00:00
|
|
|
# Derived values
|
|
|
|
|
self.total_sequence_length = self.kv_sequence_length + past_sequence_length
|
|
|
|
|
self.past_buffer_length = self.max_cache_sequence_length if share_past_present_buffer else past_sequence_length
|
|
|
|
|
self.present_buffer_length = (
|
|
|
|
|
self.max_cache_sequence_length if share_past_present_buffer else self.total_sequence_length
|
|
|
|
|
)
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
|
2024-06-19 17:23:26 +00:00
|
|
|
self.provider = provider
|
2024-06-12 20:04:25 +00:00
|
|
|
self.device = device
|
2024-06-19 17:23:26 +00:00
|
|
|
self.enable_cuda_graph = enable_cuda_graph
|
|
|
|
|
self.dtype = dtype
|
|
|
|
|
|
2024-06-12 20:04:25 +00:00
|
|
|
self.share_past_present_buffer = share_past_present_buffer
|
|
|
|
|
self.input_format = input_format
|
|
|
|
|
self.is_packed_qkv = input_format == InputFormats.QKV_BSN3H
|
|
|
|
|
self.is_packed_kv = input_format == InputFormats.Q_KV_BSNH_BSN2H
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
self.verbose = verbose
|
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498)
### Description
#### Issues Fixed
(1) **TRT cross attention not thread safe**. [Core changes like
this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6)
are used to make it thread-safe:
* Add an once_flag to CumulatedSequenceLengthCache to make sure it is
only initialized once; and change the cache to be read only after
initialization. Previously, the content is not read-only so it might be
changed by other thread and potentially cause buffer overrun.
* The kernel initialization is not guarded (Although the factory of
kernel loading has static mutex to guard multiple threading), so the
mutable variable might be set by two different threads at the same time.
Add an once_flag to avoid that.
This requires need some workspace computation change as well. So I did
not create a separated pull request.
(2) **Bias for cross attention**
That scenario has assumption that only query has bias, but not for key
and value. However, such assumption is not verified in runtime and there
was no comment of assumption, and there was no test case so the support
of scenario was disabled by mistake. Actually, the scenario is used in
whisper model (TODO: we shall add tests for whisper to CI pipeline, and
also update fusion script to verify such assumptions if needed.)
CUDA/CPU kernels supports bias for cross attention as long as bias is
zero for key and value. I updated the check to support the scenario and
added comments wherever there is such assumption.
(3) **Fallback support**
Previously, unfused kernel did not support packed qkv and packed kv
formats. That means some case might fail since there is no fallback. I
added new AddBiasTranpose cuda kernels for them to support fallback, so
that all supported cases will not fail.
#### Improvements
(4) **QKV workspace size**.
The logic for no_qkv_workspace could be easily out of sync since related
code are scattered in different source files. I refactor the code to
move all related code to one file (attention_prepare_qkv.cu) and add
asserts, so that the logic can be in sync.
(5) **Remove confusing concept of pass past in kv**
parameters.pass_past_in_kv is confusing since the k/v in cross attention
is not past state. Remove it and use parameters.qkv_format ==
Q_K_V_BSNH_BNSH_BNSH instead.
New code does not use past_key/past_value for cross attention, so the
logic is more clear.
(6) **More coverage and less workspace and less transpose of flash and
efficient attention**
Previously, there is one condition does not run flash or efficient
attention:
```
bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;
```
After this change, we can use flash and efficient attention for the
case, and also less workspace.
For example, cross attention with bias, the original code uses two
additional workspaces:
```
transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
add bias: query => q, temp_k_workspace => k, temp_v_workspace => v
```
New logic is like
```
if (has bias)
Add bias to query, key, value, and store in q, k, v workspace
else
Use query, key and value directly as q, k and v in kernel
```
We can see that, we do not need allocate temp_k_workspace and
temp_v_workspace so use less memory. New code saved two transposes in
this case.
Flash and efficient attention supports BSNH or BNSH formats for k and v.
In old code, k/v are also converted to BSNH format. Some is not
necessary. I do some change to convert k/v to BSNH or BNSH case by case.
So that there are more cases can be covered by flash or efficient
attention to improve performance.
(6) **Debugging support**
Previously, there is less debug info. In this change, I add a flag for
debug info in the AttentionData. So that we can output debug info during
the processing.
Also add functions to consolidate the dumping of inputs, QKV processing
and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow
disable dumping from cuda kernel.
#### Summary of changes
(1) Refactoring the CheckInputs, and pass in operator type.
(2) Refactoring the PrepareQKV to support fallback for packed qkv or
packed kv inputs.
(3) Change a few case of PrepareQKV to allow more case covered by flash
and efficient attention.
(4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace
parameters.pass_past_in_kv
(5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of
assumption that key/value has no bias in this case.
(6) Fix thread-safe issue in CumulatedSequenceLengthCache handling.
(7) Add test cases to cover all supported scenarios.
Current support scenarios for MultiHeadAttention for CUDA/CPU:
| Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc
| ---- | ---- | ---- | ------ | ----- | --------- | -------- |
-----|---------
| BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed
| BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in
CPU
| BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in
CPU
| BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for
Q only
| BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only
present
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and
present <br> (not share buffer)
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
|
|
|
self.has_bias = has_bias
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
self.has_attn_bias = has_attn_bias
|
|
|
|
|
self.broadcast_attn_bias_dim_0 = broadcast_attn_bias_dim_0
|
|
|
|
|
self.broadcast_attn_bias_dim_1 = broadcast_attn_bias_dim_1
|
|
|
|
|
|
|
|
|
|
assert mask_format in [
|
|
|
|
|
AttentionMaskFormat.Mask_None,
|
|
|
|
|
AttentionMaskFormat.Mask_1D_Key_SeqLen,
|
|
|
|
|
AttentionMaskFormat.Mask_2D_Key_PaddingMask,
|
|
|
|
|
]
|
|
|
|
|
self.mask_format = mask_format
|
|
|
|
|
|
|
|
|
|
# mask_index_q and mask_index_kv will be updated in random_inputs() if mask_format is not Mask_None.
|
|
|
|
|
self.mask_index_kv = torch.ones(self.batch_size, dtype=torch.int32, device=self.device) * self.sequence_length
|
|
|
|
|
self.mask_index_q = (
|
|
|
|
|
torch.ones(self.batch_size, dtype=torch.int32, device=self.device) * self.total_sequence_length
|
|
|
|
|
)
|
2024-06-19 17:23:26 +00:00
|
|
|
|
2024-08-09 08:31:00 +00:00
|
|
|
assert mask_format in [
|
|
|
|
|
AttentionMaskFormat.Mask_None,
|
|
|
|
|
AttentionMaskFormat.Mask_1D_Key_SeqLen,
|
|
|
|
|
AttentionMaskFormat.Mask_2D_Key_PaddingMask,
|
|
|
|
|
]
|
|
|
|
|
self.mask_format = mask_format
|
|
|
|
|
|
|
|
|
|
# mask_index_q and mask_index_kv will be updated in random_inputs() if mask_format is not Mask_None.
|
|
|
|
|
self.mask_index_kv = torch.ones(self.batch_size, dtype=torch.int32, device=self.device) * self.sequence_length
|
|
|
|
|
self.mask_index_q = (
|
|
|
|
|
torch.ones(self.batch_size, dtype=torch.int32, device=self.device) * self.total_sequence_length
|
|
|
|
|
)
|
|
|
|
|
|
2024-06-19 17:23:26 +00:00
|
|
|
def __repr__(self):
|
|
|
|
|
return (
|
|
|
|
|
f"MultiHeadAttentionConfig(batch_size={self.batch_size}, sequence_length={self.sequence_length}, "
|
|
|
|
|
f"num_heads={self.num_heads}, head_size={self.head_size}, "
|
|
|
|
|
f"kv_sequence_length={self.kv_sequence_length}, past_sequence_length={self.past_sequence_length}, "
|
|
|
|
|
f"max_cache_sequence_length={self.max_cache_sequence_length},"
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
f"causal={self.causal}), scale={self.scale}, use_kv_cache={self.use_kv_cache}, "
|
2024-06-19 17:23:26 +00:00
|
|
|
f"share_past_present_buffer={self.share_past_present_buffer}, "
|
|
|
|
|
f"provider={self.provider}, device={self.device}, enable_cuda_graph={self.enable_cuda_graph}, "
|
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498)
### Description
#### Issues Fixed
(1) **TRT cross attention not thread safe**. [Core changes like
this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6)
are used to make it thread-safe:
* Add an once_flag to CumulatedSequenceLengthCache to make sure it is
only initialized once; and change the cache to be read only after
initialization. Previously, the content is not read-only so it might be
changed by other thread and potentially cause buffer overrun.
* The kernel initialization is not guarded (Although the factory of
kernel loading has static mutex to guard multiple threading), so the
mutable variable might be set by two different threads at the same time.
Add an once_flag to avoid that.
This requires need some workspace computation change as well. So I did
not create a separated pull request.
(2) **Bias for cross attention**
That scenario has assumption that only query has bias, but not for key
and value. However, such assumption is not verified in runtime and there
was no comment of assumption, and there was no test case so the support
of scenario was disabled by mistake. Actually, the scenario is used in
whisper model (TODO: we shall add tests for whisper to CI pipeline, and
also update fusion script to verify such assumptions if needed.)
CUDA/CPU kernels supports bias for cross attention as long as bias is
zero for key and value. I updated the check to support the scenario and
added comments wherever there is such assumption.
(3) **Fallback support**
Previously, unfused kernel did not support packed qkv and packed kv
formats. That means some case might fail since there is no fallback. I
added new AddBiasTranpose cuda kernels for them to support fallback, so
that all supported cases will not fail.
#### Improvements
(4) **QKV workspace size**.
The logic for no_qkv_workspace could be easily out of sync since related
code are scattered in different source files. I refactor the code to
move all related code to one file (attention_prepare_qkv.cu) and add
asserts, so that the logic can be in sync.
(5) **Remove confusing concept of pass past in kv**
parameters.pass_past_in_kv is confusing since the k/v in cross attention
is not past state. Remove it and use parameters.qkv_format ==
Q_K_V_BSNH_BNSH_BNSH instead.
New code does not use past_key/past_value for cross attention, so the
logic is more clear.
(6) **More coverage and less workspace and less transpose of flash and
efficient attention**
Previously, there is one condition does not run flash or efficient
attention:
```
bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;
```
After this change, we can use flash and efficient attention for the
case, and also less workspace.
For example, cross attention with bias, the original code uses two
additional workspaces:
```
transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
add bias: query => q, temp_k_workspace => k, temp_v_workspace => v
```
New logic is like
```
if (has bias)
Add bias to query, key, value, and store in q, k, v workspace
else
Use query, key and value directly as q, k and v in kernel
```
We can see that, we do not need allocate temp_k_workspace and
temp_v_workspace so use less memory. New code saved two transposes in
this case.
Flash and efficient attention supports BSNH or BNSH formats for k and v.
In old code, k/v are also converted to BSNH format. Some is not
necessary. I do some change to convert k/v to BSNH or BNSH case by case.
So that there are more cases can be covered by flash or efficient
attention to improve performance.
(6) **Debugging support**
Previously, there is less debug info. In this change, I add a flag for
debug info in the AttentionData. So that we can output debug info during
the processing.
Also add functions to consolidate the dumping of inputs, QKV processing
and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow
disable dumping from cuda kernel.
#### Summary of changes
(1) Refactoring the CheckInputs, and pass in operator type.
(2) Refactoring the PrepareQKV to support fallback for packed qkv or
packed kv inputs.
(3) Change a few case of PrepareQKV to allow more case covered by flash
and efficient attention.
(4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace
parameters.pass_past_in_kv
(5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of
assumption that key/value has no bias in this case.
(6) Fix thread-safe issue in CumulatedSequenceLengthCache handling.
(7) Add test cases to cover all supported scenarios.
Current support scenarios for MultiHeadAttention for CUDA/CPU:
| Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc
| ---- | ---- | ---- | ------ | ----- | --------- | -------- |
-----|---------
| BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed
| BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in
CPU
| BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in
CPU
| BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for
Q only
| BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only
present
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and
present <br> (not share buffer)
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
|
|
|
f"dtype={self.dtype}, input_format={InputFormats.input_format_str(self.input_format)}, "
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
f"has_bias={self.has_bias}, mask_format={self.mask_format}, "
|
|
|
|
|
f"has_attn_bias={self.has_attn_bias}, "
|
|
|
|
|
f"broadcast_attn_bias_dim_0={self.broadcast_attn_bias_dim_0}, "
|
|
|
|
|
f"broadcast_attn_bias_dim_1={self.broadcast_attn_bias_dim_1}, "
|
2024-06-19 17:23:26 +00:00
|
|
|
)
|
2024-06-12 20:04:25 +00:00
|
|
|
|
|
|
|
|
def shape_dict(self, input_format=None):
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
shapes: Dict[str, Tuple] = {
|
|
|
|
|
"output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
|
|
|
|
|
}
|
|
|
|
|
|
2024-06-12 20:04:25 +00:00
|
|
|
input_format = input_format or self.input_format
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
if input_format == InputFormats.QKV_BSN3H:
|
|
|
|
|
shapes = {
|
|
|
|
|
**shapes,
|
|
|
|
|
"query": (self.batch_size, self.sequence_length, self.num_heads, 3, self.head_size),
|
|
|
|
|
}
|
|
|
|
|
elif input_format == InputFormats.Q_KV_BSNH_BSN2H:
|
|
|
|
|
shapes = {
|
|
|
|
|
**shapes,
|
|
|
|
|
"query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
|
|
|
|
|
"key": (self.batch_size, self.sequence_length, self.num_heads, 2, self.head_size),
|
|
|
|
|
}
|
|
|
|
|
elif input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH:
|
|
|
|
|
shapes = {
|
|
|
|
|
**shapes,
|
|
|
|
|
"query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
|
|
|
|
|
"key": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
|
|
|
|
|
"value": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
|
|
|
|
|
}
|
|
|
|
|
else:
|
|
|
|
|
assert input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH
|
|
|
|
|
shapes = {
|
|
|
|
|
**shapes,
|
2024-06-12 20:04:25 +00:00
|
|
|
"query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
|
|
|
|
|
"key": (self.batch_size, self.num_heads, self.sequence_length, self.head_size),
|
|
|
|
|
"value": (self.batch_size, self.num_heads, self.sequence_length, self.head_size),
|
|
|
|
|
}
|
|
|
|
|
|
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498)
### Description
#### Issues Fixed
(1) **TRT cross attention not thread safe**. [Core changes like
this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6)
are used to make it thread-safe:
* Add an once_flag to CumulatedSequenceLengthCache to make sure it is
only initialized once; and change the cache to be read only after
initialization. Previously, the content is not read-only so it might be
changed by other thread and potentially cause buffer overrun.
* The kernel initialization is not guarded (Although the factory of
kernel loading has static mutex to guard multiple threading), so the
mutable variable might be set by two different threads at the same time.
Add an once_flag to avoid that.
This requires need some workspace computation change as well. So I did
not create a separated pull request.
(2) **Bias for cross attention**
That scenario has assumption that only query has bias, but not for key
and value. However, such assumption is not verified in runtime and there
was no comment of assumption, and there was no test case so the support
of scenario was disabled by mistake. Actually, the scenario is used in
whisper model (TODO: we shall add tests for whisper to CI pipeline, and
also update fusion script to verify such assumptions if needed.)
CUDA/CPU kernels supports bias for cross attention as long as bias is
zero for key and value. I updated the check to support the scenario and
added comments wherever there is such assumption.
(3) **Fallback support**
Previously, unfused kernel did not support packed qkv and packed kv
formats. That means some case might fail since there is no fallback. I
added new AddBiasTranpose cuda kernels for them to support fallback, so
that all supported cases will not fail.
#### Improvements
(4) **QKV workspace size**.
The logic for no_qkv_workspace could be easily out of sync since related
code are scattered in different source files. I refactor the code to
move all related code to one file (attention_prepare_qkv.cu) and add
asserts, so that the logic can be in sync.
(5) **Remove confusing concept of pass past in kv**
parameters.pass_past_in_kv is confusing since the k/v in cross attention
is not past state. Remove it and use parameters.qkv_format ==
Q_K_V_BSNH_BNSH_BNSH instead.
New code does not use past_key/past_value for cross attention, so the
logic is more clear.
(6) **More coverage and less workspace and less transpose of flash and
efficient attention**
Previously, there is one condition does not run flash or efficient
attention:
```
bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;
```
After this change, we can use flash and efficient attention for the
case, and also less workspace.
For example, cross attention with bias, the original code uses two
additional workspaces:
```
transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
add bias: query => q, temp_k_workspace => k, temp_v_workspace => v
```
New logic is like
```
if (has bias)
Add bias to query, key, value, and store in q, k, v workspace
else
Use query, key and value directly as q, k and v in kernel
```
We can see that, we do not need allocate temp_k_workspace and
temp_v_workspace so use less memory. New code saved two transposes in
this case.
Flash and efficient attention supports BSNH or BNSH formats for k and v.
In old code, k/v are also converted to BSNH format. Some is not
necessary. I do some change to convert k/v to BSNH or BNSH case by case.
So that there are more cases can be covered by flash or efficient
attention to improve performance.
(6) **Debugging support**
Previously, there is less debug info. In this change, I add a flag for
debug info in the AttentionData. So that we can output debug info during
the processing.
Also add functions to consolidate the dumping of inputs, QKV processing
and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow
disable dumping from cuda kernel.
#### Summary of changes
(1) Refactoring the CheckInputs, and pass in operator type.
(2) Refactoring the PrepareQKV to support fallback for packed qkv or
packed kv inputs.
(3) Change a few case of PrepareQKV to allow more case covered by flash
and efficient attention.
(4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace
parameters.pass_past_in_kv
(5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of
assumption that key/value has no bias in this case.
(6) Fix thread-safe issue in CumulatedSequenceLengthCache handling.
(7) Add test cases to cover all supported scenarios.
Current support scenarios for MultiHeadAttention for CUDA/CPU:
| Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc
| ---- | ---- | ---- | ------ | ----- | --------- | -------- |
-----|---------
| BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed
| BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in
CPU
| BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in
CPU
| BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for
Q only
| BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only
present
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and
present <br> (not share buffer)
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
|
|
|
if self.has_past_input:
|
2024-06-12 20:04:25 +00:00
|
|
|
shapes = {
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
**shapes,
|
2024-06-12 20:04:25 +00:00
|
|
|
"past_key": (self.batch_size, self.num_heads, self.past_buffer_length, self.head_size),
|
|
|
|
|
"past_value": (self.batch_size, self.num_heads, self.past_buffer_length, self.head_size),
|
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498)
### Description
#### Issues Fixed
(1) **TRT cross attention not thread safe**. [Core changes like
this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6)
are used to make it thread-safe:
* Add an once_flag to CumulatedSequenceLengthCache to make sure it is
only initialized once; and change the cache to be read only after
initialization. Previously, the content is not read-only so it might be
changed by other thread and potentially cause buffer overrun.
* The kernel initialization is not guarded (Although the factory of
kernel loading has static mutex to guard multiple threading), so the
mutable variable might be set by two different threads at the same time.
Add an once_flag to avoid that.
This requires need some workspace computation change as well. So I did
not create a separated pull request.
(2) **Bias for cross attention**
That scenario has assumption that only query has bias, but not for key
and value. However, such assumption is not verified in runtime and there
was no comment of assumption, and there was no test case so the support
of scenario was disabled by mistake. Actually, the scenario is used in
whisper model (TODO: we shall add tests for whisper to CI pipeline, and
also update fusion script to verify such assumptions if needed.)
CUDA/CPU kernels supports bias for cross attention as long as bias is
zero for key and value. I updated the check to support the scenario and
added comments wherever there is such assumption.
(3) **Fallback support**
Previously, unfused kernel did not support packed qkv and packed kv
formats. That means some case might fail since there is no fallback. I
added new AddBiasTranpose cuda kernels for them to support fallback, so
that all supported cases will not fail.
#### Improvements
(4) **QKV workspace size**.
The logic for no_qkv_workspace could be easily out of sync since related
code are scattered in different source files. I refactor the code to
move all related code to one file (attention_prepare_qkv.cu) and add
asserts, so that the logic can be in sync.
(5) **Remove confusing concept of pass past in kv**
parameters.pass_past_in_kv is confusing since the k/v in cross attention
is not past state. Remove it and use parameters.qkv_format ==
Q_K_V_BSNH_BNSH_BNSH instead.
New code does not use past_key/past_value for cross attention, so the
logic is more clear.
(6) **More coverage and less workspace and less transpose of flash and
efficient attention**
Previously, there is one condition does not run flash or efficient
attention:
```
bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;
```
After this change, we can use flash and efficient attention for the
case, and also less workspace.
For example, cross attention with bias, the original code uses two
additional workspaces:
```
transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
add bias: query => q, temp_k_workspace => k, temp_v_workspace => v
```
New logic is like
```
if (has bias)
Add bias to query, key, value, and store in q, k, v workspace
else
Use query, key and value directly as q, k and v in kernel
```
We can see that, we do not need allocate temp_k_workspace and
temp_v_workspace so use less memory. New code saved two transposes in
this case.
Flash and efficient attention supports BSNH or BNSH formats for k and v.
In old code, k/v are also converted to BSNH format. Some is not
necessary. I do some change to convert k/v to BSNH or BNSH case by case.
So that there are more cases can be covered by flash or efficient
attention to improve performance.
(6) **Debugging support**
Previously, there is less debug info. In this change, I add a flag for
debug info in the AttentionData. So that we can output debug info during
the processing.
Also add functions to consolidate the dumping of inputs, QKV processing
and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow
disable dumping from cuda kernel.
#### Summary of changes
(1) Refactoring the CheckInputs, and pass in operator type.
(2) Refactoring the PrepareQKV to support fallback for packed qkv or
packed kv inputs.
(3) Change a few case of PrepareQKV to allow more case covered by flash
and efficient attention.
(4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace
parameters.pass_past_in_kv
(5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of
assumption that key/value has no bias in this case.
(6) Fix thread-safe issue in CumulatedSequenceLengthCache handling.
(7) Add test cases to cover all supported scenarios.
Current support scenarios for MultiHeadAttention for CUDA/CPU:
| Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc
| ---- | ---- | ---- | ------ | ----- | --------- | -------- |
-----|---------
| BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed
| BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in
CPU
| BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in
CPU
| BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for
Q only
| BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only
present
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and
present <br> (not share buffer)
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if self.has_present_output:
|
|
|
|
|
shapes = {
|
|
|
|
|
**shapes,
|
2024-06-12 20:04:25 +00:00
|
|
|
"present_key": (self.batch_size, self.num_heads, self.present_buffer_length, self.head_size),
|
|
|
|
|
"present_value": (self.batch_size, self.num_heads, self.present_buffer_length, self.head_size),
|
|
|
|
|
}
|
|
|
|
|
|
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498)
### Description
#### Issues Fixed
(1) **TRT cross attention not thread safe**. [Core changes like
this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6)
are used to make it thread-safe:
* Add an once_flag to CumulatedSequenceLengthCache to make sure it is
only initialized once; and change the cache to be read only after
initialization. Previously, the content is not read-only so it might be
changed by other thread and potentially cause buffer overrun.
* The kernel initialization is not guarded (Although the factory of
kernel loading has static mutex to guard multiple threading), so the
mutable variable might be set by two different threads at the same time.
Add an once_flag to avoid that.
This requires need some workspace computation change as well. So I did
not create a separated pull request.
(2) **Bias for cross attention**
That scenario has assumption that only query has bias, but not for key
and value. However, such assumption is not verified in runtime and there
was no comment of assumption, and there was no test case so the support
of scenario was disabled by mistake. Actually, the scenario is used in
whisper model (TODO: we shall add tests for whisper to CI pipeline, and
also update fusion script to verify such assumptions if needed.)
CUDA/CPU kernels supports bias for cross attention as long as bias is
zero for key and value. I updated the check to support the scenario and
added comments wherever there is such assumption.
(3) **Fallback support**
Previously, unfused kernel did not support packed qkv and packed kv
formats. That means some case might fail since there is no fallback. I
added new AddBiasTranpose cuda kernels for them to support fallback, so
that all supported cases will not fail.
#### Improvements
(4) **QKV workspace size**.
The logic for no_qkv_workspace could be easily out of sync since related
code are scattered in different source files. I refactor the code to
move all related code to one file (attention_prepare_qkv.cu) and add
asserts, so that the logic can be in sync.
(5) **Remove confusing concept of pass past in kv**
parameters.pass_past_in_kv is confusing since the k/v in cross attention
is not past state. Remove it and use parameters.qkv_format ==
Q_K_V_BSNH_BNSH_BNSH instead.
New code does not use past_key/past_value for cross attention, so the
logic is more clear.
(6) **More coverage and less workspace and less transpose of flash and
efficient attention**
Previously, there is one condition does not run flash or efficient
attention:
```
bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;
```
After this change, we can use flash and efficient attention for the
case, and also less workspace.
For example, cross attention with bias, the original code uses two
additional workspaces:
```
transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
add bias: query => q, temp_k_workspace => k, temp_v_workspace => v
```
New logic is like
```
if (has bias)
Add bias to query, key, value, and store in q, k, v workspace
else
Use query, key and value directly as q, k and v in kernel
```
We can see that, we do not need allocate temp_k_workspace and
temp_v_workspace so use less memory. New code saved two transposes in
this case.
Flash and efficient attention supports BSNH or BNSH formats for k and v.
In old code, k/v are also converted to BSNH format. Some is not
necessary. I do some change to convert k/v to BSNH or BNSH case by case.
So that there are more cases can be covered by flash or efficient
attention to improve performance.
(6) **Debugging support**
Previously, there is less debug info. In this change, I add a flag for
debug info in the AttentionData. So that we can output debug info during
the processing.
Also add functions to consolidate the dumping of inputs, QKV processing
and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow
disable dumping from cuda kernel.
#### Summary of changes
(1) Refactoring the CheckInputs, and pass in operator type.
(2) Refactoring the PrepareQKV to support fallback for packed qkv or
packed kv inputs.
(3) Change a few case of PrepareQKV to allow more case covered by flash
and efficient attention.
(4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace
parameters.pass_past_in_kv
(5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of
assumption that key/value has no bias in this case.
(6) Fix thread-safe issue in CumulatedSequenceLengthCache handling.
(7) Add test cases to cover all supported scenarios.
Current support scenarios for MultiHeadAttention for CUDA/CPU:
| Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc
| ---- | ---- | ---- | ------ | ----- | --------- | -------- |
-----|---------
| BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed
| BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in
CPU
| BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in
CPU
| BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for
Q only
| BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only
present
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and
present <br> (not share buffer)
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
|
|
|
if self.has_bias:
|
|
|
|
|
shapes["bias"] = (3 * self.num_heads * self.head_size,)
|
|
|
|
|
|
2024-08-09 08:31:00 +00:00
|
|
|
if self.mask_format == AttentionMaskFormat.Mask_1D_Key_SeqLen:
|
|
|
|
|
shapes["mask"] = (self.batch_size,)
|
|
|
|
|
elif self.mask_format == AttentionMaskFormat.Mask_2D_Key_PaddingMask:
|
|
|
|
|
shapes["mask"] = (self.batch_size, self.total_sequence_length)
|
|
|
|
|
else:
|
|
|
|
|
assert self.mask_format == AttentionMaskFormat.Mask_None
|
|
|
|
|
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
if self.has_attn_bias:
|
|
|
|
|
shapes["attn_bias"] = (
|
|
|
|
|
1 if self.broadcast_attn_bias_dim_0 else self.batch_size,
|
|
|
|
|
1 if self.broadcast_attn_bias_dim_1 else self.num_heads,
|
|
|
|
|
self.sequence_length,
|
|
|
|
|
self.total_sequence_length,
|
|
|
|
|
)
|
|
|
|
|
|
2024-06-12 20:04:25 +00:00
|
|
|
return shapes
|
|
|
|
|
|
[CUDA] FusedMHARunnerFP16v2 thread-safe (#21420)
### Description
- [x] Rewrite FusedMHARunnerFP16v2 to make it thread-safe.
- [x] Add multi-threading tests
Previously, the kernel parameters params is stored as a member of mha
runner, which means that different threads might change the params at
the same time and impacts the other threads.
For example, if batch_size and seq_len was changed by another thread to
larger values in setup(...), buffer overrun might happen in run(...)
because a kernel could read/write memory out of range of allocated
buffers.
In new implementation, I change the api and remove mutable member
variables to make it thread safe. Below is summary of change:
Before:
```
class FusedMHARunnerFP16v2::mhaImpl {
void setup(int seq_len, int batch_size) {
// change scalar params
}
void run(input, output) {
// change params for input and output pointers
// launch kernel using params
}
Fused_multihead_attention_params_v2 params; // mutable, not thread-safe
}
```
After:
```
class FusedMHARunnerFP16v2::FmhaImpl {
void setup(int seq_len, int batch_size, Fused_multihead_attention_params_v2& params) {
// change params
}
void run(params, input, output) {
// change params with input and output pointers
// launch kernel using params
}
}
```
### Motivation and Context
https://github.com/microsoft/onnxruntime/issues/18854
https://github.com/microsoft/onnxruntime/issues/21413
2024-07-22 17:41:08 +00:00
|
|
|
def symbolic_shape_dict(self, input_format=None):
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
shapes: Dict[str, Tuple] = {
|
|
|
|
|
"output": ("batch_size", "sequence_length", self.num_heads * self.head_size),
|
|
|
|
|
}
|
|
|
|
|
|
[CUDA] FusedMHARunnerFP16v2 thread-safe (#21420)
### Description
- [x] Rewrite FusedMHARunnerFP16v2 to make it thread-safe.
- [x] Add multi-threading tests
Previously, the kernel parameters params is stored as a member of mha
runner, which means that different threads might change the params at
the same time and impacts the other threads.
For example, if batch_size and seq_len was changed by another thread to
larger values in setup(...), buffer overrun might happen in run(...)
because a kernel could read/write memory out of range of allocated
buffers.
In new implementation, I change the api and remove mutable member
variables to make it thread safe. Below is summary of change:
Before:
```
class FusedMHARunnerFP16v2::mhaImpl {
void setup(int seq_len, int batch_size) {
// change scalar params
}
void run(input, output) {
// change params for input and output pointers
// launch kernel using params
}
Fused_multihead_attention_params_v2 params; // mutable, not thread-safe
}
```
After:
```
class FusedMHARunnerFP16v2::FmhaImpl {
void setup(int seq_len, int batch_size, Fused_multihead_attention_params_v2& params) {
// change params
}
void run(params, input, output) {
// change params with input and output pointers
// launch kernel using params
}
}
```
### Motivation and Context
https://github.com/microsoft/onnxruntime/issues/18854
https://github.com/microsoft/onnxruntime/issues/21413
2024-07-22 17:41:08 +00:00
|
|
|
input_format = input_format or self.input_format
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
if input_format == InputFormats.QKV_BSN3H:
|
|
|
|
|
shapes = {
|
|
|
|
|
**shapes,
|
|
|
|
|
"query": ("batch_size", "sequence_length", self.num_heads, 3, self.head_size),
|
|
|
|
|
}
|
|
|
|
|
elif input_format == InputFormats.Q_KV_BSNH_BSN2H:
|
|
|
|
|
shapes = {
|
|
|
|
|
**shapes,
|
|
|
|
|
"query": ("batch_size", "sequence_length", self.num_heads * self.head_size),
|
|
|
|
|
"key": ("batch_size", "sequence_length", self.num_heads, 2, self.head_size),
|
|
|
|
|
}
|
|
|
|
|
elif input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH:
|
|
|
|
|
shapes = {
|
|
|
|
|
**shapes,
|
|
|
|
|
"query": ("batch_size", "sequence_length", self.num_heads * self.head_size),
|
|
|
|
|
"key": ("batch_size", "sequence_length", self.num_heads * self.head_size),
|
|
|
|
|
"value": ("batch_size", "sequence_length", self.num_heads * self.head_size),
|
|
|
|
|
}
|
|
|
|
|
else:
|
|
|
|
|
assert input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH
|
|
|
|
|
shapes = {
|
|
|
|
|
**shapes,
|
[CUDA] FusedMHARunnerFP16v2 thread-safe (#21420)
### Description
- [x] Rewrite FusedMHARunnerFP16v2 to make it thread-safe.
- [x] Add multi-threading tests
Previously, the kernel parameters params is stored as a member of mha
runner, which means that different threads might change the params at
the same time and impacts the other threads.
For example, if batch_size and seq_len was changed by another thread to
larger values in setup(...), buffer overrun might happen in run(...)
because a kernel could read/write memory out of range of allocated
buffers.
In new implementation, I change the api and remove mutable member
variables to make it thread safe. Below is summary of change:
Before:
```
class FusedMHARunnerFP16v2::mhaImpl {
void setup(int seq_len, int batch_size) {
// change scalar params
}
void run(input, output) {
// change params for input and output pointers
// launch kernel using params
}
Fused_multihead_attention_params_v2 params; // mutable, not thread-safe
}
```
After:
```
class FusedMHARunnerFP16v2::FmhaImpl {
void setup(int seq_len, int batch_size, Fused_multihead_attention_params_v2& params) {
// change params
}
void run(params, input, output) {
// change params with input and output pointers
// launch kernel using params
}
}
```
### Motivation and Context
https://github.com/microsoft/onnxruntime/issues/18854
https://github.com/microsoft/onnxruntime/issues/21413
2024-07-22 17:41:08 +00:00
|
|
|
"query": ("batch_size", "sequence_length", self.num_heads * self.head_size),
|
|
|
|
|
"key": ("batch_size", self.num_heads, "sequence_length", self.head_size),
|
|
|
|
|
"value": ("batch_size", self.num_heads, "sequence_length", self.head_size),
|
|
|
|
|
}
|
|
|
|
|
|
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498)
### Description
#### Issues Fixed
(1) **TRT cross attention not thread safe**. [Core changes like
this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6)
are used to make it thread-safe:
* Add an once_flag to CumulatedSequenceLengthCache to make sure it is
only initialized once; and change the cache to be read only after
initialization. Previously, the content is not read-only so it might be
changed by other thread and potentially cause buffer overrun.
* The kernel initialization is not guarded (Although the factory of
kernel loading has static mutex to guard multiple threading), so the
mutable variable might be set by two different threads at the same time.
Add an once_flag to avoid that.
This requires need some workspace computation change as well. So I did
not create a separated pull request.
(2) **Bias for cross attention**
That scenario has assumption that only query has bias, but not for key
and value. However, such assumption is not verified in runtime and there
was no comment of assumption, and there was no test case so the support
of scenario was disabled by mistake. Actually, the scenario is used in
whisper model (TODO: we shall add tests for whisper to CI pipeline, and
also update fusion script to verify such assumptions if needed.)
CUDA/CPU kernels supports bias for cross attention as long as bias is
zero for key and value. I updated the check to support the scenario and
added comments wherever there is such assumption.
(3) **Fallback support**
Previously, unfused kernel did not support packed qkv and packed kv
formats. That means some case might fail since there is no fallback. I
added new AddBiasTranpose cuda kernels for them to support fallback, so
that all supported cases will not fail.
#### Improvements
(4) **QKV workspace size**.
The logic for no_qkv_workspace could be easily out of sync since related
code are scattered in different source files. I refactor the code to
move all related code to one file (attention_prepare_qkv.cu) and add
asserts, so that the logic can be in sync.
(5) **Remove confusing concept of pass past in kv**
parameters.pass_past_in_kv is confusing since the k/v in cross attention
is not past state. Remove it and use parameters.qkv_format ==
Q_K_V_BSNH_BNSH_BNSH instead.
New code does not use past_key/past_value for cross attention, so the
logic is more clear.
(6) **More coverage and less workspace and less transpose of flash and
efficient attention**
Previously, there is one condition does not run flash or efficient
attention:
```
bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;
```
After this change, we can use flash and efficient attention for the
case, and also less workspace.
For example, cross attention with bias, the original code uses two
additional workspaces:
```
transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
add bias: query => q, temp_k_workspace => k, temp_v_workspace => v
```
New logic is like
```
if (has bias)
Add bias to query, key, value, and store in q, k, v workspace
else
Use query, key and value directly as q, k and v in kernel
```
We can see that, we do not need allocate temp_k_workspace and
temp_v_workspace so use less memory. New code saved two transposes in
this case.
Flash and efficient attention supports BSNH or BNSH formats for k and v.
In old code, k/v are also converted to BSNH format. Some is not
necessary. I do some change to convert k/v to BSNH or BNSH case by case.
So that there are more cases can be covered by flash or efficient
attention to improve performance.
(6) **Debugging support**
Previously, there is less debug info. In this change, I add a flag for
debug info in the AttentionData. So that we can output debug info during
the processing.
Also add functions to consolidate the dumping of inputs, QKV processing
and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow
disable dumping from cuda kernel.
#### Summary of changes
(1) Refactoring the CheckInputs, and pass in operator type.
(2) Refactoring the PrepareQKV to support fallback for packed qkv or
packed kv inputs.
(3) Change a few case of PrepareQKV to allow more case covered by flash
and efficient attention.
(4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace
parameters.pass_past_in_kv
(5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of
assumption that key/value has no bias in this case.
(6) Fix thread-safe issue in CumulatedSequenceLengthCache handling.
(7) Add test cases to cover all supported scenarios.
Current support scenarios for MultiHeadAttention for CUDA/CPU:
| Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc
| ---- | ---- | ---- | ------ | ----- | --------- | -------- |
-----|---------
| BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed
| BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in
CPU
| BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in
CPU
| BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for
Q only
| BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only
present
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and
present <br> (not share buffer)
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
|
|
|
if self.has_past_input:
|
[CUDA] FusedMHARunnerFP16v2 thread-safe (#21420)
### Description
- [x] Rewrite FusedMHARunnerFP16v2 to make it thread-safe.
- [x] Add multi-threading tests
Previously, the kernel parameters params is stored as a member of mha
runner, which means that different threads might change the params at
the same time and impacts the other threads.
For example, if batch_size and seq_len was changed by another thread to
larger values in setup(...), buffer overrun might happen in run(...)
because a kernel could read/write memory out of range of allocated
buffers.
In new implementation, I change the api and remove mutable member
variables to make it thread safe. Below is summary of change:
Before:
```
class FusedMHARunnerFP16v2::mhaImpl {
void setup(int seq_len, int batch_size) {
// change scalar params
}
void run(input, output) {
// change params for input and output pointers
// launch kernel using params
}
Fused_multihead_attention_params_v2 params; // mutable, not thread-safe
}
```
After:
```
class FusedMHARunnerFP16v2::FmhaImpl {
void setup(int seq_len, int batch_size, Fused_multihead_attention_params_v2& params) {
// change params
}
void run(params, input, output) {
// change params with input and output pointers
// launch kernel using params
}
}
```
### Motivation and Context
https://github.com/microsoft/onnxruntime/issues/18854
https://github.com/microsoft/onnxruntime/issues/21413
2024-07-22 17:41:08 +00:00
|
|
|
shapes = {
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
**shapes,
|
[CUDA] FusedMHARunnerFP16v2 thread-safe (#21420)
### Description
- [x] Rewrite FusedMHARunnerFP16v2 to make it thread-safe.
- [x] Add multi-threading tests
Previously, the kernel parameters params is stored as a member of mha
runner, which means that different threads might change the params at
the same time and impacts the other threads.
For example, if batch_size and seq_len was changed by another thread to
larger values in setup(...), buffer overrun might happen in run(...)
because a kernel could read/write memory out of range of allocated
buffers.
In new implementation, I change the api and remove mutable member
variables to make it thread safe. Below is summary of change:
Before:
```
class FusedMHARunnerFP16v2::mhaImpl {
void setup(int seq_len, int batch_size) {
// change scalar params
}
void run(input, output) {
// change params for input and output pointers
// launch kernel using params
}
Fused_multihead_attention_params_v2 params; // mutable, not thread-safe
}
```
After:
```
class FusedMHARunnerFP16v2::FmhaImpl {
void setup(int seq_len, int batch_size, Fused_multihead_attention_params_v2& params) {
// change params
}
void run(params, input, output) {
// change params with input and output pointers
// launch kernel using params
}
}
```
### Motivation and Context
https://github.com/microsoft/onnxruntime/issues/18854
https://github.com/microsoft/onnxruntime/issues/21413
2024-07-22 17:41:08 +00:00
|
|
|
"past_key": ("batch_size", self.num_heads, "past_buffer_length", self.head_size),
|
|
|
|
|
"past_value": ("batch_size", self.num_heads, "past_buffer_length", self.head_size),
|
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498)
### Description
#### Issues Fixed
(1) **TRT cross attention not thread safe**. [Core changes like
this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6)
are used to make it thread-safe:
* Add an once_flag to CumulatedSequenceLengthCache to make sure it is
only initialized once; and change the cache to be read only after
initialization. Previously, the content is not read-only so it might be
changed by other thread and potentially cause buffer overrun.
* The kernel initialization is not guarded (Although the factory of
kernel loading has static mutex to guard multiple threading), so the
mutable variable might be set by two different threads at the same time.
Add an once_flag to avoid that.
This requires need some workspace computation change as well. So I did
not create a separated pull request.
(2) **Bias for cross attention**
That scenario has assumption that only query has bias, but not for key
and value. However, such assumption is not verified in runtime and there
was no comment of assumption, and there was no test case so the support
of scenario was disabled by mistake. Actually, the scenario is used in
whisper model (TODO: we shall add tests for whisper to CI pipeline, and
also update fusion script to verify such assumptions if needed.)
CUDA/CPU kernels supports bias for cross attention as long as bias is
zero for key and value. I updated the check to support the scenario and
added comments wherever there is such assumption.
(3) **Fallback support**
Previously, unfused kernel did not support packed qkv and packed kv
formats. That means some case might fail since there is no fallback. I
added new AddBiasTranpose cuda kernels for them to support fallback, so
that all supported cases will not fail.
#### Improvements
(4) **QKV workspace size**.
The logic for no_qkv_workspace could be easily out of sync since related
code are scattered in different source files. I refactor the code to
move all related code to one file (attention_prepare_qkv.cu) and add
asserts, so that the logic can be in sync.
(5) **Remove confusing concept of pass past in kv**
parameters.pass_past_in_kv is confusing since the k/v in cross attention
is not past state. Remove it and use parameters.qkv_format ==
Q_K_V_BSNH_BNSH_BNSH instead.
New code does not use past_key/past_value for cross attention, so the
logic is more clear.
(6) **More coverage and less workspace and less transpose of flash and
efficient attention**
Previously, there is one condition does not run flash or efficient
attention:
```
bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;
```
After this change, we can use flash and efficient attention for the
case, and also less workspace.
For example, cross attention with bias, the original code uses two
additional workspaces:
```
transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
add bias: query => q, temp_k_workspace => k, temp_v_workspace => v
```
New logic is like
```
if (has bias)
Add bias to query, key, value, and store in q, k, v workspace
else
Use query, key and value directly as q, k and v in kernel
```
We can see that, we do not need allocate temp_k_workspace and
temp_v_workspace so use less memory. New code saved two transposes in
this case.
Flash and efficient attention supports BSNH or BNSH formats for k and v.
In old code, k/v are also converted to BSNH format. Some is not
necessary. I do some change to convert k/v to BSNH or BNSH case by case.
So that there are more cases can be covered by flash or efficient
attention to improve performance.
(6) **Debugging support**
Previously, there is less debug info. In this change, I add a flag for
debug info in the AttentionData. So that we can output debug info during
the processing.
Also add functions to consolidate the dumping of inputs, QKV processing
and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow
disable dumping from cuda kernel.
#### Summary of changes
(1) Refactoring the CheckInputs, and pass in operator type.
(2) Refactoring the PrepareQKV to support fallback for packed qkv or
packed kv inputs.
(3) Change a few case of PrepareQKV to allow more case covered by flash
and efficient attention.
(4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace
parameters.pass_past_in_kv
(5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of
assumption that key/value has no bias in this case.
(6) Fix thread-safe issue in CumulatedSequenceLengthCache handling.
(7) Add test cases to cover all supported scenarios.
Current support scenarios for MultiHeadAttention for CUDA/CPU:
| Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc
| ---- | ---- | ---- | ------ | ----- | --------- | -------- |
-----|---------
| BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed
| BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in
CPU
| BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in
CPU
| BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for
Q only
| BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only
present
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and
present <br> (not share buffer)
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if self.has_present_output:
|
|
|
|
|
shapes = {
|
|
|
|
|
**shapes,
|
[CUDA] FusedMHARunnerFP16v2 thread-safe (#21420)
### Description
- [x] Rewrite FusedMHARunnerFP16v2 to make it thread-safe.
- [x] Add multi-threading tests
Previously, the kernel parameters params is stored as a member of mha
runner, which means that different threads might change the params at
the same time and impacts the other threads.
For example, if batch_size and seq_len was changed by another thread to
larger values in setup(...), buffer overrun might happen in run(...)
because a kernel could read/write memory out of range of allocated
buffers.
In new implementation, I change the api and remove mutable member
variables to make it thread safe. Below is summary of change:
Before:
```
class FusedMHARunnerFP16v2::mhaImpl {
void setup(int seq_len, int batch_size) {
// change scalar params
}
void run(input, output) {
// change params for input and output pointers
// launch kernel using params
}
Fused_multihead_attention_params_v2 params; // mutable, not thread-safe
}
```
After:
```
class FusedMHARunnerFP16v2::FmhaImpl {
void setup(int seq_len, int batch_size, Fused_multihead_attention_params_v2& params) {
// change params
}
void run(params, input, output) {
// change params with input and output pointers
// launch kernel using params
}
}
```
### Motivation and Context
https://github.com/microsoft/onnxruntime/issues/18854
https://github.com/microsoft/onnxruntime/issues/21413
2024-07-22 17:41:08 +00:00
|
|
|
"present_key": ("batch_size", self.num_heads, "present_buffer_length", self.head_size),
|
|
|
|
|
"present_value": ("batch_size", self.num_heads, "present_buffer_length", self.head_size),
|
|
|
|
|
}
|
|
|
|
|
|
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498)
### Description
#### Issues Fixed
(1) **TRT cross attention not thread safe**. [Core changes like
this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6)
are used to make it thread-safe:
* Add an once_flag to CumulatedSequenceLengthCache to make sure it is
only initialized once; and change the cache to be read only after
initialization. Previously, the content is not read-only so it might be
changed by other thread and potentially cause buffer overrun.
* The kernel initialization is not guarded (Although the factory of
kernel loading has static mutex to guard multiple threading), so the
mutable variable might be set by two different threads at the same time.
Add an once_flag to avoid that.
This requires need some workspace computation change as well. So I did
not create a separated pull request.
(2) **Bias for cross attention**
That scenario has assumption that only query has bias, but not for key
and value. However, such assumption is not verified in runtime and there
was no comment of assumption, and there was no test case so the support
of scenario was disabled by mistake. Actually, the scenario is used in
whisper model (TODO: we shall add tests for whisper to CI pipeline, and
also update fusion script to verify such assumptions if needed.)
CUDA/CPU kernels supports bias for cross attention as long as bias is
zero for key and value. I updated the check to support the scenario and
added comments wherever there is such assumption.
(3) **Fallback support**
Previously, unfused kernel did not support packed qkv and packed kv
formats. That means some case might fail since there is no fallback. I
added new AddBiasTranpose cuda kernels for them to support fallback, so
that all supported cases will not fail.
#### Improvements
(4) **QKV workspace size**.
The logic for no_qkv_workspace could be easily out of sync since related
code are scattered in different source files. I refactor the code to
move all related code to one file (attention_prepare_qkv.cu) and add
asserts, so that the logic can be in sync.
(5) **Remove confusing concept of pass past in kv**
parameters.pass_past_in_kv is confusing since the k/v in cross attention
is not past state. Remove it and use parameters.qkv_format ==
Q_K_V_BSNH_BNSH_BNSH instead.
New code does not use past_key/past_value for cross attention, so the
logic is more clear.
(6) **More coverage and less workspace and less transpose of flash and
efficient attention**
Previously, there is one condition does not run flash or efficient
attention:
```
bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;
```
After this change, we can use flash and efficient attention for the
case, and also less workspace.
For example, cross attention with bias, the original code uses two
additional workspaces:
```
transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
add bias: query => q, temp_k_workspace => k, temp_v_workspace => v
```
New logic is like
```
if (has bias)
Add bias to query, key, value, and store in q, k, v workspace
else
Use query, key and value directly as q, k and v in kernel
```
We can see that, we do not need allocate temp_k_workspace and
temp_v_workspace so use less memory. New code saved two transposes in
this case.
Flash and efficient attention supports BSNH or BNSH formats for k and v.
In old code, k/v are also converted to BSNH format. Some is not
necessary. I do some change to convert k/v to BSNH or BNSH case by case.
So that there are more cases can be covered by flash or efficient
attention to improve performance.
(6) **Debugging support**
Previously, there is less debug info. In this change, I add a flag for
debug info in the AttentionData. So that we can output debug info during
the processing.
Also add functions to consolidate the dumping of inputs, QKV processing
and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow
disable dumping from cuda kernel.
#### Summary of changes
(1) Refactoring the CheckInputs, and pass in operator type.
(2) Refactoring the PrepareQKV to support fallback for packed qkv or
packed kv inputs.
(3) Change a few case of PrepareQKV to allow more case covered by flash
and efficient attention.
(4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace
parameters.pass_past_in_kv
(5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of
assumption that key/value has no bias in this case.
(6) Fix thread-safe issue in CumulatedSequenceLengthCache handling.
(7) Add test cases to cover all supported scenarios.
Current support scenarios for MultiHeadAttention for CUDA/CPU:
| Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc
| ---- | ---- | ---- | ------ | ----- | --------- | -------- |
-----|---------
| BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed
| BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in
CPU
| BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in
CPU
| BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for
Q only
| BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only
present
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and
present <br> (not share buffer)
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
|
|
|
if self.has_bias:
|
|
|
|
|
shapes["bias"] = (3 * self.num_heads * self.head_size,)
|
|
|
|
|
|
2024-08-09 08:31:00 +00:00
|
|
|
if self.mask_format == AttentionMaskFormat.Mask_1D_Key_SeqLen:
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
shapes["mask"] = ("batch_size",)
|
2024-08-09 08:31:00 +00:00
|
|
|
elif self.mask_format == AttentionMaskFormat.Mask_2D_Key_PaddingMask:
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
shapes["mask"] = ("batch_size", "total_sequence_length")
|
2024-08-09 08:31:00 +00:00
|
|
|
else:
|
|
|
|
|
assert self.mask_format == AttentionMaskFormat.Mask_None
|
|
|
|
|
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
if self.has_attn_bias:
|
|
|
|
|
shapes["attn_bias"] = ("batch_size_or_1", "num_heads_or_1", "sequence_length", "total_sequence_length")
|
|
|
|
|
|
[CUDA] FusedMHARunnerFP16v2 thread-safe (#21420)
### Description
- [x] Rewrite FusedMHARunnerFP16v2 to make it thread-safe.
- [x] Add multi-threading tests
Previously, the kernel parameters params is stored as a member of mha
runner, which means that different threads might change the params at
the same time and impacts the other threads.
For example, if batch_size and seq_len was changed by another thread to
larger values in setup(...), buffer overrun might happen in run(...)
because a kernel could read/write memory out of range of allocated
buffers.
In new implementation, I change the api and remove mutable member
variables to make it thread safe. Below is summary of change:
Before:
```
class FusedMHARunnerFP16v2::mhaImpl {
void setup(int seq_len, int batch_size) {
// change scalar params
}
void run(input, output) {
// change params for input and output pointers
// launch kernel using params
}
Fused_multihead_attention_params_v2 params; // mutable, not thread-safe
}
```
After:
```
class FusedMHARunnerFP16v2::FmhaImpl {
void setup(int seq_len, int batch_size, Fused_multihead_attention_params_v2& params) {
// change params
}
void run(params, input, output) {
// change params with input and output pointers
// launch kernel using params
}
}
```
### Motivation and Context
https://github.com/microsoft/onnxruntime/issues/18854
https://github.com/microsoft/onnxruntime/issues/21413
2024-07-22 17:41:08 +00:00
|
|
|
return shapes
|
|
|
|
|
|
2024-08-09 08:31:00 +00:00
|
|
|
def right_side_padding_masks(self):
|
|
|
|
|
q_mask = torch.ones(self.batch_size, 1, self.sequence_length, 1, dtype=torch.bool, device=self.device)
|
|
|
|
|
k_mask = torch.ones(self.batch_size, 1, self.total_sequence_length, 1, dtype=torch.bool, device=self.device)
|
|
|
|
|
mask = torch.ones(
|
|
|
|
|
self.batch_size,
|
|
|
|
|
self.num_heads,
|
|
|
|
|
self.sequence_length,
|
|
|
|
|
self.total_sequence_length,
|
|
|
|
|
dtype=torch.bool,
|
|
|
|
|
device=self.device,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if self.mask_format != AttentionMaskFormat.Mask_None:
|
|
|
|
|
for i, (m, n) in enumerate(zip(self.mask_index_q, self.mask_index_kv)):
|
|
|
|
|
q_mask[i, :, m:, :] = False
|
|
|
|
|
k_mask[i, :, n:, :] = False
|
|
|
|
|
mask[i, :, m:, :] = False
|
|
|
|
|
mask[i, :, :, n:] = False
|
|
|
|
|
return q_mask, k_mask, mask
|
|
|
|
|
|
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498)
### Description
#### Issues Fixed
(1) **TRT cross attention not thread safe**. [Core changes like
this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6)
are used to make it thread-safe:
* Add an once_flag to CumulatedSequenceLengthCache to make sure it is
only initialized once; and change the cache to be read only after
initialization. Previously, the content is not read-only so it might be
changed by other thread and potentially cause buffer overrun.
* The kernel initialization is not guarded (Although the factory of
kernel loading has static mutex to guard multiple threading), so the
mutable variable might be set by two different threads at the same time.
Add an once_flag to avoid that.
This requires need some workspace computation change as well. So I did
not create a separated pull request.
(2) **Bias for cross attention**
That scenario has assumption that only query has bias, but not for key
and value. However, such assumption is not verified in runtime and there
was no comment of assumption, and there was no test case so the support
of scenario was disabled by mistake. Actually, the scenario is used in
whisper model (TODO: we shall add tests for whisper to CI pipeline, and
also update fusion script to verify such assumptions if needed.)
CUDA/CPU kernels supports bias for cross attention as long as bias is
zero for key and value. I updated the check to support the scenario and
added comments wherever there is such assumption.
(3) **Fallback support**
Previously, unfused kernel did not support packed qkv and packed kv
formats. That means some case might fail since there is no fallback. I
added new AddBiasTranpose cuda kernels for them to support fallback, so
that all supported cases will not fail.
#### Improvements
(4) **QKV workspace size**.
The logic for no_qkv_workspace could be easily out of sync since related
code are scattered in different source files. I refactor the code to
move all related code to one file (attention_prepare_qkv.cu) and add
asserts, so that the logic can be in sync.
(5) **Remove confusing concept of pass past in kv**
parameters.pass_past_in_kv is confusing since the k/v in cross attention
is not past state. Remove it and use parameters.qkv_format ==
Q_K_V_BSNH_BNSH_BNSH instead.
New code does not use past_key/past_value for cross attention, so the
logic is more clear.
(6) **More coverage and less workspace and less transpose of flash and
efficient attention**
Previously, there is one condition does not run flash or efficient
attention:
```
bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;
```
After this change, we can use flash and efficient attention for the
case, and also less workspace.
For example, cross attention with bias, the original code uses two
additional workspaces:
```
transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
add bias: query => q, temp_k_workspace => k, temp_v_workspace => v
```
New logic is like
```
if (has bias)
Add bias to query, key, value, and store in q, k, v workspace
else
Use query, key and value directly as q, k and v in kernel
```
We can see that, we do not need allocate temp_k_workspace and
temp_v_workspace so use less memory. New code saved two transposes in
this case.
Flash and efficient attention supports BSNH or BNSH formats for k and v.
In old code, k/v are also converted to BSNH format. Some is not
necessary. I do some change to convert k/v to BSNH or BNSH case by case.
So that there are more cases can be covered by flash or efficient
attention to improve performance.
(6) **Debugging support**
Previously, there is less debug info. In this change, I add a flag for
debug info in the AttentionData. So that we can output debug info during
the processing.
Also add functions to consolidate the dumping of inputs, QKV processing
and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow
disable dumping from cuda kernel.
#### Summary of changes
(1) Refactoring the CheckInputs, and pass in operator type.
(2) Refactoring the PrepareQKV to support fallback for packed qkv or
packed kv inputs.
(3) Change a few case of PrepareQKV to allow more case covered by flash
and efficient attention.
(4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace
parameters.pass_past_in_kv
(5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of
assumption that key/value has no bias in this case.
(6) Fix thread-safe issue in CumulatedSequenceLengthCache handling.
(7) Add test cases to cover all supported scenarios.
Current support scenarios for MultiHeadAttention for CUDA/CPU:
| Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc
| ---- | ---- | ---- | ------ | ----- | --------- | -------- |
-----|---------
| BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed
| BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in
CPU
| BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in
CPU
| BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for
Q only
| BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only
present
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and
present <br> (not share buffer)
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
|
|
|
def random_inputs(self, seed: int = 123, no_bias_k_v: bool = False):
|
2024-06-12 20:04:25 +00:00
|
|
|
device = self.device
|
|
|
|
|
dtype = self.dtype
|
|
|
|
|
|
|
|
|
|
shape_dict = self.shape_dict()
|
|
|
|
|
|
|
|
|
|
if seed > 0:
|
|
|
|
|
torch.manual_seed(seed)
|
|
|
|
|
|
|
|
|
|
shape = (self.batch_size, self.sequence_length, self.num_heads, self.head_size)
|
|
|
|
|
q = torch.empty(shape, device=device, dtype=dtype).normal_(mean=0, std=0.1)
|
|
|
|
|
k = torch.empty(shape, device=device, dtype=dtype).normal_(mean=0, std=0.1)
|
|
|
|
|
v = torch.empty(shape, device=device, dtype=dtype).normal_(mean=0, std=0.1)
|
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498)
### Description
#### Issues Fixed
(1) **TRT cross attention not thread safe**. [Core changes like
this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6)
are used to make it thread-safe:
* Add an once_flag to CumulatedSequenceLengthCache to make sure it is
only initialized once; and change the cache to be read only after
initialization. Previously, the content is not read-only so it might be
changed by other thread and potentially cause buffer overrun.
* The kernel initialization is not guarded (Although the factory of
kernel loading has static mutex to guard multiple threading), so the
mutable variable might be set by two different threads at the same time.
Add an once_flag to avoid that.
This requires need some workspace computation change as well. So I did
not create a separated pull request.
(2) **Bias for cross attention**
That scenario has assumption that only query has bias, but not for key
and value. However, such assumption is not verified in runtime and there
was no comment of assumption, and there was no test case so the support
of scenario was disabled by mistake. Actually, the scenario is used in
whisper model (TODO: we shall add tests for whisper to CI pipeline, and
also update fusion script to verify such assumptions if needed.)
CUDA/CPU kernels supports bias for cross attention as long as bias is
zero for key and value. I updated the check to support the scenario and
added comments wherever there is such assumption.
(3) **Fallback support**
Previously, unfused kernel did not support packed qkv and packed kv
formats. That means some case might fail since there is no fallback. I
added new AddBiasTranpose cuda kernels for them to support fallback, so
that all supported cases will not fail.
#### Improvements
(4) **QKV workspace size**.
The logic for no_qkv_workspace could be easily out of sync since related
code are scattered in different source files. I refactor the code to
move all related code to one file (attention_prepare_qkv.cu) and add
asserts, so that the logic can be in sync.
(5) **Remove confusing concept of pass past in kv**
parameters.pass_past_in_kv is confusing since the k/v in cross attention
is not past state. Remove it and use parameters.qkv_format ==
Q_K_V_BSNH_BNSH_BNSH instead.
New code does not use past_key/past_value for cross attention, so the
logic is more clear.
(6) **More coverage and less workspace and less transpose of flash and
efficient attention**
Previously, there is one condition does not run flash or efficient
attention:
```
bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;
```
After this change, we can use flash and efficient attention for the
case, and also less workspace.
For example, cross attention with bias, the original code uses two
additional workspaces:
```
transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
add bias: query => q, temp_k_workspace => k, temp_v_workspace => v
```
New logic is like
```
if (has bias)
Add bias to query, key, value, and store in q, k, v workspace
else
Use query, key and value directly as q, k and v in kernel
```
We can see that, we do not need allocate temp_k_workspace and
temp_v_workspace so use less memory. New code saved two transposes in
this case.
Flash and efficient attention supports BSNH or BNSH formats for k and v.
In old code, k/v are also converted to BSNH format. Some is not
necessary. I do some change to convert k/v to BSNH or BNSH case by case.
So that there are more cases can be covered by flash or efficient
attention to improve performance.
(6) **Debugging support**
Previously, there is less debug info. In this change, I add a flag for
debug info in the AttentionData. So that we can output debug info during
the processing.
Also add functions to consolidate the dumping of inputs, QKV processing
and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow
disable dumping from cuda kernel.
#### Summary of changes
(1) Refactoring the CheckInputs, and pass in operator type.
(2) Refactoring the PrepareQKV to support fallback for packed qkv or
packed kv inputs.
(3) Change a few case of PrepareQKV to allow more case covered by flash
and efficient attention.
(4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace
parameters.pass_past_in_kv
(5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of
assumption that key/value has no bias in this case.
(6) Fix thread-safe issue in CumulatedSequenceLengthCache handling.
(7) Add test cases to cover all supported scenarios.
Current support scenarios for MultiHeadAttention for CUDA/CPU:
| Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc
| ---- | ---- | ---- | ------ | ----- | --------- | -------- |
-----|---------
| BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed
| BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in
CPU
| BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in
CPU
| BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for
Q only
| BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only
present
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and
present <br> (not share buffer)
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
|
|
|
|
|
|
|
|
bias_q = torch.empty((self.num_heads * self.head_size,), device=device, dtype=dtype).normal_(mean=0, std=0.1)
|
|
|
|
|
bias_k = torch.empty((self.num_heads * self.head_size,), device=device, dtype=dtype).normal_(mean=0, std=0.1)
|
|
|
|
|
bias_v = torch.empty((self.num_heads * self.head_size,), device=device, dtype=dtype).normal_(mean=0, std=0.1)
|
|
|
|
|
if no_bias_k_v:
|
|
|
|
|
bias_k = torch.zeros_like(bias_k)
|
|
|
|
|
bias_v = torch.zeros_like(bias_v)
|
|
|
|
|
|
2024-06-12 20:04:25 +00:00
|
|
|
k_bnsh = k.transpose(1, 2)
|
|
|
|
|
v_bnsh = v.transpose(1, 2)
|
|
|
|
|
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
if self.input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH:
|
|
|
|
|
feeds = {
|
2024-06-12 20:04:25 +00:00
|
|
|
"query": q.reshape(shape_dict["query"]),
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
"key": k.reshape(shape_dict["key"]),
|
|
|
|
|
"value": v.reshape(shape_dict["value"]),
|
2024-06-12 20:04:25 +00:00
|
|
|
}
|
|
|
|
|
elif self.input_format == InputFormats.QKV_BSN3H:
|
|
|
|
|
query = q.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size)
|
|
|
|
|
key = k.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size)
|
|
|
|
|
value = v.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size)
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
feeds = {
|
|
|
|
|
"query": torch.dstack((query, key, value)).reshape(shape_dict["query"]).contiguous(),
|
|
|
|
|
}
|
2024-06-12 20:04:25 +00:00
|
|
|
elif self.input_format == InputFormats.Q_KV_BSNH_BSN2H:
|
|
|
|
|
key = k.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size)
|
|
|
|
|
value = v.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size)
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
feeds = {
|
|
|
|
|
"query": q.reshape(shape_dict["query"]),
|
|
|
|
|
"key": torch.dstack((key, value)).reshape(shape_dict["key"]).contiguous(),
|
|
|
|
|
}
|
|
|
|
|
else:
|
|
|
|
|
assert self.input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH
|
|
|
|
|
feeds = {
|
|
|
|
|
"query": q.reshape(shape_dict["query"]),
|
|
|
|
|
"key": k_bnsh.contiguous(),
|
|
|
|
|
"value": v_bnsh.contiguous(),
|
|
|
|
|
}
|
|
|
|
|
|
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498)
### Description
#### Issues Fixed
(1) **TRT cross attention not thread safe**. [Core changes like
this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6)
are used to make it thread-safe:
* Add an once_flag to CumulatedSequenceLengthCache to make sure it is
only initialized once; and change the cache to be read only after
initialization. Previously, the content is not read-only so it might be
changed by other thread and potentially cause buffer overrun.
* The kernel initialization is not guarded (Although the factory of
kernel loading has static mutex to guard multiple threading), so the
mutable variable might be set by two different threads at the same time.
Add an once_flag to avoid that.
This requires need some workspace computation change as well. So I did
not create a separated pull request.
(2) **Bias for cross attention**
That scenario has assumption that only query has bias, but not for key
and value. However, such assumption is not verified in runtime and there
was no comment of assumption, and there was no test case so the support
of scenario was disabled by mistake. Actually, the scenario is used in
whisper model (TODO: we shall add tests for whisper to CI pipeline, and
also update fusion script to verify such assumptions if needed.)
CUDA/CPU kernels supports bias for cross attention as long as bias is
zero for key and value. I updated the check to support the scenario and
added comments wherever there is such assumption.
(3) **Fallback support**
Previously, unfused kernel did not support packed qkv and packed kv
formats. That means some case might fail since there is no fallback. I
added new AddBiasTranpose cuda kernels for them to support fallback, so
that all supported cases will not fail.
#### Improvements
(4) **QKV workspace size**.
The logic for no_qkv_workspace could be easily out of sync since related
code are scattered in different source files. I refactor the code to
move all related code to one file (attention_prepare_qkv.cu) and add
asserts, so that the logic can be in sync.
(5) **Remove confusing concept of pass past in kv**
parameters.pass_past_in_kv is confusing since the k/v in cross attention
is not past state. Remove it and use parameters.qkv_format ==
Q_K_V_BSNH_BNSH_BNSH instead.
New code does not use past_key/past_value for cross attention, so the
logic is more clear.
(6) **More coverage and less workspace and less transpose of flash and
efficient attention**
Previously, there is one condition does not run flash or efficient
attention:
```
bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;
```
After this change, we can use flash and efficient attention for the
case, and also less workspace.
For example, cross attention with bias, the original code uses two
additional workspaces:
```
transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
add bias: query => q, temp_k_workspace => k, temp_v_workspace => v
```
New logic is like
```
if (has bias)
Add bias to query, key, value, and store in q, k, v workspace
else
Use query, key and value directly as q, k and v in kernel
```
We can see that, we do not need allocate temp_k_workspace and
temp_v_workspace so use less memory. New code saved two transposes in
this case.
Flash and efficient attention supports BSNH or BNSH formats for k and v.
In old code, k/v are also converted to BSNH format. Some is not
necessary. I do some change to convert k/v to BSNH or BNSH case by case.
So that there are more cases can be covered by flash or efficient
attention to improve performance.
(6) **Debugging support**
Previously, there is less debug info. In this change, I add a flag for
debug info in the AttentionData. So that we can output debug info during
the processing.
Also add functions to consolidate the dumping of inputs, QKV processing
and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow
disable dumping from cuda kernel.
#### Summary of changes
(1) Refactoring the CheckInputs, and pass in operator type.
(2) Refactoring the PrepareQKV to support fallback for packed qkv or
packed kv inputs.
(3) Change a few case of PrepareQKV to allow more case covered by flash
and efficient attention.
(4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace
parameters.pass_past_in_kv
(5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of
assumption that key/value has no bias in this case.
(6) Fix thread-safe issue in CumulatedSequenceLengthCache handling.
(7) Add test cases to cover all supported scenarios.
Current support scenarios for MultiHeadAttention for CUDA/CPU:
| Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc
| ---- | ---- | ---- | ------ | ----- | --------- | -------- |
-----|---------
| BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed
| BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in
CPU
| BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in
CPU
| BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for
Q only
| BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only
present
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and
present <br> (not share buffer)
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
|
|
|
if self.has_past_input:
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
feeds = {
|
|
|
|
|
**feeds,
|
|
|
|
|
"past_key": torch.empty(shape_dict["past_key"], device=device, dtype=dtype).normal_(mean=0, std=0.1),
|
|
|
|
|
"past_value": torch.empty(shape_dict["past_value"], device=device, dtype=dtype).normal_(
|
|
|
|
|
mean=0, std=0.1
|
|
|
|
|
),
|
|
|
|
|
}
|
2024-06-12 20:04:25 +00:00
|
|
|
|
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498)
### Description
#### Issues Fixed
(1) **TRT cross attention not thread safe**. [Core changes like
this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6)
are used to make it thread-safe:
* Add an once_flag to CumulatedSequenceLengthCache to make sure it is
only initialized once; and change the cache to be read only after
initialization. Previously, the content is not read-only so it might be
changed by other thread and potentially cause buffer overrun.
* The kernel initialization is not guarded (Although the factory of
kernel loading has static mutex to guard multiple threading), so the
mutable variable might be set by two different threads at the same time.
Add an once_flag to avoid that.
This requires need some workspace computation change as well. So I did
not create a separated pull request.
(2) **Bias for cross attention**
That scenario has assumption that only query has bias, but not for key
and value. However, such assumption is not verified in runtime and there
was no comment of assumption, and there was no test case so the support
of scenario was disabled by mistake. Actually, the scenario is used in
whisper model (TODO: we shall add tests for whisper to CI pipeline, and
also update fusion script to verify such assumptions if needed.)
CUDA/CPU kernels supports bias for cross attention as long as bias is
zero for key and value. I updated the check to support the scenario and
added comments wherever there is such assumption.
(3) **Fallback support**
Previously, unfused kernel did not support packed qkv and packed kv
formats. That means some case might fail since there is no fallback. I
added new AddBiasTranpose cuda kernels for them to support fallback, so
that all supported cases will not fail.
#### Improvements
(4) **QKV workspace size**.
The logic for no_qkv_workspace could be easily out of sync since related
code are scattered in different source files. I refactor the code to
move all related code to one file (attention_prepare_qkv.cu) and add
asserts, so that the logic can be in sync.
(5) **Remove confusing concept of pass past in kv**
parameters.pass_past_in_kv is confusing since the k/v in cross attention
is not past state. Remove it and use parameters.qkv_format ==
Q_K_V_BSNH_BNSH_BNSH instead.
New code does not use past_key/past_value for cross attention, so the
logic is more clear.
(6) **More coverage and less workspace and less transpose of flash and
efficient attention**
Previously, there is one condition does not run flash or efficient
attention:
```
bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;
```
After this change, we can use flash and efficient attention for the
case, and also less workspace.
For example, cross attention with bias, the original code uses two
additional workspaces:
```
transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
add bias: query => q, temp_k_workspace => k, temp_v_workspace => v
```
New logic is like
```
if (has bias)
Add bias to query, key, value, and store in q, k, v workspace
else
Use query, key and value directly as q, k and v in kernel
```
We can see that, we do not need allocate temp_k_workspace and
temp_v_workspace so use less memory. New code saved two transposes in
this case.
Flash and efficient attention supports BSNH or BNSH formats for k and v.
In old code, k/v are also converted to BSNH format. Some is not
necessary. I do some change to convert k/v to BSNH or BNSH case by case.
So that there are more cases can be covered by flash or efficient
attention to improve performance.
(6) **Debugging support**
Previously, there is less debug info. In this change, I add a flag for
debug info in the AttentionData. So that we can output debug info during
the processing.
Also add functions to consolidate the dumping of inputs, QKV processing
and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow
disable dumping from cuda kernel.
#### Summary of changes
(1) Refactoring the CheckInputs, and pass in operator type.
(2) Refactoring the PrepareQKV to support fallback for packed qkv or
packed kv inputs.
(3) Change a few case of PrepareQKV to allow more case covered by flash
and efficient attention.
(4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace
parameters.pass_past_in_kv
(5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of
assumption that key/value has no bias in this case.
(6) Fix thread-safe issue in CumulatedSequenceLengthCache handling.
(7) Add test cases to cover all supported scenarios.
Current support scenarios for MultiHeadAttention for CUDA/CPU:
| Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc
| ---- | ---- | ---- | ------ | ----- | --------- | -------- |
-----|---------
| BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed
| BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in
CPU
| BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in
CPU
| BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for
Q only
| BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only
present
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and
present <br> (not share buffer)
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
|
|
|
if self.has_bias:
|
|
|
|
|
feeds["bias"] = torch.concat([bias_q, bias_k, bias_v], dim=0).reshape(shape_dict["bias"]).contiguous()
|
|
|
|
|
|
2024-08-09 08:31:00 +00:00
|
|
|
# Generate padding mask
|
|
|
|
|
if self.mask_format != AttentionMaskFormat.Mask_None:
|
|
|
|
|
self.mask_index_kv = torch.randint(
|
|
|
|
|
1, self.total_sequence_length + 1, (self.batch_size,), dtype=torch.int32, device=self.device
|
|
|
|
|
)
|
|
|
|
|
if self.past_sequence_length > 0:
|
|
|
|
|
self.mask_index_q = (
|
|
|
|
|
torch.ones(self.batch_size, dtype=torch.int32, device=self.device) * self.sequence_length
|
|
|
|
|
)
|
|
|
|
|
else: # prompt case
|
|
|
|
|
self.mask_index_q = self.mask_index_kv.clone()
|
|
|
|
|
|
|
|
|
|
mask = None
|
|
|
|
|
if self.mask_format == AttentionMaskFormat.Mask_1D_Key_SeqLen:
|
|
|
|
|
mask = self.mask_index_kv.clone()
|
|
|
|
|
elif self.mask_format == AttentionMaskFormat.Mask_2D_Key_PaddingMask:
|
|
|
|
|
k_mask = torch.ones(self.batch_size, 1, self.total_sequence_length, 1, dtype=torch.bool, device=self.device)
|
|
|
|
|
for i, n in enumerate(self.mask_index_kv):
|
|
|
|
|
k_mask[i, :, n:, :] = False
|
|
|
|
|
mask = k_mask.reshape(self.batch_size, self.total_sequence_length)
|
|
|
|
|
else:
|
|
|
|
|
assert self.mask_format == AttentionMaskFormat.Mask_None
|
|
|
|
|
|
|
|
|
|
if mask is not None:
|
|
|
|
|
feeds = {**feeds, "mask": mask.to(dtype=torch.int32)} # mask is int32 (not bool) for MultiHeadAttention op.
|
|
|
|
|
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
if self.has_attn_bias:
|
|
|
|
|
attn_bias = torch.empty(
|
|
|
|
|
(
|
|
|
|
|
1 if self.broadcast_attn_bias_dim_0 else self.batch_size,
|
|
|
|
|
1 if self.broadcast_attn_bias_dim_1 else self.num_heads,
|
|
|
|
|
self.sequence_length,
|
|
|
|
|
self.total_sequence_length,
|
|
|
|
|
),
|
|
|
|
|
device=self.device,
|
|
|
|
|
dtype=dtype,
|
|
|
|
|
).normal_(mean=0, std=0.1)
|
|
|
|
|
feeds["attn_bias"] = attn_bias
|
|
|
|
|
|
2024-06-12 20:04:25 +00:00
|
|
|
return feeds
|
|
|
|
|
|
|
|
|
|
def get_input_output_names(self):
|
|
|
|
|
if self.input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH:
|
2024-08-09 08:31:00 +00:00
|
|
|
inputs, outputs = ["query", "key", "value"], ["output"]
|
|
|
|
|
elif self.input_format == InputFormats.QKV_BSN3H:
|
2024-06-12 20:04:25 +00:00
|
|
|
inputs, outputs = ["query"], ["output"]
|
|
|
|
|
elif self.input_format == InputFormats.Q_KV_BSNH_BSN2H:
|
|
|
|
|
inputs, outputs = ["query", "key"], ["output"]
|
|
|
|
|
else:
|
|
|
|
|
inputs, outputs = ["query", "key", "value"], ["output"]
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
|
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498)
### Description
#### Issues Fixed
(1) **TRT cross attention not thread safe**. [Core changes like
this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6)
are used to make it thread-safe:
* Add an once_flag to CumulatedSequenceLengthCache to make sure it is
only initialized once; and change the cache to be read only after
initialization. Previously, the content is not read-only so it might be
changed by other thread and potentially cause buffer overrun.
* The kernel initialization is not guarded (Although the factory of
kernel loading has static mutex to guard multiple threading), so the
mutable variable might be set by two different threads at the same time.
Add an once_flag to avoid that.
This requires need some workspace computation change as well. So I did
not create a separated pull request.
(2) **Bias for cross attention**
That scenario has assumption that only query has bias, but not for key
and value. However, such assumption is not verified in runtime and there
was no comment of assumption, and there was no test case so the support
of scenario was disabled by mistake. Actually, the scenario is used in
whisper model (TODO: we shall add tests for whisper to CI pipeline, and
also update fusion script to verify such assumptions if needed.)
CUDA/CPU kernels supports bias for cross attention as long as bias is
zero for key and value. I updated the check to support the scenario and
added comments wherever there is such assumption.
(3) **Fallback support**
Previously, unfused kernel did not support packed qkv and packed kv
formats. That means some case might fail since there is no fallback. I
added new AddBiasTranpose cuda kernels for them to support fallback, so
that all supported cases will not fail.
#### Improvements
(4) **QKV workspace size**.
The logic for no_qkv_workspace could be easily out of sync since related
code are scattered in different source files. I refactor the code to
move all related code to one file (attention_prepare_qkv.cu) and add
asserts, so that the logic can be in sync.
(5) **Remove confusing concept of pass past in kv**
parameters.pass_past_in_kv is confusing since the k/v in cross attention
is not past state. Remove it and use parameters.qkv_format ==
Q_K_V_BSNH_BNSH_BNSH instead.
New code does not use past_key/past_value for cross attention, so the
logic is more clear.
(6) **More coverage and less workspace and less transpose of flash and
efficient attention**
Previously, there is one condition does not run flash or efficient
attention:
```
bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;
```
After this change, we can use flash and efficient attention for the
case, and also less workspace.
For example, cross attention with bias, the original code uses two
additional workspaces:
```
transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
add bias: query => q, temp_k_workspace => k, temp_v_workspace => v
```
New logic is like
```
if (has bias)
Add bias to query, key, value, and store in q, k, v workspace
else
Use query, key and value directly as q, k and v in kernel
```
We can see that, we do not need allocate temp_k_workspace and
temp_v_workspace so use less memory. New code saved two transposes in
this case.
Flash and efficient attention supports BSNH or BNSH formats for k and v.
In old code, k/v are also converted to BSNH format. Some is not
necessary. I do some change to convert k/v to BSNH or BNSH case by case.
So that there are more cases can be covered by flash or efficient
attention to improve performance.
(6) **Debugging support**
Previously, there is less debug info. In this change, I add a flag for
debug info in the AttentionData. So that we can output debug info during
the processing.
Also add functions to consolidate the dumping of inputs, QKV processing
and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow
disable dumping from cuda kernel.
#### Summary of changes
(1) Refactoring the CheckInputs, and pass in operator type.
(2) Refactoring the PrepareQKV to support fallback for packed qkv or
packed kv inputs.
(3) Change a few case of PrepareQKV to allow more case covered by flash
and efficient attention.
(4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace
parameters.pass_past_in_kv
(5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of
assumption that key/value has no bias in this case.
(6) Fix thread-safe issue in CumulatedSequenceLengthCache handling.
(7) Add test cases to cover all supported scenarios.
Current support scenarios for MultiHeadAttention for CUDA/CPU:
| Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc
| ---- | ---- | ---- | ------ | ----- | --------- | -------- |
-----|---------
| BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed
| BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in
CPU
| BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in
CPU
| BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for
Q only
| BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only
present
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and
present <br> (not share buffer)
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
|
|
|
if self.has_bias:
|
2024-08-09 08:31:00 +00:00
|
|
|
assert self.input_format != InputFormats.Q_KV_BSNH_BSN2H
|
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498)
### Description
#### Issues Fixed
(1) **TRT cross attention not thread safe**. [Core changes like
this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6)
are used to make it thread-safe:
* Add an once_flag to CumulatedSequenceLengthCache to make sure it is
only initialized once; and change the cache to be read only after
initialization. Previously, the content is not read-only so it might be
changed by other thread and potentially cause buffer overrun.
* The kernel initialization is not guarded (Although the factory of
kernel loading has static mutex to guard multiple threading), so the
mutable variable might be set by two different threads at the same time.
Add an once_flag to avoid that.
This requires need some workspace computation change as well. So I did
not create a separated pull request.
(2) **Bias for cross attention**
That scenario has assumption that only query has bias, but not for key
and value. However, such assumption is not verified in runtime and there
was no comment of assumption, and there was no test case so the support
of scenario was disabled by mistake. Actually, the scenario is used in
whisper model (TODO: we shall add tests for whisper to CI pipeline, and
also update fusion script to verify such assumptions if needed.)
CUDA/CPU kernels supports bias for cross attention as long as bias is
zero for key and value. I updated the check to support the scenario and
added comments wherever there is such assumption.
(3) **Fallback support**
Previously, unfused kernel did not support packed qkv and packed kv
formats. That means some case might fail since there is no fallback. I
added new AddBiasTranpose cuda kernels for them to support fallback, so
that all supported cases will not fail.
#### Improvements
(4) **QKV workspace size**.
The logic for no_qkv_workspace could be easily out of sync since related
code are scattered in different source files. I refactor the code to
move all related code to one file (attention_prepare_qkv.cu) and add
asserts, so that the logic can be in sync.
(5) **Remove confusing concept of pass past in kv**
parameters.pass_past_in_kv is confusing since the k/v in cross attention
is not past state. Remove it and use parameters.qkv_format ==
Q_K_V_BSNH_BNSH_BNSH instead.
New code does not use past_key/past_value for cross attention, so the
logic is more clear.
(6) **More coverage and less workspace and less transpose of flash and
efficient attention**
Previously, there is one condition does not run flash or efficient
attention:
```
bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;
```
After this change, we can use flash and efficient attention for the
case, and also less workspace.
For example, cross attention with bias, the original code uses two
additional workspaces:
```
transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
add bias: query => q, temp_k_workspace => k, temp_v_workspace => v
```
New logic is like
```
if (has bias)
Add bias to query, key, value, and store in q, k, v workspace
else
Use query, key and value directly as q, k and v in kernel
```
We can see that, we do not need allocate temp_k_workspace and
temp_v_workspace so use less memory. New code saved two transposes in
this case.
Flash and efficient attention supports BSNH or BNSH formats for k and v.
In old code, k/v are also converted to BSNH format. Some is not
necessary. I do some change to convert k/v to BSNH or BNSH case by case.
So that there are more cases can be covered by flash or efficient
attention to improve performance.
(6) **Debugging support**
Previously, there is less debug info. In this change, I add a flag for
debug info in the AttentionData. So that we can output debug info during
the processing.
Also add functions to consolidate the dumping of inputs, QKV processing
and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow
disable dumping from cuda kernel.
#### Summary of changes
(1) Refactoring the CheckInputs, and pass in operator type.
(2) Refactoring the PrepareQKV to support fallback for packed qkv or
packed kv inputs.
(3) Change a few case of PrepareQKV to allow more case covered by flash
and efficient attention.
(4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace
parameters.pass_past_in_kv
(5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of
assumption that key/value has no bias in this case.
(6) Fix thread-safe issue in CumulatedSequenceLengthCache handling.
(7) Add test cases to cover all supported scenarios.
Current support scenarios for MultiHeadAttention for CUDA/CPU:
| Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc
| ---- | ---- | ---- | ------ | ----- | --------- | -------- |
-----|---------
| BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed
| BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in
CPU
| BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in
CPU
| BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for
Q only
| BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only
present
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and
present <br> (not share buffer)
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
|
|
|
inputs = [*inputs, "bias"]
|
|
|
|
|
|
2024-08-09 08:31:00 +00:00
|
|
|
if self.mask_format != AttentionMaskFormat.Mask_None:
|
|
|
|
|
inputs = [*inputs, "mask"]
|
|
|
|
|
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
if self.has_attn_bias:
|
|
|
|
|
inputs = [*inputs, "attn_bias"]
|
|
|
|
|
|
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498)
### Description
#### Issues Fixed
(1) **TRT cross attention not thread safe**. [Core changes like
this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6)
are used to make it thread-safe:
* Add an once_flag to CumulatedSequenceLengthCache to make sure it is
only initialized once; and change the cache to be read only after
initialization. Previously, the content is not read-only so it might be
changed by other thread and potentially cause buffer overrun.
* The kernel initialization is not guarded (Although the factory of
kernel loading has static mutex to guard multiple threading), so the
mutable variable might be set by two different threads at the same time.
Add an once_flag to avoid that.
This requires need some workspace computation change as well. So I did
not create a separated pull request.
(2) **Bias for cross attention**
That scenario has assumption that only query has bias, but not for key
and value. However, such assumption is not verified in runtime and there
was no comment of assumption, and there was no test case so the support
of scenario was disabled by mistake. Actually, the scenario is used in
whisper model (TODO: we shall add tests for whisper to CI pipeline, and
also update fusion script to verify such assumptions if needed.)
CUDA/CPU kernels supports bias for cross attention as long as bias is
zero for key and value. I updated the check to support the scenario and
added comments wherever there is such assumption.
(3) **Fallback support**
Previously, unfused kernel did not support packed qkv and packed kv
formats. That means some case might fail since there is no fallback. I
added new AddBiasTranpose cuda kernels for them to support fallback, so
that all supported cases will not fail.
#### Improvements
(4) **QKV workspace size**.
The logic for no_qkv_workspace could be easily out of sync since related
code are scattered in different source files. I refactor the code to
move all related code to one file (attention_prepare_qkv.cu) and add
asserts, so that the logic can be in sync.
(5) **Remove confusing concept of pass past in kv**
parameters.pass_past_in_kv is confusing since the k/v in cross attention
is not past state. Remove it and use parameters.qkv_format ==
Q_K_V_BSNH_BNSH_BNSH instead.
New code does not use past_key/past_value for cross attention, so the
logic is more clear.
(6) **More coverage and less workspace and less transpose of flash and
efficient attention**
Previously, there is one condition does not run flash or efficient
attention:
```
bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;
```
After this change, we can use flash and efficient attention for the
case, and also less workspace.
For example, cross attention with bias, the original code uses two
additional workspaces:
```
transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
add bias: query => q, temp_k_workspace => k, temp_v_workspace => v
```
New logic is like
```
if (has bias)
Add bias to query, key, value, and store in q, k, v workspace
else
Use query, key and value directly as q, k and v in kernel
```
We can see that, we do not need allocate temp_k_workspace and
temp_v_workspace so use less memory. New code saved two transposes in
this case.
Flash and efficient attention supports BSNH or BNSH formats for k and v.
In old code, k/v are also converted to BSNH format. Some is not
necessary. I do some change to convert k/v to BSNH or BNSH case by case.
So that there are more cases can be covered by flash or efficient
attention to improve performance.
(6) **Debugging support**
Previously, there is less debug info. In this change, I add a flag for
debug info in the AttentionData. So that we can output debug info during
the processing.
Also add functions to consolidate the dumping of inputs, QKV processing
and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow
disable dumping from cuda kernel.
#### Summary of changes
(1) Refactoring the CheckInputs, and pass in operator type.
(2) Refactoring the PrepareQKV to support fallback for packed qkv or
packed kv inputs.
(3) Change a few case of PrepareQKV to allow more case covered by flash
and efficient attention.
(4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace
parameters.pass_past_in_kv
(5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of
assumption that key/value has no bias in this case.
(6) Fix thread-safe issue in CumulatedSequenceLengthCache handling.
(7) Add test cases to cover all supported scenarios.
Current support scenarios for MultiHeadAttention for CUDA/CPU:
| Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc
| ---- | ---- | ---- | ------ | ----- | --------- | -------- |
-----|---------
| BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed
| BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in
CPU
| BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in
CPU
| BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for
Q only
| BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only
present
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and
present <br> (not share buffer)
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
|
|
|
if self.has_past_input:
|
|
|
|
|
inputs = [*inputs, "past_key", "past_value"]
|
|
|
|
|
|
|
|
|
|
if self.has_present_output:
|
|
|
|
|
outputs = [*outputs, "present_key", "present_value"]
|
|
|
|
|
|
|
|
|
|
return inputs, outputs
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
|
|
|
|
|
|
2024-06-19 17:23:26 +00:00
|
|
|
def fill_optional_mha_inputs(input_names):
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
inputs = ["query", "key", "value", "bias", "mask", "attn_bias", "past_key", "past_value"]
|
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498)
### Description
#### Issues Fixed
(1) **TRT cross attention not thread safe**. [Core changes like
this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6)
are used to make it thread-safe:
* Add an once_flag to CumulatedSequenceLengthCache to make sure it is
only initialized once; and change the cache to be read only after
initialization. Previously, the content is not read-only so it might be
changed by other thread and potentially cause buffer overrun.
* The kernel initialization is not guarded (Although the factory of
kernel loading has static mutex to guard multiple threading), so the
mutable variable might be set by two different threads at the same time.
Add an once_flag to avoid that.
This requires need some workspace computation change as well. So I did
not create a separated pull request.
(2) **Bias for cross attention**
That scenario has assumption that only query has bias, but not for key
and value. However, such assumption is not verified in runtime and there
was no comment of assumption, and there was no test case so the support
of scenario was disabled by mistake. Actually, the scenario is used in
whisper model (TODO: we shall add tests for whisper to CI pipeline, and
also update fusion script to verify such assumptions if needed.)
CUDA/CPU kernels supports bias for cross attention as long as bias is
zero for key and value. I updated the check to support the scenario and
added comments wherever there is such assumption.
(3) **Fallback support**
Previously, unfused kernel did not support packed qkv and packed kv
formats. That means some case might fail since there is no fallback. I
added new AddBiasTranpose cuda kernels for them to support fallback, so
that all supported cases will not fail.
#### Improvements
(4) **QKV workspace size**.
The logic for no_qkv_workspace could be easily out of sync since related
code are scattered in different source files. I refactor the code to
move all related code to one file (attention_prepare_qkv.cu) and add
asserts, so that the logic can be in sync.
(5) **Remove confusing concept of pass past in kv**
parameters.pass_past_in_kv is confusing since the k/v in cross attention
is not past state. Remove it and use parameters.qkv_format ==
Q_K_V_BSNH_BNSH_BNSH instead.
New code does not use past_key/past_value for cross attention, so the
logic is more clear.
(6) **More coverage and less workspace and less transpose of flash and
efficient attention**
Previously, there is one condition does not run flash or efficient
attention:
```
bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;
```
After this change, we can use flash and efficient attention for the
case, and also less workspace.
For example, cross attention with bias, the original code uses two
additional workspaces:
```
transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
add bias: query => q, temp_k_workspace => k, temp_v_workspace => v
```
New logic is like
```
if (has bias)
Add bias to query, key, value, and store in q, k, v workspace
else
Use query, key and value directly as q, k and v in kernel
```
We can see that, we do not need allocate temp_k_workspace and
temp_v_workspace so use less memory. New code saved two transposes in
this case.
Flash and efficient attention supports BSNH or BNSH formats for k and v.
In old code, k/v are also converted to BSNH format. Some is not
necessary. I do some change to convert k/v to BSNH or BNSH case by case.
So that there are more cases can be covered by flash or efficient
attention to improve performance.
(6) **Debugging support**
Previously, there is less debug info. In this change, I add a flag for
debug info in the AttentionData. So that we can output debug info during
the processing.
Also add functions to consolidate the dumping of inputs, QKV processing
and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow
disable dumping from cuda kernel.
#### Summary of changes
(1) Refactoring the CheckInputs, and pass in operator type.
(2) Refactoring the PrepareQKV to support fallback for packed qkv or
packed kv inputs.
(3) Change a few case of PrepareQKV to allow more case covered by flash
and efficient attention.
(4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace
parameters.pass_past_in_kv
(5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of
assumption that key/value has no bias in this case.
(6) Fix thread-safe issue in CumulatedSequenceLengthCache handling.
(7) Add test cases to cover all supported scenarios.
Current support scenarios for MultiHeadAttention for CUDA/CPU:
| Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc
| ---- | ---- | ---- | ------ | ----- | --------- | -------- |
-----|---------
| BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed
| BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in
CPU
| BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in
CPU
| BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for
Q only
| BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only
present
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and
present <br> (not share buffer)
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
|
|
|
|
|
|
|
|
# Remove optional inputs that are not in input_names with empty string
|
|
|
|
|
inputs_with_optional = [input if input in input_names else "" for input in inputs]
|
|
|
|
|
|
|
|
|
|
# Remove empty string at the end of the list.
|
|
|
|
|
while inputs_with_optional[-1] == "":
|
|
|
|
|
inputs_with_optional.pop(-1)
|
|
|
|
|
|
|
|
|
|
return inputs_with_optional
|
2024-06-19 17:23:26 +00:00
|
|
|
|
|
|
|
|
|
[CUDA] FusedMHARunnerFP16v2 thread-safe (#21420)
### Description
- [x] Rewrite FusedMHARunnerFP16v2 to make it thread-safe.
- [x] Add multi-threading tests
Previously, the kernel parameters params is stored as a member of mha
runner, which means that different threads might change the params at
the same time and impacts the other threads.
For example, if batch_size and seq_len was changed by another thread to
larger values in setup(...), buffer overrun might happen in run(...)
because a kernel could read/write memory out of range of allocated
buffers.
In new implementation, I change the api and remove mutable member
variables to make it thread safe. Below is summary of change:
Before:
```
class FusedMHARunnerFP16v2::mhaImpl {
void setup(int seq_len, int batch_size) {
// change scalar params
}
void run(input, output) {
// change params for input and output pointers
// launch kernel using params
}
Fused_multihead_attention_params_v2 params; // mutable, not thread-safe
}
```
After:
```
class FusedMHARunnerFP16v2::FmhaImpl {
void setup(int seq_len, int batch_size, Fused_multihead_attention_params_v2& params) {
// change params
}
void run(params, input, output) {
// change params with input and output pointers
// launch kernel using params
}
}
```
### Motivation and Context
https://github.com/microsoft/onnxruntime/issues/18854
https://github.com/microsoft/onnxruntime/issues/21413
2024-07-22 17:41:08 +00:00
|
|
|
def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig, use_symbolic_shape=False):
|
2024-06-12 20:04:25 +00:00
|
|
|
input_names, output_names = config.get_input_output_names()
|
2024-06-19 17:23:26 +00:00
|
|
|
|
2024-06-12 20:04:25 +00:00
|
|
|
float_type = TensorProto.FLOAT16 if config.dtype == torch.float16 else TensorProto.FLOAT
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
nodes = [
|
|
|
|
|
helper.make_node(
|
|
|
|
|
"MultiHeadAttention",
|
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498)
### Description
#### Issues Fixed
(1) **TRT cross attention not thread safe**. [Core changes like
this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6)
are used to make it thread-safe:
* Add an once_flag to CumulatedSequenceLengthCache to make sure it is
only initialized once; and change the cache to be read only after
initialization. Previously, the content is not read-only so it might be
changed by other thread and potentially cause buffer overrun.
* The kernel initialization is not guarded (Although the factory of
kernel loading has static mutex to guard multiple threading), so the
mutable variable might be set by two different threads at the same time.
Add an once_flag to avoid that.
This requires need some workspace computation change as well. So I did
not create a separated pull request.
(2) **Bias for cross attention**
That scenario has assumption that only query has bias, but not for key
and value. However, such assumption is not verified in runtime and there
was no comment of assumption, and there was no test case so the support
of scenario was disabled by mistake. Actually, the scenario is used in
whisper model (TODO: we shall add tests for whisper to CI pipeline, and
also update fusion script to verify such assumptions if needed.)
CUDA/CPU kernels supports bias for cross attention as long as bias is
zero for key and value. I updated the check to support the scenario and
added comments wherever there is such assumption.
(3) **Fallback support**
Previously, unfused kernel did not support packed qkv and packed kv
formats. That means some case might fail since there is no fallback. I
added new AddBiasTranpose cuda kernels for them to support fallback, so
that all supported cases will not fail.
#### Improvements
(4) **QKV workspace size**.
The logic for no_qkv_workspace could be easily out of sync since related
code are scattered in different source files. I refactor the code to
move all related code to one file (attention_prepare_qkv.cu) and add
asserts, so that the logic can be in sync.
(5) **Remove confusing concept of pass past in kv**
parameters.pass_past_in_kv is confusing since the k/v in cross attention
is not past state. Remove it and use parameters.qkv_format ==
Q_K_V_BSNH_BNSH_BNSH instead.
New code does not use past_key/past_value for cross attention, so the
logic is more clear.
(6) **More coverage and less workspace and less transpose of flash and
efficient attention**
Previously, there is one condition does not run flash or efficient
attention:
```
bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;
```
After this change, we can use flash and efficient attention for the
case, and also less workspace.
For example, cross attention with bias, the original code uses two
additional workspaces:
```
transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
add bias: query => q, temp_k_workspace => k, temp_v_workspace => v
```
New logic is like
```
if (has bias)
Add bias to query, key, value, and store in q, k, v workspace
else
Use query, key and value directly as q, k and v in kernel
```
We can see that, we do not need allocate temp_k_workspace and
temp_v_workspace so use less memory. New code saved two transposes in
this case.
Flash and efficient attention supports BSNH or BNSH formats for k and v.
In old code, k/v are also converted to BSNH format. Some is not
necessary. I do some change to convert k/v to BSNH or BNSH case by case.
So that there are more cases can be covered by flash or efficient
attention to improve performance.
(6) **Debugging support**
Previously, there is less debug info. In this change, I add a flag for
debug info in the AttentionData. So that we can output debug info during
the processing.
Also add functions to consolidate the dumping of inputs, QKV processing
and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow
disable dumping from cuda kernel.
#### Summary of changes
(1) Refactoring the CheckInputs, and pass in operator type.
(2) Refactoring the PrepareQKV to support fallback for packed qkv or
packed kv inputs.
(3) Change a few case of PrepareQKV to allow more case covered by flash
and efficient attention.
(4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace
parameters.pass_past_in_kv
(5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of
assumption that key/value has no bias in this case.
(6) Fix thread-safe issue in CumulatedSequenceLengthCache handling.
(7) Add test cases to cover all supported scenarios.
Current support scenarios for MultiHeadAttention for CUDA/CPU:
| Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc
| ---- | ---- | ---- | ------ | ----- | --------- | -------- |
-----|---------
| BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed
| BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in
CPU
| BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in
CPU
| BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for
Q only
| BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only
present
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and
present <br> (not share buffer)
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
|
|
|
fill_optional_mha_inputs(input_names),
|
2024-06-12 20:04:25 +00:00
|
|
|
output_names,
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
"MultiHeadAttention_0",
|
|
|
|
|
num_heads=config.num_heads,
|
2024-06-19 17:23:26 +00:00
|
|
|
unidirectional=int(config.causal),
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
scale=config.scale,
|
2024-08-09 08:31:00 +00:00
|
|
|
mask_filter_value=float("-inf"),
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
domain="com.microsoft",
|
|
|
|
|
),
|
|
|
|
|
]
|
|
|
|
|
|
[CUDA] FusedMHARunnerFP16v2 thread-safe (#21420)
### Description
- [x] Rewrite FusedMHARunnerFP16v2 to make it thread-safe.
- [x] Add multi-threading tests
Previously, the kernel parameters params is stored as a member of mha
runner, which means that different threads might change the params at
the same time and impacts the other threads.
For example, if batch_size and seq_len was changed by another thread to
larger values in setup(...), buffer overrun might happen in run(...)
because a kernel could read/write memory out of range of allocated
buffers.
In new implementation, I change the api and remove mutable member
variables to make it thread safe. Below is summary of change:
Before:
```
class FusedMHARunnerFP16v2::mhaImpl {
void setup(int seq_len, int batch_size) {
// change scalar params
}
void run(input, output) {
// change params for input and output pointers
// launch kernel using params
}
Fused_multihead_attention_params_v2 params; // mutable, not thread-safe
}
```
After:
```
class FusedMHARunnerFP16v2::FmhaImpl {
void setup(int seq_len, int batch_size, Fused_multihead_attention_params_v2& params) {
// change params
}
void run(params, input, output) {
// change params with input and output pointers
// launch kernel using params
}
}
```
### Motivation and Context
https://github.com/microsoft/onnxruntime/issues/18854
https://github.com/microsoft/onnxruntime/issues/21413
2024-07-22 17:41:08 +00:00
|
|
|
shape_dict = config.symbolic_shape_dict() if use_symbolic_shape else config.shape_dict()
|
2024-06-12 20:04:25 +00:00
|
|
|
inputs = [
|
2024-08-09 08:31:00 +00:00
|
|
|
helper.make_tensor_value_info(
|
|
|
|
|
input_name, TensorProto.INT32 if input_name == "mask" else float_type, list(shape_dict[input_name])
|
|
|
|
|
)
|
2024-06-12 20:04:25 +00:00
|
|
|
for input_name in input_names
|
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498)
### Description
#### Issues Fixed
(1) **TRT cross attention not thread safe**. [Core changes like
this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6)
are used to make it thread-safe:
* Add an once_flag to CumulatedSequenceLengthCache to make sure it is
only initialized once; and change the cache to be read only after
initialization. Previously, the content is not read-only so it might be
changed by other thread and potentially cause buffer overrun.
* The kernel initialization is not guarded (Although the factory of
kernel loading has static mutex to guard multiple threading), so the
mutable variable might be set by two different threads at the same time.
Add an once_flag to avoid that.
This requires need some workspace computation change as well. So I did
not create a separated pull request.
(2) **Bias for cross attention**
That scenario has assumption that only query has bias, but not for key
and value. However, such assumption is not verified in runtime and there
was no comment of assumption, and there was no test case so the support
of scenario was disabled by mistake. Actually, the scenario is used in
whisper model (TODO: we shall add tests for whisper to CI pipeline, and
also update fusion script to verify such assumptions if needed.)
CUDA/CPU kernels supports bias for cross attention as long as bias is
zero for key and value. I updated the check to support the scenario and
added comments wherever there is such assumption.
(3) **Fallback support**
Previously, unfused kernel did not support packed qkv and packed kv
formats. That means some case might fail since there is no fallback. I
added new AddBiasTranpose cuda kernels for them to support fallback, so
that all supported cases will not fail.
#### Improvements
(4) **QKV workspace size**.
The logic for no_qkv_workspace could be easily out of sync since related
code are scattered in different source files. I refactor the code to
move all related code to one file (attention_prepare_qkv.cu) and add
asserts, so that the logic can be in sync.
(5) **Remove confusing concept of pass past in kv**
parameters.pass_past_in_kv is confusing since the k/v in cross attention
is not past state. Remove it and use parameters.qkv_format ==
Q_K_V_BSNH_BNSH_BNSH instead.
New code does not use past_key/past_value for cross attention, so the
logic is more clear.
(6) **More coverage and less workspace and less transpose of flash and
efficient attention**
Previously, there is one condition does not run flash or efficient
attention:
```
bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;
```
After this change, we can use flash and efficient attention for the
case, and also less workspace.
For example, cross attention with bias, the original code uses two
additional workspaces:
```
transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
add bias: query => q, temp_k_workspace => k, temp_v_workspace => v
```
New logic is like
```
if (has bias)
Add bias to query, key, value, and store in q, k, v workspace
else
Use query, key and value directly as q, k and v in kernel
```
We can see that, we do not need allocate temp_k_workspace and
temp_v_workspace so use less memory. New code saved two transposes in
this case.
Flash and efficient attention supports BSNH or BNSH formats for k and v.
In old code, k/v are also converted to BSNH format. Some is not
necessary. I do some change to convert k/v to BSNH or BNSH case by case.
So that there are more cases can be covered by flash or efficient
attention to improve performance.
(6) **Debugging support**
Previously, there is less debug info. In this change, I add a flag for
debug info in the AttentionData. So that we can output debug info during
the processing.
Also add functions to consolidate the dumping of inputs, QKV processing
and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow
disable dumping from cuda kernel.
#### Summary of changes
(1) Refactoring the CheckInputs, and pass in operator type.
(2) Refactoring the PrepareQKV to support fallback for packed qkv or
packed kv inputs.
(3) Change a few case of PrepareQKV to allow more case covered by flash
and efficient attention.
(4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace
parameters.pass_past_in_kv
(5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of
assumption that key/value has no bias in this case.
(6) Fix thread-safe issue in CumulatedSequenceLengthCache handling.
(7) Add test cases to cover all supported scenarios.
Current support scenarios for MultiHeadAttention for CUDA/CPU:
| Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc
| ---- | ---- | ---- | ------ | ----- | --------- | -------- |
-----|---------
| BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed
| BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in
CPU
| BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in
CPU
| BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for
Q only
| BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only
present
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and
present <br> (not share buffer)
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
|
|
|
if input_name
|
2024-06-12 20:04:25 +00:00
|
|
|
]
|
2024-06-19 17:23:26 +00:00
|
|
|
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
outputs = [
|
2024-06-12 20:04:25 +00:00
|
|
|
helper.make_tensor_value_info(output_name, float_type, list(shape_dict[output_name]))
|
|
|
|
|
for output_name in output_names
|
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498)
### Description
#### Issues Fixed
(1) **TRT cross attention not thread safe**. [Core changes like
this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6)
are used to make it thread-safe:
* Add an once_flag to CumulatedSequenceLengthCache to make sure it is
only initialized once; and change the cache to be read only after
initialization. Previously, the content is not read-only so it might be
changed by other thread and potentially cause buffer overrun.
* The kernel initialization is not guarded (Although the factory of
kernel loading has static mutex to guard multiple threading), so the
mutable variable might be set by two different threads at the same time.
Add an once_flag to avoid that.
This requires need some workspace computation change as well. So I did
not create a separated pull request.
(2) **Bias for cross attention**
That scenario has assumption that only query has bias, but not for key
and value. However, such assumption is not verified in runtime and there
was no comment of assumption, and there was no test case so the support
of scenario was disabled by mistake. Actually, the scenario is used in
whisper model (TODO: we shall add tests for whisper to CI pipeline, and
also update fusion script to verify such assumptions if needed.)
CUDA/CPU kernels supports bias for cross attention as long as bias is
zero for key and value. I updated the check to support the scenario and
added comments wherever there is such assumption.
(3) **Fallback support**
Previously, unfused kernel did not support packed qkv and packed kv
formats. That means some case might fail since there is no fallback. I
added new AddBiasTranpose cuda kernels for them to support fallback, so
that all supported cases will not fail.
#### Improvements
(4) **QKV workspace size**.
The logic for no_qkv_workspace could be easily out of sync since related
code are scattered in different source files. I refactor the code to
move all related code to one file (attention_prepare_qkv.cu) and add
asserts, so that the logic can be in sync.
(5) **Remove confusing concept of pass past in kv**
parameters.pass_past_in_kv is confusing since the k/v in cross attention
is not past state. Remove it and use parameters.qkv_format ==
Q_K_V_BSNH_BNSH_BNSH instead.
New code does not use past_key/past_value for cross attention, so the
logic is more clear.
(6) **More coverage and less workspace and less transpose of flash and
efficient attention**
Previously, there is one condition does not run flash or efficient
attention:
```
bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;
```
After this change, we can use flash and efficient attention for the
case, and also less workspace.
For example, cross attention with bias, the original code uses two
additional workspaces:
```
transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
add bias: query => q, temp_k_workspace => k, temp_v_workspace => v
```
New logic is like
```
if (has bias)
Add bias to query, key, value, and store in q, k, v workspace
else
Use query, key and value directly as q, k and v in kernel
```
We can see that, we do not need allocate temp_k_workspace and
temp_v_workspace so use less memory. New code saved two transposes in
this case.
Flash and efficient attention supports BSNH or BNSH formats for k and v.
In old code, k/v are also converted to BSNH format. Some is not
necessary. I do some change to convert k/v to BSNH or BNSH case by case.
So that there are more cases can be covered by flash or efficient
attention to improve performance.
(6) **Debugging support**
Previously, there is less debug info. In this change, I add a flag for
debug info in the AttentionData. So that we can output debug info during
the processing.
Also add functions to consolidate the dumping of inputs, QKV processing
and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow
disable dumping from cuda kernel.
#### Summary of changes
(1) Refactoring the CheckInputs, and pass in operator type.
(2) Refactoring the PrepareQKV to support fallback for packed qkv or
packed kv inputs.
(3) Change a few case of PrepareQKV to allow more case covered by flash
and efficient attention.
(4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace
parameters.pass_past_in_kv
(5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of
assumption that key/value has no bias in this case.
(6) Fix thread-safe issue in CumulatedSequenceLengthCache handling.
(7) Add test cases to cover all supported scenarios.
Current support scenarios for MultiHeadAttention for CUDA/CPU:
| Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc
| ---- | ---- | ---- | ------ | ----- | --------- | -------- |
-----|---------
| BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed
| BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in
CPU
| BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in
CPU
| BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for
Q only
| BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only
present
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and
present <br> (not share buffer)
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
|
|
|
if output_name
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
]
|
|
|
|
|
|
|
|
|
|
graph = helper.make_graph(
|
|
|
|
|
nodes,
|
|
|
|
|
"MultiHeadAttention_Graph",
|
|
|
|
|
inputs,
|
|
|
|
|
outputs,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
model = helper.make_model(graph)
|
2024-06-19 17:23:26 +00:00
|
|
|
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
return model.SerializeToString()
|
|
|
|
|
|
|
|
|
|
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
def create_ort_session(
|
2024-06-12 20:04:25 +00:00
|
|
|
config: MultiHeadAttentionConfig,
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
session_options=None,
|
|
|
|
|
attention_kernel=SdpaKernel.DEFAULT,
|
|
|
|
|
use_symbolic_shape: bool = True,
|
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498)
### Description
#### Issues Fixed
(1) **TRT cross attention not thread safe**. [Core changes like
this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6)
are used to make it thread-safe:
* Add an once_flag to CumulatedSequenceLengthCache to make sure it is
only initialized once; and change the cache to be read only after
initialization. Previously, the content is not read-only so it might be
changed by other thread and potentially cause buffer overrun.
* The kernel initialization is not guarded (Although the factory of
kernel loading has static mutex to guard multiple threading), so the
mutable variable might be set by two different threads at the same time.
Add an once_flag to avoid that.
This requires need some workspace computation change as well. So I did
not create a separated pull request.
(2) **Bias for cross attention**
That scenario has assumption that only query has bias, but not for key
and value. However, such assumption is not verified in runtime and there
was no comment of assumption, and there was no test case so the support
of scenario was disabled by mistake. Actually, the scenario is used in
whisper model (TODO: we shall add tests for whisper to CI pipeline, and
also update fusion script to verify such assumptions if needed.)
CUDA/CPU kernels supports bias for cross attention as long as bias is
zero for key and value. I updated the check to support the scenario and
added comments wherever there is such assumption.
(3) **Fallback support**
Previously, unfused kernel did not support packed qkv and packed kv
formats. That means some case might fail since there is no fallback. I
added new AddBiasTranpose cuda kernels for them to support fallback, so
that all supported cases will not fail.
#### Improvements
(4) **QKV workspace size**.
The logic for no_qkv_workspace could be easily out of sync since related
code are scattered in different source files. I refactor the code to
move all related code to one file (attention_prepare_qkv.cu) and add
asserts, so that the logic can be in sync.
(5) **Remove confusing concept of pass past in kv**
parameters.pass_past_in_kv is confusing since the k/v in cross attention
is not past state. Remove it and use parameters.qkv_format ==
Q_K_V_BSNH_BNSH_BNSH instead.
New code does not use past_key/past_value for cross attention, so the
logic is more clear.
(6) **More coverage and less workspace and less transpose of flash and
efficient attention**
Previously, there is one condition does not run flash or efficient
attention:
```
bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;
```
After this change, we can use flash and efficient attention for the
case, and also less workspace.
For example, cross attention with bias, the original code uses two
additional workspaces:
```
transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
add bias: query => q, temp_k_workspace => k, temp_v_workspace => v
```
New logic is like
```
if (has bias)
Add bias to query, key, value, and store in q, k, v workspace
else
Use query, key and value directly as q, k and v in kernel
```
We can see that, we do not need allocate temp_k_workspace and
temp_v_workspace so use less memory. New code saved two transposes in
this case.
Flash and efficient attention supports BSNH or BNSH formats for k and v.
In old code, k/v are also converted to BSNH format. Some is not
necessary. I do some change to convert k/v to BSNH or BNSH case by case.
So that there are more cases can be covered by flash or efficient
attention to improve performance.
(6) **Debugging support**
Previously, there is less debug info. In this change, I add a flag for
debug info in the AttentionData. So that we can output debug info during
the processing.
Also add functions to consolidate the dumping of inputs, QKV processing
and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow
disable dumping from cuda kernel.
#### Summary of changes
(1) Refactoring the CheckInputs, and pass in operator type.
(2) Refactoring the PrepareQKV to support fallback for packed qkv or
packed kv inputs.
(3) Change a few case of PrepareQKV to allow more case covered by flash
and efficient attention.
(4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace
parameters.pass_past_in_kv
(5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of
assumption that key/value has no bias in this case.
(6) Fix thread-safe issue in CumulatedSequenceLengthCache handling.
(7) Add test cases to cover all supported scenarios.
Current support scenarios for MultiHeadAttention for CUDA/CPU:
| Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc
| ---- | ---- | ---- | ------ | ----- | --------- | -------- |
-----|---------
| BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed
| BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in
CPU
| BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in
CPU
| BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for
Q only
| BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only
present
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and
present <br> (not share buffer)
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
|
|
|
use_tf32: bool = True,
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
) -> CudaSession:
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
if config.verbose:
|
|
|
|
|
print(f"create session for {vars(config)}")
|
|
|
|
|
onnx_model_str = create_multi_head_attention_onnx_model(config, use_symbolic_shape=use_symbolic_shape)
|
2024-06-12 20:04:25 +00:00
|
|
|
|
2024-06-19 17:23:26 +00:00
|
|
|
if config.provider == "CUDAExecutionProvider":
|
|
|
|
|
device_id = torch.cuda.current_device() if isinstance(config.device, str) else config.device.index
|
|
|
|
|
provider_options = CudaSession.get_cuda_provider_options(device_id, config.enable_cuda_graph)
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
provider_options["sdpa_kernel"] = int(attention_kernel)
|
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498)
### Description
#### Issues Fixed
(1) **TRT cross attention not thread safe**. [Core changes like
this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6)
are used to make it thread-safe:
* Add an once_flag to CumulatedSequenceLengthCache to make sure it is
only initialized once; and change the cache to be read only after
initialization. Previously, the content is not read-only so it might be
changed by other thread and potentially cause buffer overrun.
* The kernel initialization is not guarded (Although the factory of
kernel loading has static mutex to guard multiple threading), so the
mutable variable might be set by two different threads at the same time.
Add an once_flag to avoid that.
This requires need some workspace computation change as well. So I did
not create a separated pull request.
(2) **Bias for cross attention**
That scenario has assumption that only query has bias, but not for key
and value. However, such assumption is not verified in runtime and there
was no comment of assumption, and there was no test case so the support
of scenario was disabled by mistake. Actually, the scenario is used in
whisper model (TODO: we shall add tests for whisper to CI pipeline, and
also update fusion script to verify such assumptions if needed.)
CUDA/CPU kernels supports bias for cross attention as long as bias is
zero for key and value. I updated the check to support the scenario and
added comments wherever there is such assumption.
(3) **Fallback support**
Previously, unfused kernel did not support packed qkv and packed kv
formats. That means some case might fail since there is no fallback. I
added new AddBiasTranpose cuda kernels for them to support fallback, so
that all supported cases will not fail.
#### Improvements
(4) **QKV workspace size**.
The logic for no_qkv_workspace could be easily out of sync since related
code are scattered in different source files. I refactor the code to
move all related code to one file (attention_prepare_qkv.cu) and add
asserts, so that the logic can be in sync.
(5) **Remove confusing concept of pass past in kv**
parameters.pass_past_in_kv is confusing since the k/v in cross attention
is not past state. Remove it and use parameters.qkv_format ==
Q_K_V_BSNH_BNSH_BNSH instead.
New code does not use past_key/past_value for cross attention, so the
logic is more clear.
(6) **More coverage and less workspace and less transpose of flash and
efficient attention**
Previously, there is one condition does not run flash or efficient
attention:
```
bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;
```
After this change, we can use flash and efficient attention for the
case, and also less workspace.
For example, cross attention with bias, the original code uses two
additional workspaces:
```
transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
add bias: query => q, temp_k_workspace => k, temp_v_workspace => v
```
New logic is like
```
if (has bias)
Add bias to query, key, value, and store in q, k, v workspace
else
Use query, key and value directly as q, k and v in kernel
```
We can see that, we do not need allocate temp_k_workspace and
temp_v_workspace so use less memory. New code saved two transposes in
this case.
Flash and efficient attention supports BSNH or BNSH formats for k and v.
In old code, k/v are also converted to BSNH format. Some is not
necessary. I do some change to convert k/v to BSNH or BNSH case by case.
So that there are more cases can be covered by flash or efficient
attention to improve performance.
(6) **Debugging support**
Previously, there is less debug info. In this change, I add a flag for
debug info in the AttentionData. So that we can output debug info during
the processing.
Also add functions to consolidate the dumping of inputs, QKV processing
and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow
disable dumping from cuda kernel.
#### Summary of changes
(1) Refactoring the CheckInputs, and pass in operator type.
(2) Refactoring the PrepareQKV to support fallback for packed qkv or
packed kv inputs.
(3) Change a few case of PrepareQKV to allow more case covered by flash
and efficient attention.
(4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace
parameters.pass_past_in_kv
(5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of
assumption that key/value has no bias in this case.
(6) Fix thread-safe issue in CumulatedSequenceLengthCache handling.
(7) Add test cases to cover all supported scenarios.
Current support scenarios for MultiHeadAttention for CUDA/CPU:
| Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc
| ---- | ---- | ---- | ------ | ----- | --------- | -------- |
-----|---------
| BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed
| BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in
CPU
| BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in
CPU
| BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for
Q only
| BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only
present
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and
present <br> (not share buffer)
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
|
|
|
provider_options["use_tf32"] = int(use_tf32)
|
2024-06-19 17:23:26 +00:00
|
|
|
providers = [(config.provider, provider_options), "CPUExecutionProvider"]
|
2024-06-12 20:04:25 +00:00
|
|
|
else:
|
|
|
|
|
providers = ["CPUExecutionProvider"]
|
|
|
|
|
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
ort_session = InferenceSession(onnx_model_str, session_options, providers=providers)
|
|
|
|
|
return ort_session
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_session(
|
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498)
### Description
#### Issues Fixed
(1) **TRT cross attention not thread safe**. [Core changes like
this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6)
are used to make it thread-safe:
* Add an once_flag to CumulatedSequenceLengthCache to make sure it is
only initialized once; and change the cache to be read only after
initialization. Previously, the content is not read-only so it might be
changed by other thread and potentially cause buffer overrun.
* The kernel initialization is not guarded (Although the factory of
kernel loading has static mutex to guard multiple threading), so the
mutable variable might be set by two different threads at the same time.
Add an once_flag to avoid that.
This requires need some workspace computation change as well. So I did
not create a separated pull request.
(2) **Bias for cross attention**
That scenario has assumption that only query has bias, but not for key
and value. However, such assumption is not verified in runtime and there
was no comment of assumption, and there was no test case so the support
of scenario was disabled by mistake. Actually, the scenario is used in
whisper model (TODO: we shall add tests for whisper to CI pipeline, and
also update fusion script to verify such assumptions if needed.)
CUDA/CPU kernels supports bias for cross attention as long as bias is
zero for key and value. I updated the check to support the scenario and
added comments wherever there is such assumption.
(3) **Fallback support**
Previously, unfused kernel did not support packed qkv and packed kv
formats. That means some case might fail since there is no fallback. I
added new AddBiasTranpose cuda kernels for them to support fallback, so
that all supported cases will not fail.
#### Improvements
(4) **QKV workspace size**.
The logic for no_qkv_workspace could be easily out of sync since related
code are scattered in different source files. I refactor the code to
move all related code to one file (attention_prepare_qkv.cu) and add
asserts, so that the logic can be in sync.
(5) **Remove confusing concept of pass past in kv**
parameters.pass_past_in_kv is confusing since the k/v in cross attention
is not past state. Remove it and use parameters.qkv_format ==
Q_K_V_BSNH_BNSH_BNSH instead.
New code does not use past_key/past_value for cross attention, so the
logic is more clear.
(6) **More coverage and less workspace and less transpose of flash and
efficient attention**
Previously, there is one condition does not run flash or efficient
attention:
```
bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;
```
After this change, we can use flash and efficient attention for the
case, and also less workspace.
For example, cross attention with bias, the original code uses two
additional workspaces:
```
transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
add bias: query => q, temp_k_workspace => k, temp_v_workspace => v
```
New logic is like
```
if (has bias)
Add bias to query, key, value, and store in q, k, v workspace
else
Use query, key and value directly as q, k and v in kernel
```
We can see that, we do not need allocate temp_k_workspace and
temp_v_workspace so use less memory. New code saved two transposes in
this case.
Flash and efficient attention supports BSNH or BNSH formats for k and v.
In old code, k/v are also converted to BSNH format. Some is not
necessary. I do some change to convert k/v to BSNH or BNSH case by case.
So that there are more cases can be covered by flash or efficient
attention to improve performance.
(6) **Debugging support**
Previously, there is less debug info. In this change, I add a flag for
debug info in the AttentionData. So that we can output debug info during
the processing.
Also add functions to consolidate the dumping of inputs, QKV processing
and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow
disable dumping from cuda kernel.
#### Summary of changes
(1) Refactoring the CheckInputs, and pass in operator type.
(2) Refactoring the PrepareQKV to support fallback for packed qkv or
packed kv inputs.
(3) Change a few case of PrepareQKV to allow more case covered by flash
and efficient attention.
(4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace
parameters.pass_past_in_kv
(5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of
assumption that key/value has no bias in this case.
(6) Fix thread-safe issue in CumulatedSequenceLengthCache handling.
(7) Add test cases to cover all supported scenarios.
Current support scenarios for MultiHeadAttention for CUDA/CPU:
| Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc
| ---- | ---- | ---- | ------ | ----- | --------- | -------- |
-----|---------
| BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed
| BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in
CPU
| BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in
CPU
| BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for
Q only
| BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only
present
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and
present <br> (not share buffer)
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
|
|
|
config: MultiHeadAttentionConfig, session_options=None, attention_kernel=SdpaKernel.DEFAULT, use_tf32: bool = True
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
) -> CudaSession:
|
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498)
### Description
#### Issues Fixed
(1) **TRT cross attention not thread safe**. [Core changes like
this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6)
are used to make it thread-safe:
* Add an once_flag to CumulatedSequenceLengthCache to make sure it is
only initialized once; and change the cache to be read only after
initialization. Previously, the content is not read-only so it might be
changed by other thread and potentially cause buffer overrun.
* The kernel initialization is not guarded (Although the factory of
kernel loading has static mutex to guard multiple threading), so the
mutable variable might be set by two different threads at the same time.
Add an once_flag to avoid that.
This requires need some workspace computation change as well. So I did
not create a separated pull request.
(2) **Bias for cross attention**
That scenario has assumption that only query has bias, but not for key
and value. However, such assumption is not verified in runtime and there
was no comment of assumption, and there was no test case so the support
of scenario was disabled by mistake. Actually, the scenario is used in
whisper model (TODO: we shall add tests for whisper to CI pipeline, and
also update fusion script to verify such assumptions if needed.)
CUDA/CPU kernels supports bias for cross attention as long as bias is
zero for key and value. I updated the check to support the scenario and
added comments wherever there is such assumption.
(3) **Fallback support**
Previously, unfused kernel did not support packed qkv and packed kv
formats. That means some case might fail since there is no fallback. I
added new AddBiasTranpose cuda kernels for them to support fallback, so
that all supported cases will not fail.
#### Improvements
(4) **QKV workspace size**.
The logic for no_qkv_workspace could be easily out of sync since related
code are scattered in different source files. I refactor the code to
move all related code to one file (attention_prepare_qkv.cu) and add
asserts, so that the logic can be in sync.
(5) **Remove confusing concept of pass past in kv**
parameters.pass_past_in_kv is confusing since the k/v in cross attention
is not past state. Remove it and use parameters.qkv_format ==
Q_K_V_BSNH_BNSH_BNSH instead.
New code does not use past_key/past_value for cross attention, so the
logic is more clear.
(6) **More coverage and less workspace and less transpose of flash and
efficient attention**
Previously, there is one condition does not run flash or efficient
attention:
```
bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;
```
After this change, we can use flash and efficient attention for the
case, and also less workspace.
For example, cross attention with bias, the original code uses two
additional workspaces:
```
transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
add bias: query => q, temp_k_workspace => k, temp_v_workspace => v
```
New logic is like
```
if (has bias)
Add bias to query, key, value, and store in q, k, v workspace
else
Use query, key and value directly as q, k and v in kernel
```
We can see that, we do not need allocate temp_k_workspace and
temp_v_workspace so use less memory. New code saved two transposes in
this case.
Flash and efficient attention supports BSNH or BNSH formats for k and v.
In old code, k/v are also converted to BSNH format. Some is not
necessary. I do some change to convert k/v to BSNH or BNSH case by case.
So that there are more cases can be covered by flash or efficient
attention to improve performance.
(6) **Debugging support**
Previously, there is less debug info. In this change, I add a flag for
debug info in the AttentionData. So that we can output debug info during
the processing.
Also add functions to consolidate the dumping of inputs, QKV processing
and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow
disable dumping from cuda kernel.
#### Summary of changes
(1) Refactoring the CheckInputs, and pass in operator type.
(2) Refactoring the PrepareQKV to support fallback for packed qkv or
packed kv inputs.
(3) Change a few case of PrepareQKV to allow more case covered by flash
and efficient attention.
(4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace
parameters.pass_past_in_kv
(5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of
assumption that key/value has no bias in this case.
(6) Fix thread-safe issue in CumulatedSequenceLengthCache handling.
(7) Add test cases to cover all supported scenarios.
Current support scenarios for MultiHeadAttention for CUDA/CPU:
| Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc
| ---- | ---- | ---- | ------ | ----- | --------- | -------- |
-----|---------
| BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed
| BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in
CPU
| BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in
CPU
| BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for
Q only
| BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only
present
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and
present <br> (not share buffer)
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
|
|
|
ort_session = create_ort_session(
|
|
|
|
|
config, session_options, attention_kernel, use_symbolic_shape=False, use_tf32=use_tf32
|
|
|
|
|
)
|
2024-06-19 17:23:26 +00:00
|
|
|
cuda_session = CudaSession(ort_session, config.device, config.enable_cuda_graph)
|
2024-06-12 20:04:25 +00:00
|
|
|
shape_dict = config.shape_dict()
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
cuda_session.allocate_buffers(shape_dict)
|
|
|
|
|
return cuda_session
|
|
|
|
|
|
|
|
|
|
|
2024-06-12 20:04:25 +00:00
|
|
|
class OrtMultiHeadAttention:
|
|
|
|
|
"""A wrapper of ORT MultiHeadAttention to test relevance and performance."""
|
|
|
|
|
|
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498)
### Description
#### Issues Fixed
(1) **TRT cross attention not thread safe**. [Core changes like
this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6)
are used to make it thread-safe:
* Add an once_flag to CumulatedSequenceLengthCache to make sure it is
only initialized once; and change the cache to be read only after
initialization. Previously, the content is not read-only so it might be
changed by other thread and potentially cause buffer overrun.
* The kernel initialization is not guarded (Although the factory of
kernel loading has static mutex to guard multiple threading), so the
mutable variable might be set by two different threads at the same time.
Add an once_flag to avoid that.
This requires need some workspace computation change as well. So I did
not create a separated pull request.
(2) **Bias for cross attention**
That scenario has assumption that only query has bias, but not for key
and value. However, such assumption is not verified in runtime and there
was no comment of assumption, and there was no test case so the support
of scenario was disabled by mistake. Actually, the scenario is used in
whisper model (TODO: we shall add tests for whisper to CI pipeline, and
also update fusion script to verify such assumptions if needed.)
CUDA/CPU kernels supports bias for cross attention as long as bias is
zero for key and value. I updated the check to support the scenario and
added comments wherever there is such assumption.
(3) **Fallback support**
Previously, unfused kernel did not support packed qkv and packed kv
formats. That means some case might fail since there is no fallback. I
added new AddBiasTranpose cuda kernels for them to support fallback, so
that all supported cases will not fail.
#### Improvements
(4) **QKV workspace size**.
The logic for no_qkv_workspace could be easily out of sync since related
code are scattered in different source files. I refactor the code to
move all related code to one file (attention_prepare_qkv.cu) and add
asserts, so that the logic can be in sync.
(5) **Remove confusing concept of pass past in kv**
parameters.pass_past_in_kv is confusing since the k/v in cross attention
is not past state. Remove it and use parameters.qkv_format ==
Q_K_V_BSNH_BNSH_BNSH instead.
New code does not use past_key/past_value for cross attention, so the
logic is more clear.
(6) **More coverage and less workspace and less transpose of flash and
efficient attention**
Previously, there is one condition does not run flash or efficient
attention:
```
bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;
```
After this change, we can use flash and efficient attention for the
case, and also less workspace.
For example, cross attention with bias, the original code uses two
additional workspaces:
```
transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
add bias: query => q, temp_k_workspace => k, temp_v_workspace => v
```
New logic is like
```
if (has bias)
Add bias to query, key, value, and store in q, k, v workspace
else
Use query, key and value directly as q, k and v in kernel
```
We can see that, we do not need allocate temp_k_workspace and
temp_v_workspace so use less memory. New code saved two transposes in
this case.
Flash and efficient attention supports BSNH or BNSH formats for k and v.
In old code, k/v are also converted to BSNH format. Some is not
necessary. I do some change to convert k/v to BSNH or BNSH case by case.
So that there are more cases can be covered by flash or efficient
attention to improve performance.
(6) **Debugging support**
Previously, there is less debug info. In this change, I add a flag for
debug info in the AttentionData. So that we can output debug info during
the processing.
Also add functions to consolidate the dumping of inputs, QKV processing
and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow
disable dumping from cuda kernel.
#### Summary of changes
(1) Refactoring the CheckInputs, and pass in operator type.
(2) Refactoring the PrepareQKV to support fallback for packed qkv or
packed kv inputs.
(3) Change a few case of PrepareQKV to allow more case covered by flash
and efficient attention.
(4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace
parameters.pass_past_in_kv
(5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of
assumption that key/value has no bias in this case.
(6) Fix thread-safe issue in CumulatedSequenceLengthCache handling.
(7) Add test cases to cover all supported scenarios.
Current support scenarios for MultiHeadAttention for CUDA/CPU:
| Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc
| ---- | ---- | ---- | ------ | ----- | --------- | -------- |
-----|---------
| BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed
| BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in
CPU
| BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in
CPU
| BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for
Q only
| BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only
present
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and
present <br> (not share buffer)
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
|
|
|
def __init__(self, config: MultiHeadAttentionConfig, session_options=None, use_tf32: bool = True):
|
|
|
|
|
self.ort_session = create_session(config, session_options, use_tf32=use_tf32)
|
2024-06-12 20:04:25 +00:00
|
|
|
self.feed_dict = config.random_inputs()
|
|
|
|
|
|
2024-08-26 20:34:55 +00:00
|
|
|
def infer(self, run_options=None, synchronize=True):
|
|
|
|
|
return self.ort_session.infer(self.feed_dict, run_options=run_options, synchronize=synchronize)
|
2024-06-12 20:04:25 +00:00
|
|
|
|
|
|
|
|
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
def measure_latency(cuda_session: CudaSession, input_dict):
|
|
|
|
|
start = time.time()
|
|
|
|
|
_ = cuda_session.infer(input_dict)
|
|
|
|
|
end = time.time()
|
|
|
|
|
return end - start
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def flops(batch, sequence_length, head_size, num_heads, causal):
|
|
|
|
|
return 4 * batch * sequence_length**2 * num_heads * head_size // (2 if causal else 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tflops_per_second(flop, time):
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
try:
|
|
|
|
|
return (flop / time / 10**12) if not math.isnan(time) else 0.0
|
|
|
|
|
except ZeroDivisionError:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_gpu_kernel_name(attention_kernel: SdpaKernel) -> str:
|
|
|
|
|
kernel_names = {
|
|
|
|
|
SdpaKernel.DEFAULT: "ort:default",
|
|
|
|
|
SdpaKernel.FLASH_ATTENTION: "ort:flash",
|
|
|
|
|
SdpaKernel.EFFICIENT_ATTENTION: "ort:efficient",
|
|
|
|
|
SdpaKernel.CUDNN_FLASH_ATTENTION: "ort:cudnn",
|
|
|
|
|
SdpaKernel.MATH: "ort:math",
|
|
|
|
|
}
|
|
|
|
|
assert attention_kernel in kernel_names
|
|
|
|
|
return kernel_names[attention_kernel]
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
|
|
|
|
|
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
def get_cpu_kernel_name(config: MultiHeadAttentionConfig) -> str:
|
|
|
|
|
# CPU Flash Attention does not support causal and kv cache etc.
|
|
|
|
|
if not (config.causal or config.use_kv_cache or config.past_sequence_length > 0):
|
|
|
|
|
if os.getenv("ORT_DISABLE_FLASH_ATTENTION") != "1":
|
|
|
|
|
return "ort:flash"
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
return "ort:math"
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
|
|
|
|
|
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
# Functions for benchmarking PyTorch SDPA
|
|
|
|
|
# ------------------------------------------------------------------
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
def benchmark_torch_function(repeats: int, func: Callable, *args, **kwargs) -> float:
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
warmup = 5
|
|
|
|
|
for _ in range(warmup):
|
|
|
|
|
func(*args, **kwargs)
|
2024-06-12 20:04:25 +00:00
|
|
|
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
timer = benchmark.Timer(
|
|
|
|
|
stmt="func(*args, **kwargs)",
|
|
|
|
|
globals={"args": args, "kwargs": kwargs, "func": func},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return timer.timeit(number=repeats).median
|
2024-06-12 20:04:25 +00:00
|
|
|
|
|
|
|
|
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
def run_torch_sdpa(
|
|
|
|
|
batch_size: int,
|
|
|
|
|
q_seq_len: int,
|
|
|
|
|
kv_seq_len: int,
|
|
|
|
|
num_heads: int,
|
|
|
|
|
head_size: int,
|
|
|
|
|
causal: bool,
|
|
|
|
|
device,
|
|
|
|
|
dtype,
|
|
|
|
|
has_mask: bool = False,
|
|
|
|
|
mask_dim: int = 2,
|
|
|
|
|
mask_dtype=torch.bool,
|
|
|
|
|
backend: Optional[int] = None,
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
repeats: int = 100,
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
):
|
|
|
|
|
q_shape = (batch_size, num_heads, q_seq_len, head_size)
|
|
|
|
|
kv_shape = (batch_size, num_heads, kv_seq_len, head_size)
|
|
|
|
|
q = torch.randn(q_shape, device=device, dtype=dtype)
|
|
|
|
|
k = torch.randn(kv_shape, device=device, dtype=dtype)
|
|
|
|
|
v = torch.randn(kv_shape, device=device, dtype=dtype)
|
|
|
|
|
|
|
|
|
|
attn_mask = None
|
|
|
|
|
if has_mask:
|
|
|
|
|
mask_shape = (batch_size, num_heads, q_seq_len, kv_seq_len) if mask_dim == 4 else (q_seq_len, kv_seq_len)
|
|
|
|
|
attn_mask = torch.ones(mask_shape, dtype=mask_dtype, device=device)
|
|
|
|
|
|
|
|
|
|
context = sdpa_kernel(backend) if backend is not None else nullcontext()
|
|
|
|
|
|
|
|
|
|
with context:
|
|
|
|
|
average_latency = benchmark_torch_function(
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
repeats,
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
scaled_dot_product_attention,
|
|
|
|
|
q,
|
|
|
|
|
k,
|
|
|
|
|
v,
|
|
|
|
|
is_causal=causal,
|
|
|
|
|
attn_mask=attn_mask,
|
|
|
|
|
)
|
|
|
|
|
return average_latency
|
|
|
|
|
|
|
|
|
|
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
def get_test_configs(args: argparse.Namespace):
|
|
|
|
|
use_gpu: bool = args.use_gpu
|
|
|
|
|
|
|
|
|
|
if args.batch_size > 0:
|
|
|
|
|
run_unfused = args.sequence_length + args.past_sequence_length <= (2048 if use_gpu else 1024)
|
|
|
|
|
return [
|
|
|
|
|
(
|
|
|
|
|
args.batch_size,
|
|
|
|
|
args.sequence_length,
|
|
|
|
|
args.past_sequence_length,
|
|
|
|
|
args.num_heads,
|
|
|
|
|
args.head_size,
|
|
|
|
|
run_unfused,
|
|
|
|
|
),
|
|
|
|
|
]
|
|
|
|
|
|
2024-06-12 20:04:25 +00:00
|
|
|
if use_gpu:
|
|
|
|
|
# (batch_size, sequence_length, past_sequence_length, num_heads, head_size, run_unfused)
|
|
|
|
|
configs = [
|
|
|
|
|
(32, 512, 0, 64, 32, True),
|
|
|
|
|
(32, 512, 0, 128, 16, True),
|
|
|
|
|
(16, 1024, 0, 64, 32, True),
|
|
|
|
|
(16, 1024, 0, 128, 16, True),
|
|
|
|
|
(8, 2048, 0, 64, 32, True),
|
|
|
|
|
(8, 2048, 0, 128, 16, False),
|
|
|
|
|
(4, 4096, 0, 64, 32, False),
|
|
|
|
|
(4, 4096, 0, 128, 16, False),
|
|
|
|
|
(2, 8192, 0, 64, 32, False),
|
|
|
|
|
(2, 8192, 0, 128, 16, False),
|
|
|
|
|
(1, 16384, 0, 64, 32, False),
|
|
|
|
|
(1, 16384, 0, 128, 16, False),
|
|
|
|
|
# stable diffusion
|
|
|
|
|
(1, 4096, 0, 8, 40, False),
|
|
|
|
|
(1, 4096, 0, 8, 80, False),
|
|
|
|
|
(1, 4096, 0, 8, 160, False),
|
|
|
|
|
(4, 4096, 0, 8, 40, False),
|
|
|
|
|
(4, 4096, 0, 8, 80, False),
|
|
|
|
|
(4, 4096, 0, 8, 160, False),
|
|
|
|
|
(1, 16384, 0, 8, 40, False),
|
|
|
|
|
(1, 16384, 0, 8, 80, False),
|
|
|
|
|
(1, 16384, 0, 8, 160, False),
|
|
|
|
|
# bert-base
|
|
|
|
|
(128, 128, 0, 12, 64, True),
|
|
|
|
|
(64, 128, 0, 12, 64, True),
|
|
|
|
|
(128, 384, 0, 12, 64, True),
|
|
|
|
|
(64, 384, 0, 12, 64, True),
|
|
|
|
|
(128, 512, 0, 12, 64, True),
|
|
|
|
|
(64, 512, 0, 12, 64, True),
|
|
|
|
|
# TNLGv4
|
|
|
|
|
(4, 2048, 0, 32, 128, True),
|
|
|
|
|
(4, 4096, 0, 32, 128, False),
|
|
|
|
|
(8, 2048, 0, 32, 128, False),
|
|
|
|
|
(8, 4096, 0, 32, 128, False),
|
|
|
|
|
]
|
|
|
|
|
else:
|
|
|
|
|
configs = [
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
# TNLGv4
|
2024-06-12 20:04:25 +00:00
|
|
|
(1, 128, 0, 32, 128, True),
|
|
|
|
|
(1, 256, 0, 32, 128, True),
|
|
|
|
|
(1, 512, 0, 32, 128, True),
|
|
|
|
|
(1, 1024, 0, 32, 128, True),
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
# (1, 2048, 0, 32, 128, True),
|
|
|
|
|
# bert-base
|
|
|
|
|
(1, 128, 0, 12, 64, True),
|
|
|
|
|
(1, 384, 0, 12, 64, True),
|
|
|
|
|
(1, 512, 0, 12, 64, True),
|
|
|
|
|
(4, 128, 0, 12, 64, True),
|
|
|
|
|
(4, 384, 0, 12, 64, True),
|
|
|
|
|
(4, 512, 0, 12, 64, True),
|
|
|
|
|
# bert-large
|
|
|
|
|
(1, 128, 0, 16, 64, True),
|
|
|
|
|
(1, 384, 0, 16, 64, True),
|
|
|
|
|
(1, 512, 0, 16, 64, True),
|
|
|
|
|
(4, 128, 0, 16, 64, True),
|
|
|
|
|
(4, 384, 0, 16, 64, True),
|
|
|
|
|
(4, 512, 0, 16, 64, True),
|
2024-06-12 20:04:25 +00:00
|
|
|
]
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
return configs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_compute_capability():
|
|
|
|
|
assert torch.cuda.is_available()
|
|
|
|
|
major, minor = torch.cuda.get_device_capability()
|
|
|
|
|
sm = major * 10 + minor
|
|
|
|
|
return sm
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
|
|
|
|
|
|
2024-08-22 00:30:16 +00:00
|
|
|
class CaptureStdout:
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.fd = sys.stdout.fileno()
|
|
|
|
|
self.chunk_size = 1024
|
|
|
|
|
self.output = b""
|
|
|
|
|
|
|
|
|
|
def _capture(self):
|
|
|
|
|
chunks = []
|
|
|
|
|
while chunk := os.read(self._pipe_reader, self.chunk_size):
|
|
|
|
|
chunks.append(chunk)
|
|
|
|
|
self.output = b"".join(chunks)
|
|
|
|
|
|
|
|
|
|
def __enter__(self):
|
|
|
|
|
self._duped_fd = os.dup(self.fd)
|
|
|
|
|
self._pipe_reader, pipe_writer = os.pipe()
|
|
|
|
|
os.dup2(pipe_writer, self.fd)
|
|
|
|
|
os.close(pipe_writer)
|
|
|
|
|
self._capture_thread = threading.Thread(target=self._capture)
|
|
|
|
|
self._capture_thread.start()
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
|
|
|
os.close(self.fd)
|
|
|
|
|
self._capture_thread.join()
|
|
|
|
|
os.close(self._pipe_reader)
|
|
|
|
|
os.dup2(self._duped_fd, self.fd)
|
|
|
|
|
os.close(self._duped_fd)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sdpa_kernel_from_debug_info(
|
|
|
|
|
config: MultiHeadAttentionConfig, attention_kernel: SdpaKernel, sess_options: SessionOptions
|
|
|
|
|
):
|
|
|
|
|
os.environ["ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO"] = "1"
|
|
|
|
|
captured_text = None
|
|
|
|
|
try:
|
|
|
|
|
with CaptureStdout() as captured:
|
|
|
|
|
session = create_session(config, sess_options, attention_kernel=attention_kernel)
|
|
|
|
|
input_dict = config.random_inputs()
|
|
|
|
|
session.infer(input_dict)
|
|
|
|
|
captured_text = captured.output.decode()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"Failed to run {attention_kernel=} for {config=}. Exception: {e}")
|
|
|
|
|
finally:
|
|
|
|
|
os.environ["ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO"] = "0"
|
|
|
|
|
|
|
|
|
|
if captured_text is not None:
|
|
|
|
|
m = re.search("SdpaKernel=(?P<kernel>[A-Z_]+)", captured_text)
|
|
|
|
|
if m is not None:
|
|
|
|
|
name = m.group("kernel")
|
|
|
|
|
kernel_names = {
|
|
|
|
|
"FLASH_ATTENTION": "ort:flash",
|
|
|
|
|
"EFFICIENT_ATTENTION": "ort:efficient",
|
|
|
|
|
"CUDNN_FLASH_ATTENTION": "ort:cudnn",
|
|
|
|
|
"MATH": "ort:math",
|
|
|
|
|
"TRT_FUSED_ATTENTION": "ort:trt_fmha",
|
|
|
|
|
"TRT_FLASH_ATTENTION": "ort:trt_flash",
|
|
|
|
|
"TRT_CROSS_ATTENTION": "ort:trt_cross",
|
|
|
|
|
"TRT_CAUSAL_ATTENTION": "ort:trt_causal",
|
|
|
|
|
}
|
|
|
|
|
return kernel_names[name]
|
|
|
|
|
else:
|
|
|
|
|
print("Failed to get sdpa kernel from debug info:", captured_text)
|
|
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
def run_tflops_test(
|
|
|
|
|
csv_writer: csv.DictWriter,
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
args: argparse.Namespace,
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
):
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
use_gpu: bool = args.use_gpu
|
|
|
|
|
enable_cuda_graph: bool = args.use_cuda_graph
|
|
|
|
|
causal: bool = args.causal
|
|
|
|
|
intra_op_num_threads: int = args.intra_op_num_threads
|
|
|
|
|
repeats: int = args.repeats
|
|
|
|
|
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
print(f"run_tflops_test: causal={causal}")
|
|
|
|
|
|
|
|
|
|
if use_gpu:
|
|
|
|
|
device_id = torch.cuda.current_device()
|
|
|
|
|
device = torch.device("cuda", device_id)
|
|
|
|
|
formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH, InputFormats.Q_KV_BSNH_BSN2H, InputFormats.QKV_BSN3H]
|
|
|
|
|
provider = "CUDAExecutionProvider"
|
|
|
|
|
# flash attention is available for sm >= 80
|
|
|
|
|
sm = get_compute_capability()
|
|
|
|
|
if sm >= 80:
|
[CUDA] cuDNN Flash Attention (#21629)
### Description
- [x] Add cuDNN flash attention using cudnn frontend, and enable it in
MultiHeadAttention operator.
- [x] Support attention mask.
- [x] Support attention bias.
- [x] Update tests and benchmark script.
The cuDNN SDPA is disabled by default. To enable it, need the following:
(1) Requires cuDNN 9.3 or newer version installed.
(2) Set an environment variable `ORT_ENABLE_CUDNN_FLASH_ATTENTION=1` or
set `sdpa_kernel=8` cuda provider option to enable it.
(3) Only works on devices with compute capability >= 8.0.
Note that some combinations of parameters might be rejected due to
limited support of head dimension or sequence lengths.
Future Works:
(1) FP8 and BF16 APIs. Currently, only API for FP16 are exposed.
(2) Add API to support ragged batching (padding removed in inputs).
(3) Support other input formats (like QKV_BS3NH).
(4) Currently, q are converted to BSNH, k/v are converted to either BSNH
or BNSH format. May do some experiment to see whether converting q to
BNSH could be better in some case.
### Example Benchmark Results on H100
The following tests are on FP16 MultiHeadAttention operator without
attention mask and attention bias.
#### Test Setting 1
batch_size | sequence_length | past_sequence_length | num_heads |
head_size
-- | -- | -- | -- | --
16 | 256 | 0 | 32 | 128
format | average_latency | tflops | kernel
-- | -- | -- | --
Q,K,V (BNSH) | 0.000075 | 229.5 | torch:flash
Q,K,V (BNSH) | 0.000119 | 144.8 | torch:efficient
Q,K,V (BNSH) | 0.000224 | 76.5 | torch:math
Q,K,V (BSNH) | 0.000075 | 227.8 | ort:cudnn
Q,K,V (BSNH) | 0.000094 | 182.8 | ort:flash
Q,K,V (BSNH) | 0.000138 | 124.7 | ort:efficient
Q,K,V (BSNH) | 0.000438 | 39.3 | ort:math
Q,KV | 0.000129 | 133.0 | ort:cudnn
Q,KV | 0.000151 | 114.1 | ort:flash
Q,KV | 0.000194 | 88.5 | ort:efficient
QKV | 0.000154 | 111.8 | ort:cudnn
QKV | 0.000175 | 98.0 | ort:flash
QKV | 0.000217 | 79.0 | ort:efficient
#### Test Setting 2
batch_size | sequence_length | past_sequence_length | num_heads |
head_size
-- | -- | -- | -- | --
16 | 512 | 0 | 16 | 64
format | average_latency | tflops | kernel
-- | -- | -- | --
Q,K,V (BNSH) | 0.000069 | 249.2 | torch:flash
Q,K,V (BNSH) | 0.000141 | 121.7 | torch:efficient
Q,K,V (BNSH) | 0.000294 | 58.5 | torch:math
Q,K,V (BSNH) | 0.000077 | 221.7 | ort:cudnn
Q,K,V (BSNH) | 0.000087 | 196.6 | ort:flash
Q,K,V (BSNH) | 0.000163 | 105.6 | ort:efficient
Q,K,V (BSNH) | 0.000651 | 26.4 | ort:math
Q,KV | 0.000103 | 167.1 | ort:cudnn
Q,KV | 0.000117 | 146.3 | ort:flash
Q,KV | 0.000192 | 89.6 | ort:efficient
QKV | 0.000113 | 151.5 | ort:cudnn
QKV | 0.000128 | 134.7 | ort:flash
QKV | 0.000201 | 85.3 | ort:efficient
2024-08-20 15:50:22 +00:00
|
|
|
backends = [
|
|
|
|
|
SdpaKernel.DEFAULT,
|
|
|
|
|
SdpaKernel.FLASH_ATTENTION,
|
|
|
|
|
SdpaKernel.EFFICIENT_ATTENTION,
|
|
|
|
|
SdpaKernel.CUDNN_FLASH_ATTENTION,
|
|
|
|
|
SdpaKernel.MATH,
|
|
|
|
|
]
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
else:
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
backends = [SdpaKernel.DEFAULT, SdpaKernel.EFFICIENT_ATTENTION, SdpaKernel.MATH]
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
else:
|
|
|
|
|
device_id = 0
|
|
|
|
|
device = torch.device("cpu")
|
|
|
|
|
formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH]
|
|
|
|
|
enable_cuda_graph = False
|
|
|
|
|
provider = "CPUExecutionProvider"
|
|
|
|
|
backends = [SdpaKernel.DEFAULT]
|
|
|
|
|
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
configs = get_test_configs(args)
|
2024-08-22 00:30:16 +00:00
|
|
|
print(
|
|
|
|
|
"\nformat\tcausal\tattBias\tbatch\tseqlen\tpast\theads\th_dim\tthreads\tms\tTFLOPS\tsdpa_kernel\trequest_kernel"
|
|
|
|
|
)
|
2024-06-12 20:04:25 +00:00
|
|
|
|
|
|
|
|
for input_format in formats:
|
|
|
|
|
for batch_size, sequence_length, past_sequence_length, num_heads, head_size, enable_unfused in configs:
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
config = MultiHeadAttentionConfig(
|
|
|
|
|
batch_size=batch_size,
|
|
|
|
|
sequence_length=sequence_length,
|
|
|
|
|
num_heads=num_heads,
|
|
|
|
|
head_size=head_size,
|
|
|
|
|
causal=causal,
|
|
|
|
|
use_kv_cache=past_sequence_length > 0,
|
|
|
|
|
past_sequence_length=past_sequence_length,
|
|
|
|
|
max_cache_sequence_length=None,
|
|
|
|
|
kv_sequence_length=None,
|
|
|
|
|
provider=provider,
|
|
|
|
|
enable_cuda_graph=enable_cuda_graph,
|
|
|
|
|
device=device,
|
|
|
|
|
dtype=torch.float16 if use_gpu else torch.float,
|
|
|
|
|
share_past_present_buffer=False,
|
|
|
|
|
input_format=input_format,
|
|
|
|
|
has_attn_bias=args.has_attn_bias,
|
|
|
|
|
broadcast_attn_bias_dim_0=args.broadcast_attn_bias_dim_0,
|
|
|
|
|
broadcast_attn_bias_dim_1=args.broadcast_attn_bias_dim_1,
|
|
|
|
|
)
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
for attention_kernel in backends:
|
|
|
|
|
sess_options = SessionOptions()
|
|
|
|
|
sess_options.intra_op_num_threads = intra_op_num_threads
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
|
2024-06-12 20:04:25 +00:00
|
|
|
if use_gpu:
|
2024-08-22 00:30:16 +00:00
|
|
|
request_kernel = get_gpu_kernel_name(attention_kernel)
|
2024-06-12 20:04:25 +00:00
|
|
|
else:
|
2024-08-22 00:30:16 +00:00
|
|
|
request_kernel = get_cpu_kernel_name(config)
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
|
2024-08-22 00:30:16 +00:00
|
|
|
if "math" in request_kernel:
|
2024-06-12 20:04:25 +00:00
|
|
|
# Skip large sequence length for Unfused kernel to avoid OOM.
|
|
|
|
|
if not enable_unfused:
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
if config.verbose:
|
|
|
|
|
print(f"skip unfused kernel for {vars(config)}")
|
2024-06-12 20:04:25 +00:00
|
|
|
continue
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
|
2024-06-12 20:04:25 +00:00
|
|
|
# Unfused kernel does not support packed QKV or packed KV formats.
|
|
|
|
|
if input_format not in [InputFormats.Q_K_V_BSNH_BSNH_BSNH]:
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
if config.verbose:
|
|
|
|
|
print(f"skip input_format for {vars(config)}")
|
2024-06-12 20:04:25 +00:00
|
|
|
continue
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
|
2024-08-22 00:30:16 +00:00
|
|
|
if use_gpu:
|
|
|
|
|
actual_kernel = sdpa_kernel_from_debug_info(config, attention_kernel, sess_options)
|
|
|
|
|
if actual_kernel is None:
|
|
|
|
|
print(f"Warning: skip {config} since kernel from debug info is None")
|
|
|
|
|
continue
|
|
|
|
|
else:
|
|
|
|
|
# CPU has no debug info for now.
|
|
|
|
|
actual_kernel = request_kernel
|
|
|
|
|
|
|
|
|
|
session = create_session(config, sess_options, attention_kernel=attention_kernel)
|
2024-06-12 20:04:25 +00:00
|
|
|
input_dict = config.random_inputs()
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
|
2024-06-12 20:04:25 +00:00
|
|
|
# warm up session
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
try:
|
|
|
|
|
_ = measure_latency(session, input_dict)
|
|
|
|
|
except Exception as e:
|
2024-08-22 00:30:16 +00:00
|
|
|
print(f"Failed to run {request_kernel=} for {config=}. Exception: {e}")
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
continue
|
2024-06-12 20:04:25 +00:00
|
|
|
|
|
|
|
|
latency_list = []
|
|
|
|
|
for _ in range(repeats):
|
|
|
|
|
latency = measure_latency(session, input_dict)
|
|
|
|
|
latency_list.append(latency)
|
|
|
|
|
average_latency = statistics.mean(latency_list)
|
|
|
|
|
|
|
|
|
|
del session
|
|
|
|
|
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
format_str = InputFormats.input_format_str(input_format)
|
|
|
|
|
|
2024-06-12 20:04:25 +00:00
|
|
|
# compute TFLOPS per second
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
speed = None
|
|
|
|
|
if past_sequence_length == 0:
|
|
|
|
|
speed = tflops_per_second(
|
|
|
|
|
flops(batch_size, sequence_length, head_size, num_heads, causal), average_latency
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
row = {
|
|
|
|
|
"use_gpu": use_gpu,
|
|
|
|
|
"enable_cuda_graph": enable_cuda_graph,
|
|
|
|
|
"format": format_str,
|
|
|
|
|
"causal": causal,
|
|
|
|
|
"batch_size": batch_size,
|
|
|
|
|
"sequence_length": sequence_length,
|
|
|
|
|
"past_sequence_length": past_sequence_length,
|
|
|
|
|
"num_heads": num_heads,
|
|
|
|
|
"head_size": head_size,
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
"has_attn_bias": args.has_attn_bias,
|
|
|
|
|
"broadcast_attn_bias_dim_0": args.broadcast_attn_bias_dim_0,
|
|
|
|
|
"broadcast_attn_bias_dim_1": args.broadcast_attn_bias_dim_1,
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
"intra_op_num_threads": intra_op_num_threads,
|
|
|
|
|
"average_latency": average_latency,
|
|
|
|
|
"tflops": speed,
|
2024-08-22 00:30:16 +00:00
|
|
|
"request_kernel": request_kernel,
|
|
|
|
|
"kernel": actual_kernel,
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
}
|
|
|
|
|
csv_writer.writerow(row)
|
2024-06-12 20:04:25 +00:00
|
|
|
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
speed = f"{speed:.2f}" if speed is not None else "NA"
|
2024-06-12 20:04:25 +00:00
|
|
|
print(
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
f"{format_str}\t{causal}\t{args.has_attn_bias}\t{batch_size}\t"
|
|
|
|
|
f"{sequence_length}\t{past_sequence_length}\t{num_heads}\t{head_size}\t"
|
2024-08-22 00:30:16 +00:00
|
|
|
f"{intra_op_num_threads}\t{average_latency * 1000:.2f}\t{speed}\t{actual_kernel}\t{request_kernel}"
|
2024-06-12 20:04:25 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
def run_torch_test(
|
|
|
|
|
csv_writer: csv.DictWriter,
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
args: argparse.Namespace,
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
):
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
use_gpu: bool = args.use_gpu
|
|
|
|
|
causal: bool = args.causal
|
|
|
|
|
|
|
|
|
|
configs = get_test_configs(args)
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
|
|
|
|
|
if use_gpu:
|
|
|
|
|
if not torch.cuda.is_available():
|
|
|
|
|
return
|
|
|
|
|
device_id = torch.cuda.current_device()
|
|
|
|
|
device = torch.device("cuda", device_id)
|
|
|
|
|
dtype = torch.float16
|
|
|
|
|
backends = [
|
|
|
|
|
None,
|
|
|
|
|
SDPBackend.FLASH_ATTENTION,
|
|
|
|
|
SDPBackend.EFFICIENT_ATTENTION,
|
|
|
|
|
SDPBackend.CUDNN_ATTENTION,
|
|
|
|
|
SDPBackend.MATH,
|
|
|
|
|
]
|
|
|
|
|
else:
|
|
|
|
|
device = torch.device("cpu")
|
|
|
|
|
dtype = torch.float32
|
|
|
|
|
backends = [None]
|
|
|
|
|
|
|
|
|
|
backend_names = {
|
|
|
|
|
SDPBackend.FLASH_ATTENTION: "torch:flash",
|
|
|
|
|
SDPBackend.EFFICIENT_ATTENTION: "torch:efficient",
|
|
|
|
|
SDPBackend.CUDNN_ATTENTION: "torch:cudnn",
|
|
|
|
|
SDPBackend.MATH: "torch:math",
|
|
|
|
|
None: "torch:default",
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Test PyTorch latency
|
|
|
|
|
for batch_size, sequence_length, past_sequence_length, num_heads, head_size, enable_unfused in configs:
|
|
|
|
|
for backend in backends:
|
|
|
|
|
if backend == SDPBackend.MATH and not enable_unfused:
|
|
|
|
|
continue
|
|
|
|
|
if backend == SDPBackend.FLASH_ATTENTION and platform.system() != "Linux":
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
backend_name = backend_names[backend]
|
|
|
|
|
try:
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
torch_latency = run_torch_sdpa(
|
|
|
|
|
batch_size,
|
|
|
|
|
sequence_length,
|
|
|
|
|
sequence_length,
|
|
|
|
|
num_heads,
|
|
|
|
|
head_size,
|
|
|
|
|
causal,
|
|
|
|
|
has_mask=False,
|
|
|
|
|
mask_dim=2,
|
|
|
|
|
mask_dtype=torch.bool,
|
|
|
|
|
device=device,
|
|
|
|
|
dtype=dtype,
|
|
|
|
|
backend=backend,
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
repeats=args.repeats,
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
)
|
|
|
|
|
except RuntimeError:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
speed = tflops_per_second(flops(batch_size, sequence_length, head_size, num_heads, causal), torch_latency)
|
|
|
|
|
input_format = "Q,K,V"
|
|
|
|
|
print(
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
f"{input_format}\t{causal}\t{False}\t{batch_size}\t"
|
|
|
|
|
f"{sequence_length}\t{past_sequence_length}\t{num_heads}\t{head_size}\t"
|
2024-08-22 00:30:16 +00:00
|
|
|
f"{torch.get_num_threads()}\t{torch_latency * 1000:.2f}\t{speed}\t{backend_name}\t{backend_name}"
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
)
|
|
|
|
|
row = {
|
|
|
|
|
"use_gpu": use_gpu,
|
|
|
|
|
"enable_cuda_graph": False,
|
|
|
|
|
"format": input_format,
|
|
|
|
|
"causal": causal,
|
|
|
|
|
"batch_size": batch_size,
|
|
|
|
|
"sequence_length": sequence_length,
|
|
|
|
|
"past_sequence_length": past_sequence_length,
|
|
|
|
|
"num_heads": num_heads,
|
|
|
|
|
"head_size": head_size,
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
"has_attn_bias": False,
|
|
|
|
|
"broadcast_attn_bias_dim_0": False,
|
|
|
|
|
"broadcast_attn_bias_dim_1": False,
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
"intra_op_num_threads": torch.get_num_threads(),
|
|
|
|
|
"average_latency": torch_latency,
|
|
|
|
|
"tflops": speed,
|
2024-08-22 00:30:16 +00:00
|
|
|
"request_kernel": backend_name,
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
"kernel": backend_name,
|
|
|
|
|
}
|
|
|
|
|
csv_writer.writerow(row)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_tflops_tests(args):
|
|
|
|
|
features = "gpu" if args.use_gpu else "cpu"
|
|
|
|
|
if args.causal:
|
|
|
|
|
features += "_causal"
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
if args.past_sequence_length > 0:
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
features += "_past"
|
|
|
|
|
csv_filename = "benchmark_mha_{}_{}_{}.csv".format(
|
|
|
|
|
features,
|
|
|
|
|
"torch" if args.torch else "ort",
|
|
|
|
|
datetime.now().strftime("%Y%m%d-%H%M%S"),
|
|
|
|
|
)
|
|
|
|
|
with open(csv_filename, mode="a", newline="") as csv_file:
|
|
|
|
|
column_names = [
|
|
|
|
|
"use_gpu",
|
|
|
|
|
"enable_cuda_graph",
|
|
|
|
|
"format",
|
|
|
|
|
"causal",
|
|
|
|
|
"batch_size",
|
|
|
|
|
"sequence_length",
|
|
|
|
|
"past_sequence_length",
|
|
|
|
|
"num_heads",
|
|
|
|
|
"head_size",
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
"has_attn_bias",
|
|
|
|
|
"broadcast_attn_bias_dim_0",
|
|
|
|
|
"broadcast_attn_bias_dim_1",
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
"intra_op_num_threads",
|
|
|
|
|
"average_latency",
|
|
|
|
|
"tflops",
|
2024-08-22 00:30:16 +00:00
|
|
|
"request_kernel",
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
"kernel",
|
|
|
|
|
]
|
|
|
|
|
csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
|
|
|
|
|
csv_writer.writeheader()
|
|
|
|
|
|
|
|
|
|
if args.torch:
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
run_torch_test(csv_writer, args)
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
else:
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
run_tflops_test(csv_writer, args)
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
|
|
|
|
|
|
2024-06-12 20:04:25 +00:00
|
|
|
def plot_prompt_performance(
|
|
|
|
|
model_name: str,
|
|
|
|
|
batch_size: int,
|
|
|
|
|
num_heads: int,
|
|
|
|
|
head_size: int,
|
|
|
|
|
max_seq_len: int,
|
|
|
|
|
):
|
|
|
|
|
import triton
|
|
|
|
|
|
|
|
|
|
formats = InputFormats.get_name_list()
|
|
|
|
|
|
|
|
|
|
# Exclude cross attention since kernel crashes for some configuration.
|
|
|
|
|
formats = formats[:-1]
|
|
|
|
|
|
|
|
|
|
settings = {
|
|
|
|
|
"line_vals": formats,
|
|
|
|
|
"line_names": ["ORT-MHA:" + name for name in formats],
|
|
|
|
|
"styles": [("red", "solid"), ("yellow", "dashdot"), ("blue", "dashed"), ("green", "dotted")][0 : len(formats)],
|
|
|
|
|
}
|
|
|
|
|
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
sm = get_compute_capability()
|
2024-06-12 20:04:25 +00:00
|
|
|
configs = [
|
|
|
|
|
triton.testing.Benchmark(
|
|
|
|
|
x_names=["sequence_length"],
|
|
|
|
|
x_vals=[2**i for i in range(6, 17) if 2**i <= max_seq_len],
|
|
|
|
|
line_arg="input_format",
|
|
|
|
|
ylabel="ms",
|
|
|
|
|
**settings,
|
|
|
|
|
plot_name=f"prompt-sm{sm}-{model_name}-b{batch_size}-h{num_heads}_{head_size}-fp16",
|
|
|
|
|
args={
|
|
|
|
|
"batch_size": batch_size,
|
|
|
|
|
"num_heads": num_heads,
|
|
|
|
|
"head_size": head_size,
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
@triton.testing.perf_report(configs)
|
|
|
|
|
def benchmark(
|
|
|
|
|
input_format: str,
|
|
|
|
|
sequence_length: int,
|
|
|
|
|
batch_size: int,
|
|
|
|
|
num_heads: int,
|
|
|
|
|
head_size: int,
|
|
|
|
|
device="cuda",
|
|
|
|
|
):
|
|
|
|
|
warmup = 15
|
|
|
|
|
repeat = 100
|
|
|
|
|
|
|
|
|
|
config: MultiHeadAttentionConfig = MultiHeadAttentionConfig(
|
|
|
|
|
batch_size=batch_size,
|
|
|
|
|
sequence_length=sequence_length,
|
|
|
|
|
num_heads=num_heads,
|
|
|
|
|
head_size=head_size,
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
causal=False,
|
2024-06-12 20:04:25 +00:00
|
|
|
past_sequence_length=0,
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
kv_sequence_length=sequence_length if input_format == "Q,K',V'" else None,
|
2024-06-12 20:04:25 +00:00
|
|
|
max_cache_sequence_length=max_seq_len,
|
2024-06-19 17:23:26 +00:00
|
|
|
provider="CUDAExecutionProvider",
|
|
|
|
|
enable_cuda_graph=False,
|
2024-06-12 20:04:25 +00:00
|
|
|
device=device,
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
dtype=torch.float16,
|
2024-06-12 20:04:25 +00:00
|
|
|
use_kv_cache=False,
|
|
|
|
|
input_format=InputFormats.convert(input_format),
|
|
|
|
|
)
|
|
|
|
|
|
2024-06-19 17:23:26 +00:00
|
|
|
obj = OrtMultiHeadAttention(config)
|
2024-06-12 20:04:25 +00:00
|
|
|
ms = triton.testing.do_bench(obj.infer, warmup=warmup, rep=repeat)
|
|
|
|
|
return ms
|
|
|
|
|
|
|
|
|
|
benchmark.run(save_path=".", print_data=True)
|
|
|
|
|
|
|
|
|
|
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
def run_bert_performance_test():
|
2024-06-12 20:04:25 +00:00
|
|
|
"""
|
|
|
|
|
Run performance tests for prompt and token generation.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
configures = [
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
# (1, 32, 128, 8192, "TNLGv4"),
|
|
|
|
|
# (4, 32, 128, 8192, "TNLGv4"),
|
2024-06-12 20:04:25 +00:00
|
|
|
(1, 12, 64, 1024, "BertBase"),
|
|
|
|
|
(16, 12, 64, 1024, "BertBase"),
|
|
|
|
|
(1, 16, 64, 1024, "BertLarge"),
|
|
|
|
|
(8, 16, 64, 1024, "BertLarge"),
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
for batch_size, num_heads, head_size, max_seq_len, model_name in configures:
|
|
|
|
|
plot_prompt_performance(
|
|
|
|
|
batch_size=batch_size,
|
|
|
|
|
num_heads=num_heads,
|
|
|
|
|
head_size=head_size,
|
|
|
|
|
max_seq_len=max_seq_len,
|
|
|
|
|
model_name=model_name,
|
|
|
|
|
)
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
|
|
|
|
|
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
def _parse_arguments():
|
|
|
|
|
parser = argparse.ArgumentParser(description="Benchmark MultiHeadAttention for ONNX Runtime and PyTorch.")
|
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--use_gpu",
|
|
|
|
|
required=False,
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="Use GPU for inference.",
|
|
|
|
|
)
|
|
|
|
|
parser.set_defaults(use_gpu=False)
|
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--use_cuda_graph",
|
|
|
|
|
required=False,
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="Use cuda graph in onnxruntime.",
|
|
|
|
|
)
|
|
|
|
|
parser.set_defaults(use_cuda_graph=False)
|
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--intra_op_num_threads",
|
|
|
|
|
required=False,
|
|
|
|
|
type=int,
|
|
|
|
|
choices=[0, 1, 2, 4, 8, 16],
|
|
|
|
|
default=0,
|
|
|
|
|
help="intra_op_num_threads for onnxruntime. ",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
"--causal",
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
required=False,
|
|
|
|
|
action="store_true",
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
help="test unidirectional",
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
)
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
parser.set_defaults(causal=False)
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
|
|
|
|
|
parser.add_argument(
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
"-b",
|
|
|
|
|
"--batch_size",
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
required=False,
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
type=int,
|
|
|
|
|
default=0,
|
|
|
|
|
help="batch size",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"-s",
|
|
|
|
|
"--sequence_length",
|
|
|
|
|
required=False,
|
|
|
|
|
type=int,
|
|
|
|
|
default=512,
|
|
|
|
|
help="sequence length",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"-p",
|
|
|
|
|
"--past_sequence_length",
|
|
|
|
|
required=False,
|
|
|
|
|
type=int,
|
|
|
|
|
default=0,
|
|
|
|
|
help="past sequence length",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"-n",
|
|
|
|
|
"--num_heads",
|
|
|
|
|
required=False,
|
|
|
|
|
type=int,
|
|
|
|
|
default=16,
|
|
|
|
|
help="number of attention heads",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"-d",
|
|
|
|
|
"--head_size",
|
|
|
|
|
required=False,
|
|
|
|
|
type=int,
|
|
|
|
|
default=64,
|
|
|
|
|
help="hidden dimension per head",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"-r",
|
|
|
|
|
"--repeats",
|
|
|
|
|
required=False,
|
|
|
|
|
type=int,
|
2024-08-22 00:30:16 +00:00
|
|
|
default=0,
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
help="number of repeats for performance test",
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--torch",
|
|
|
|
|
required=False,
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="test pytorch instead of onnxruntime",
|
|
|
|
|
)
|
|
|
|
|
parser.set_defaults(torch=False)
|
|
|
|
|
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
parser.add_argument(
|
|
|
|
|
"--has_attn_bias",
|
|
|
|
|
required=False,
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="has attention bias",
|
|
|
|
|
)
|
|
|
|
|
parser.set_defaults(has_attn_bias=False)
|
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--broadcast_attn_bias_dim_0",
|
|
|
|
|
required=False,
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="broadcast attention bias dimension 0",
|
|
|
|
|
)
|
|
|
|
|
parser.set_defaults(broadcast_attn_bias_dim_0=False)
|
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--broadcast_attn_bias_dim_1",
|
|
|
|
|
required=False,
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="broadcast attention bias dimension 1",
|
|
|
|
|
)
|
|
|
|
|
parser.set_defaults(broadcast_attn_bias_dim_1=False)
|
|
|
|
|
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
|
Flash Attention v2 MHA (#17227)
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported
Some limitations (like attention mask and causal) might be removed
later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.
### Speedup
The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.
There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH
Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.
We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).
input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95
### Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.
### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
|
|
|
if __name__ == "__main__":
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
args = _parse_arguments()
|
|
|
|
|
print(f"arguments:{args}")
|
|
|
|
|
|
2024-08-22 00:30:16 +00:00
|
|
|
if args.repeats == 0:
|
|
|
|
|
args.repeats = 10000 if args.use_gpu else 100
|
|
|
|
|
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
if args.use_gpu:
|
|
|
|
|
assert torch.cuda.is_available()
|
|
|
|
|
if not args.torch:
|
|
|
|
|
assert "CUDAExecutionProvider" in get_available_providers()
|
2024-06-19 17:23:26 +00:00
|
|
|
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
if args.torch:
|
|
|
|
|
assert Version(torch.__version__) >= Version("2.3.0")
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
assert args.past_sequence_length == 0
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
|
Extend Attention Bias Broadcast Support (#21710)
### Description
Previously, MultiHeadAttention supports relative position bias of shape
[1, N, S, T] or [B, N, S, T], and DecoderMaskedMultiHeadAttention
supports [1, N, S, T]. This will extend the support to allow [1, N, S,
T], [B, N, S, T], [B, 1, S, T] and [1, 1, S, T] for CUDA and CPU EPs.
- [x] Rename the input of "relative position bias" to "attention bias"
because it can also be used for other types of bias, like ALiBi
(Attention with Linear Biases) or attention mask.
- [x] Update unfused kernel to support broadcasting 2nd dimension of
attention bias.
- [x] Update efficient attention to support broadcasting 2nd dimension
of attention bias.
- [x] Update operators (MultiHeadAttention,
DecoderMaskedMultiHeadAttention, Attention, PackedAttention,
PackedMultiHeadAttention) to support broadcast attention bias on CUDA
and CPU EPs.
- [x] Update ROCm, DML and WebGPU naming to be consistent. (Note that
those EPs do not support broadcasting attention_bias for now).
- [x] Add attention bias tests for MultiHeadAttention.
- [x] Update operator documents
- [x] Update benchmark script
Other changes:
* Fix some checks in multihead-attention.ts
* Add helper functions to dump tensors given dimensions.
2024-08-16 22:40:04 +00:00
|
|
|
if args.use_gpu and args.batch_size == 0 and not args.torch:
|
2024-06-19 17:23:26 +00:00
|
|
|
if platform.system() == "Linux":
|
|
|
|
|
s = torch.cuda.Stream()
|
|
|
|
|
with torch.cuda.stream(s), torch.no_grad():
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
run_bert_performance_test()
|
2024-06-12 20:04:25 +00:00
|
|
|
|
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks
For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.
#### Example GPU results
Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient
#### Example CPU results
Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.
format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default
### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
|
|
|
run_tflops_tests(args)
|