onnxruntime/onnxruntime/test/python/transformers/benchmark_mha.py

1374 lines
49 KiB
Python
Raw Normal View History

Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
Benchmark performance of MultiHeadAttention with ORT or PyTorch.
In Linux, run the the following:
sh benchmark_mha.sh
In Windows, run the the following:
benchmark_mha.cmd
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
"""
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
import argparse
import csv
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
import math
import os
import platform
import re
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
import statistics
import sys
import threading
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
import time
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
from contextlib import nullcontext
from datetime import datetime
from enum import IntEnum
from typing import Callable, Dict, List, Optional, Tuple
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
import torch
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
import torch.utils.benchmark as benchmark
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
from onnx import TensorProto, helper
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
from packaging.version import Version
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.functional import scaled_dot_product_attention
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
from onnxruntime.transformers.io_binding_helper import CudaSession
class InputFormats:
Q_K_V_BSNH_BSNH_BSNH = 0
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
QKV_BSN3H = 1
Q_KV_BSNH_BSN2H = 2
Q_K_V_BSNH_BNSH_BNSH = 3 # For cross attention
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
@staticmethod
def input_format_str(format: int) -> str:
names = InputFormats.get_name_list()
return names[format]
@staticmethod
def convert(format_str: str) -> int:
names = InputFormats.get_name_list()
return names.index(format_str)
@staticmethod
def get_name_list() -> List[str]:
return ["Q,K,V", "QKV", "Q,KV", "Q,K',V'"]
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
class SdpaKernel(IntEnum):
"""Bit flags for sdpa_kernel CUDA provider option"""
DEFAULT = 0
FLASH_ATTENTION = 1
EFFICIENT_ATTENTION = 2
TRT_FUSED_ATTENTION = 4
CUDNN_FLASH_ATTENTION = 8
MATH = 16
TRT_FLASH_ATTENTION = 32
TRT_CROSS_ATTENTION = 64
TRT_CAUSAL_ATTENTION = 128
# Since we support attention bias, so we only need support up to 2D mask.
class AttentionMaskFormat(IntEnum):
Mask_None = 0 # No attention mask.
Mask_1D_Key_SeqLen = 1 # Shape (batch_size), actual sequence lengths (excluding padding on the right side).
Mask_2D_Key_PaddingMask = 2 # Shape (batch_size, total_sequence_length), key padding mask mask.
class MultiHeadAttentionConfig:
def __init__(
self,
batch_size: int,
sequence_length: int,
num_heads: int,
head_size: int,
causal: bool,
past_sequence_length: int = 0,
kv_sequence_length=None,
max_cache_sequence_length=None,
scale: float = 0.0,
provider="CPUExecutionProvider",
device: Optional[torch.device] = None,
enable_cuda_graph: bool = False,
dtype=torch.float,
use_kv_cache: bool = False,
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498) ### Description #### Issues Fixed (1) **TRT cross attention not thread safe**. [Core changes like this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6) are used to make it thread-safe: * Add an once_flag to CumulatedSequenceLengthCache to make sure it is only initialized once; and change the cache to be read only after initialization. Previously, the content is not read-only so it might be changed by other thread and potentially cause buffer overrun. * The kernel initialization is not guarded (Although the factory of kernel loading has static mutex to guard multiple threading), so the mutable variable might be set by two different threads at the same time. Add an once_flag to avoid that. This requires need some workspace computation change as well. So I did not create a separated pull request. (2) **Bias for cross attention** That scenario has assumption that only query has bias, but not for key and value. However, such assumption is not verified in runtime and there was no comment of assumption, and there was no test case so the support of scenario was disabled by mistake. Actually, the scenario is used in whisper model (TODO: we shall add tests for whisper to CI pipeline, and also update fusion script to verify such assumptions if needed.) CUDA/CPU kernels supports bias for cross attention as long as bias is zero for key and value. I updated the check to support the scenario and added comments wherever there is such assumption. (3) **Fallback support** Previously, unfused kernel did not support packed qkv and packed kv formats. That means some case might fail since there is no fallback. I added new AddBiasTranpose cuda kernels for them to support fallback, so that all supported cases will not fail. #### Improvements (4) **QKV workspace size**. The logic for no_qkv_workspace could be easily out of sync since related code are scattered in different source files. I refactor the code to move all related code to one file (attention_prepare_qkv.cu) and add asserts, so that the logic can be in sync. (5) **Remove confusing concept of pass past in kv** parameters.pass_past_in_kv is confusing since the k/v in cross attention is not past state. Remove it and use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH instead. New code does not use past_key/past_value for cross attention, so the logic is more clear. (6) **More coverage and less workspace and less transpose of flash and efficient attention** Previously, there is one condition does not run flash or efficient attention: ``` bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; ``` After this change, we can use flash and efficient attention for the case, and also less workspace. For example, cross attention with bias, the original code uses two additional workspaces: ``` transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) add bias: query => q, temp_k_workspace => k, temp_v_workspace => v ``` New logic is like ``` if (has bias) Add bias to query, key, value, and store in q, k, v workspace else Use query, key and value directly as q, k and v in kernel ``` We can see that, we do not need allocate temp_k_workspace and temp_v_workspace so use less memory. New code saved two transposes in this case. Flash and efficient attention supports BSNH or BNSH formats for k and v. In old code, k/v are also converted to BSNH format. Some is not necessary. I do some change to convert k/v to BSNH or BNSH case by case. So that there are more cases can be covered by flash or efficient attention to improve performance. (6) **Debugging support** Previously, there is less debug info. In this change, I add a flag for debug info in the AttentionData. So that we can output debug info during the processing. Also add functions to consolidate the dumping of inputs, QKV processing and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow disable dumping from cuda kernel. #### Summary of changes (1) Refactoring the CheckInputs, and pass in operator type. (2) Refactoring the PrepareQKV to support fallback for packed qkv or packed kv inputs. (3) Change a few case of PrepareQKV to allow more case covered by flash and efficient attention. (4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace parameters.pass_past_in_kv (5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of assumption that key/value has no bias in this case. (6) Fix thread-safe issue in CumulatedSequenceLengthCache handling. (7) Add test cases to cover all supported scenarios. Current support scenarios for MultiHeadAttention for CUDA/CPU: | Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc | ---- | ---- | ---- | ------ | ----- | --------- | -------- | -----|--------- | BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed | BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in CPU | BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in CPU | BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for Q only | BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only present | BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and present <br> (not share buffer) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
has_past_input: bool = False,
share_past_present_buffer: bool = False,
input_format: int = InputFormats.Q_K_V_BSNH_BSNH_BSNH,
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
verbose: bool = False,
has_bias: bool = False, # bias for input projection
has_attn_bias: bool = False, # bias added before softmax. For example,relative position bias.
broadcast_attn_bias_dim_0: bool = False, # broadcast attention bias dimension 0
broadcast_attn_bias_dim_1: bool = False, # broadcast attention bias dimension 1
mask_format: int = AttentionMaskFormat.Mask_None,
):
self.operator = "MultiHeadAttention"
self.batch_size = batch_size
self.sequence_length = sequence_length
self.kv_sequence_length = kv_sequence_length or sequence_length
self.max_cache_sequence_length = max_cache_sequence_length
self.past_sequence_length = past_sequence_length
self.num_heads = num_heads
self.head_size = head_size
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
self.causal = causal
self.scale = scale or (1.0 / (head_size**0.5))
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498) ### Description #### Issues Fixed (1) **TRT cross attention not thread safe**. [Core changes like this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6) are used to make it thread-safe: * Add an once_flag to CumulatedSequenceLengthCache to make sure it is only initialized once; and change the cache to be read only after initialization. Previously, the content is not read-only so it might be changed by other thread and potentially cause buffer overrun. * The kernel initialization is not guarded (Although the factory of kernel loading has static mutex to guard multiple threading), so the mutable variable might be set by two different threads at the same time. Add an once_flag to avoid that. This requires need some workspace computation change as well. So I did not create a separated pull request. (2) **Bias for cross attention** That scenario has assumption that only query has bias, but not for key and value. However, such assumption is not verified in runtime and there was no comment of assumption, and there was no test case so the support of scenario was disabled by mistake. Actually, the scenario is used in whisper model (TODO: we shall add tests for whisper to CI pipeline, and also update fusion script to verify such assumptions if needed.) CUDA/CPU kernels supports bias for cross attention as long as bias is zero for key and value. I updated the check to support the scenario and added comments wherever there is such assumption. (3) **Fallback support** Previously, unfused kernel did not support packed qkv and packed kv formats. That means some case might fail since there is no fallback. I added new AddBiasTranpose cuda kernels for them to support fallback, so that all supported cases will not fail. #### Improvements (4) **QKV workspace size**. The logic for no_qkv_workspace could be easily out of sync since related code are scattered in different source files. I refactor the code to move all related code to one file (attention_prepare_qkv.cu) and add asserts, so that the logic can be in sync. (5) **Remove confusing concept of pass past in kv** parameters.pass_past_in_kv is confusing since the k/v in cross attention is not past state. Remove it and use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH instead. New code does not use past_key/past_value for cross attention, so the logic is more clear. (6) **More coverage and less workspace and less transpose of flash and efficient attention** Previously, there is one condition does not run flash or efficient attention: ``` bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; ``` After this change, we can use flash and efficient attention for the case, and also less workspace. For example, cross attention with bias, the original code uses two additional workspaces: ``` transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) add bias: query => q, temp_k_workspace => k, temp_v_workspace => v ``` New logic is like ``` if (has bias) Add bias to query, key, value, and store in q, k, v workspace else Use query, key and value directly as q, k and v in kernel ``` We can see that, we do not need allocate temp_k_workspace and temp_v_workspace so use less memory. New code saved two transposes in this case. Flash and efficient attention supports BSNH or BNSH formats for k and v. In old code, k/v are also converted to BSNH format. Some is not necessary. I do some change to convert k/v to BSNH or BNSH case by case. So that there are more cases can be covered by flash or efficient attention to improve performance. (6) **Debugging support** Previously, there is less debug info. In this change, I add a flag for debug info in the AttentionData. So that we can output debug info during the processing. Also add functions to consolidate the dumping of inputs, QKV processing and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow disable dumping from cuda kernel. #### Summary of changes (1) Refactoring the CheckInputs, and pass in operator type. (2) Refactoring the PrepareQKV to support fallback for packed qkv or packed kv inputs. (3) Change a few case of PrepareQKV to allow more case covered by flash and efficient attention. (4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace parameters.pass_past_in_kv (5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of assumption that key/value has no bias in this case. (6) Fix thread-safe issue in CumulatedSequenceLengthCache handling. (7) Add test cases to cover all supported scenarios. Current support scenarios for MultiHeadAttention for CUDA/CPU: | Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc | ---- | ---- | ---- | ------ | ----- | --------- | -------- | -----|--------- | BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed | BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in CPU | BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in CPU | BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for Q only | BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only present | BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and present <br> (not share buffer) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
# Support the case that there is no past but need present output (for prompt case).
self.has_past_input = has_past_input
if has_past_input:
assert use_kv_cache
else: # no past input
assert past_sequence_length == 0
self.has_present_output = use_kv_cache
self.use_kv_cache = use_kv_cache
if not use_kv_cache:
assert past_sequence_length == 0
else:
assert self.kv_sequence_length == self.sequence_length
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498) ### Description #### Issues Fixed (1) **TRT cross attention not thread safe**. [Core changes like this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6) are used to make it thread-safe: * Add an once_flag to CumulatedSequenceLengthCache to make sure it is only initialized once; and change the cache to be read only after initialization. Previously, the content is not read-only so it might be changed by other thread and potentially cause buffer overrun. * The kernel initialization is not guarded (Although the factory of kernel loading has static mutex to guard multiple threading), so the mutable variable might be set by two different threads at the same time. Add an once_flag to avoid that. This requires need some workspace computation change as well. So I did not create a separated pull request. (2) **Bias for cross attention** That scenario has assumption that only query has bias, but not for key and value. However, such assumption is not verified in runtime and there was no comment of assumption, and there was no test case so the support of scenario was disabled by mistake. Actually, the scenario is used in whisper model (TODO: we shall add tests for whisper to CI pipeline, and also update fusion script to verify such assumptions if needed.) CUDA/CPU kernels supports bias for cross attention as long as bias is zero for key and value. I updated the check to support the scenario and added comments wherever there is such assumption. (3) **Fallback support** Previously, unfused kernel did not support packed qkv and packed kv formats. That means some case might fail since there is no fallback. I added new AddBiasTranpose cuda kernels for them to support fallback, so that all supported cases will not fail. #### Improvements (4) **QKV workspace size**. The logic for no_qkv_workspace could be easily out of sync since related code are scattered in different source files. I refactor the code to move all related code to one file (attention_prepare_qkv.cu) and add asserts, so that the logic can be in sync. (5) **Remove confusing concept of pass past in kv** parameters.pass_past_in_kv is confusing since the k/v in cross attention is not past state. Remove it and use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH instead. New code does not use past_key/past_value for cross attention, so the logic is more clear. (6) **More coverage and less workspace and less transpose of flash and efficient attention** Previously, there is one condition does not run flash or efficient attention: ``` bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; ``` After this change, we can use flash and efficient attention for the case, and also less workspace. For example, cross attention with bias, the original code uses two additional workspaces: ``` transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) add bias: query => q, temp_k_workspace => k, temp_v_workspace => v ``` New logic is like ``` if (has bias) Add bias to query, key, value, and store in q, k, v workspace else Use query, key and value directly as q, k and v in kernel ``` We can see that, we do not need allocate temp_k_workspace and temp_v_workspace so use less memory. New code saved two transposes in this case. Flash and efficient attention supports BSNH or BNSH formats for k and v. In old code, k/v are also converted to BSNH format. Some is not necessary. I do some change to convert k/v to BSNH or BNSH case by case. So that there are more cases can be covered by flash or efficient attention to improve performance. (6) **Debugging support** Previously, there is less debug info. In this change, I add a flag for debug info in the AttentionData. So that we can output debug info during the processing. Also add functions to consolidate the dumping of inputs, QKV processing and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow disable dumping from cuda kernel. #### Summary of changes (1) Refactoring the CheckInputs, and pass in operator type. (2) Refactoring the PrepareQKV to support fallback for packed qkv or packed kv inputs. (3) Change a few case of PrepareQKV to allow more case covered by flash and efficient attention. (4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace parameters.pass_past_in_kv (5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of assumption that key/value has no bias in this case. (6) Fix thread-safe issue in CumulatedSequenceLengthCache handling. (7) Add test cases to cover all supported scenarios. Current support scenarios for MultiHeadAttention for CUDA/CPU: | Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc | ---- | ---- | ---- | ------ | ----- | --------- | -------- | -----|--------- | BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed | BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in CPU | BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in CPU | BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for Q only | BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only present | BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and present <br> (not share buffer) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
# Only BSNH input format supports past state.
if input_format != InputFormats.Q_K_V_BSNH_BSNH_BSNH:
assert not self.has_past_input
assert not self.has_present_output
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
# Derived values
self.total_sequence_length = self.kv_sequence_length + past_sequence_length
self.past_buffer_length = self.max_cache_sequence_length if share_past_present_buffer else past_sequence_length
self.present_buffer_length = (
self.max_cache_sequence_length if share_past_present_buffer else self.total_sequence_length
)
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
self.provider = provider
self.device = device
self.enable_cuda_graph = enable_cuda_graph
self.dtype = dtype
self.share_past_present_buffer = share_past_present_buffer
self.input_format = input_format
self.is_packed_qkv = input_format == InputFormats.QKV_BSN3H
self.is_packed_kv = input_format == InputFormats.Q_KV_BSNH_BSN2H
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
self.verbose = verbose
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498) ### Description #### Issues Fixed (1) **TRT cross attention not thread safe**. [Core changes like this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6) are used to make it thread-safe: * Add an once_flag to CumulatedSequenceLengthCache to make sure it is only initialized once; and change the cache to be read only after initialization. Previously, the content is not read-only so it might be changed by other thread and potentially cause buffer overrun. * The kernel initialization is not guarded (Although the factory of kernel loading has static mutex to guard multiple threading), so the mutable variable might be set by two different threads at the same time. Add an once_flag to avoid that. This requires need some workspace computation change as well. So I did not create a separated pull request. (2) **Bias for cross attention** That scenario has assumption that only query has bias, but not for key and value. However, such assumption is not verified in runtime and there was no comment of assumption, and there was no test case so the support of scenario was disabled by mistake. Actually, the scenario is used in whisper model (TODO: we shall add tests for whisper to CI pipeline, and also update fusion script to verify such assumptions if needed.) CUDA/CPU kernels supports bias for cross attention as long as bias is zero for key and value. I updated the check to support the scenario and added comments wherever there is such assumption. (3) **Fallback support** Previously, unfused kernel did not support packed qkv and packed kv formats. That means some case might fail since there is no fallback. I added new AddBiasTranpose cuda kernels for them to support fallback, so that all supported cases will not fail. #### Improvements (4) **QKV workspace size**. The logic for no_qkv_workspace could be easily out of sync since related code are scattered in different source files. I refactor the code to move all related code to one file (attention_prepare_qkv.cu) and add asserts, so that the logic can be in sync. (5) **Remove confusing concept of pass past in kv** parameters.pass_past_in_kv is confusing since the k/v in cross attention is not past state. Remove it and use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH instead. New code does not use past_key/past_value for cross attention, so the logic is more clear. (6) **More coverage and less workspace and less transpose of flash and efficient attention** Previously, there is one condition does not run flash or efficient attention: ``` bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; ``` After this change, we can use flash and efficient attention for the case, and also less workspace. For example, cross attention with bias, the original code uses two additional workspaces: ``` transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) add bias: query => q, temp_k_workspace => k, temp_v_workspace => v ``` New logic is like ``` if (has bias) Add bias to query, key, value, and store in q, k, v workspace else Use query, key and value directly as q, k and v in kernel ``` We can see that, we do not need allocate temp_k_workspace and temp_v_workspace so use less memory. New code saved two transposes in this case. Flash and efficient attention supports BSNH or BNSH formats for k and v. In old code, k/v are also converted to BSNH format. Some is not necessary. I do some change to convert k/v to BSNH or BNSH case by case. So that there are more cases can be covered by flash or efficient attention to improve performance. (6) **Debugging support** Previously, there is less debug info. In this change, I add a flag for debug info in the AttentionData. So that we can output debug info during the processing. Also add functions to consolidate the dumping of inputs, QKV processing and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow disable dumping from cuda kernel. #### Summary of changes (1) Refactoring the CheckInputs, and pass in operator type. (2) Refactoring the PrepareQKV to support fallback for packed qkv or packed kv inputs. (3) Change a few case of PrepareQKV to allow more case covered by flash and efficient attention. (4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace parameters.pass_past_in_kv (5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of assumption that key/value has no bias in this case. (6) Fix thread-safe issue in CumulatedSequenceLengthCache handling. (7) Add test cases to cover all supported scenarios. Current support scenarios for MultiHeadAttention for CUDA/CPU: | Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc | ---- | ---- | ---- | ------ | ----- | --------- | -------- | -----|--------- | BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed | BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in CPU | BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in CPU | BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for Q only | BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only present | BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and present <br> (not share buffer) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
self.has_bias = has_bias
self.has_attn_bias = has_attn_bias
self.broadcast_attn_bias_dim_0 = broadcast_attn_bias_dim_0
self.broadcast_attn_bias_dim_1 = broadcast_attn_bias_dim_1
assert mask_format in [
AttentionMaskFormat.Mask_None,
AttentionMaskFormat.Mask_1D_Key_SeqLen,
AttentionMaskFormat.Mask_2D_Key_PaddingMask,
]
self.mask_format = mask_format
# mask_index_q and mask_index_kv will be updated in random_inputs() if mask_format is not Mask_None.
self.mask_index_kv = torch.ones(self.batch_size, dtype=torch.int32, device=self.device) * self.sequence_length
self.mask_index_q = (
torch.ones(self.batch_size, dtype=torch.int32, device=self.device) * self.total_sequence_length
)
assert mask_format in [
AttentionMaskFormat.Mask_None,
AttentionMaskFormat.Mask_1D_Key_SeqLen,
AttentionMaskFormat.Mask_2D_Key_PaddingMask,
]
self.mask_format = mask_format
# mask_index_q and mask_index_kv will be updated in random_inputs() if mask_format is not Mask_None.
self.mask_index_kv = torch.ones(self.batch_size, dtype=torch.int32, device=self.device) * self.sequence_length
self.mask_index_q = (
torch.ones(self.batch_size, dtype=torch.int32, device=self.device) * self.total_sequence_length
)
def __repr__(self):
return (
f"MultiHeadAttentionConfig(batch_size={self.batch_size}, sequence_length={self.sequence_length}, "
f"num_heads={self.num_heads}, head_size={self.head_size}, "
f"kv_sequence_length={self.kv_sequence_length}, past_sequence_length={self.past_sequence_length}, "
f"max_cache_sequence_length={self.max_cache_sequence_length},"
f"causal={self.causal}), scale={self.scale}, use_kv_cache={self.use_kv_cache}, "
f"share_past_present_buffer={self.share_past_present_buffer}, "
f"provider={self.provider}, device={self.device}, enable_cuda_graph={self.enable_cuda_graph}, "
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498) ### Description #### Issues Fixed (1) **TRT cross attention not thread safe**. [Core changes like this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6) are used to make it thread-safe: * Add an once_flag to CumulatedSequenceLengthCache to make sure it is only initialized once; and change the cache to be read only after initialization. Previously, the content is not read-only so it might be changed by other thread and potentially cause buffer overrun. * The kernel initialization is not guarded (Although the factory of kernel loading has static mutex to guard multiple threading), so the mutable variable might be set by two different threads at the same time. Add an once_flag to avoid that. This requires need some workspace computation change as well. So I did not create a separated pull request. (2) **Bias for cross attention** That scenario has assumption that only query has bias, but not for key and value. However, such assumption is not verified in runtime and there was no comment of assumption, and there was no test case so the support of scenario was disabled by mistake. Actually, the scenario is used in whisper model (TODO: we shall add tests for whisper to CI pipeline, and also update fusion script to verify such assumptions if needed.) CUDA/CPU kernels supports bias for cross attention as long as bias is zero for key and value. I updated the check to support the scenario and added comments wherever there is such assumption. (3) **Fallback support** Previously, unfused kernel did not support packed qkv and packed kv formats. That means some case might fail since there is no fallback. I added new AddBiasTranpose cuda kernels for them to support fallback, so that all supported cases will not fail. #### Improvements (4) **QKV workspace size**. The logic for no_qkv_workspace could be easily out of sync since related code are scattered in different source files. I refactor the code to move all related code to one file (attention_prepare_qkv.cu) and add asserts, so that the logic can be in sync. (5) **Remove confusing concept of pass past in kv** parameters.pass_past_in_kv is confusing since the k/v in cross attention is not past state. Remove it and use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH instead. New code does not use past_key/past_value for cross attention, so the logic is more clear. (6) **More coverage and less workspace and less transpose of flash and efficient attention** Previously, there is one condition does not run flash or efficient attention: ``` bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; ``` After this change, we can use flash and efficient attention for the case, and also less workspace. For example, cross attention with bias, the original code uses two additional workspaces: ``` transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) add bias: query => q, temp_k_workspace => k, temp_v_workspace => v ``` New logic is like ``` if (has bias) Add bias to query, key, value, and store in q, k, v workspace else Use query, key and value directly as q, k and v in kernel ``` We can see that, we do not need allocate temp_k_workspace and temp_v_workspace so use less memory. New code saved two transposes in this case. Flash and efficient attention supports BSNH or BNSH formats for k and v. In old code, k/v are also converted to BSNH format. Some is not necessary. I do some change to convert k/v to BSNH or BNSH case by case. So that there are more cases can be covered by flash or efficient attention to improve performance. (6) **Debugging support** Previously, there is less debug info. In this change, I add a flag for debug info in the AttentionData. So that we can output debug info during the processing. Also add functions to consolidate the dumping of inputs, QKV processing and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow disable dumping from cuda kernel. #### Summary of changes (1) Refactoring the CheckInputs, and pass in operator type. (2) Refactoring the PrepareQKV to support fallback for packed qkv or packed kv inputs. (3) Change a few case of PrepareQKV to allow more case covered by flash and efficient attention. (4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace parameters.pass_past_in_kv (5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of assumption that key/value has no bias in this case. (6) Fix thread-safe issue in CumulatedSequenceLengthCache handling. (7) Add test cases to cover all supported scenarios. Current support scenarios for MultiHeadAttention for CUDA/CPU: | Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc | ---- | ---- | ---- | ------ | ----- | --------- | -------- | -----|--------- | BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed | BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in CPU | BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in CPU | BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for Q only | BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only present | BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and present <br> (not share buffer) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
f"dtype={self.dtype}, input_format={InputFormats.input_format_str(self.input_format)}, "
f"has_bias={self.has_bias}, mask_format={self.mask_format}, "
f"has_attn_bias={self.has_attn_bias}, "
f"broadcast_attn_bias_dim_0={self.broadcast_attn_bias_dim_0}, "
f"broadcast_attn_bias_dim_1={self.broadcast_attn_bias_dim_1}, "
)
def shape_dict(self, input_format=None):
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
shapes: Dict[str, Tuple] = {
"output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
}
input_format = input_format or self.input_format
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
if input_format == InputFormats.QKV_BSN3H:
shapes = {
**shapes,
"query": (self.batch_size, self.sequence_length, self.num_heads, 3, self.head_size),
}
elif input_format == InputFormats.Q_KV_BSNH_BSN2H:
shapes = {
**shapes,
"query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
"key": (self.batch_size, self.sequence_length, self.num_heads, 2, self.head_size),
}
elif input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH:
shapes = {
**shapes,
"query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
"key": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
"value": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
}
else:
assert input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH
shapes = {
**shapes,
"query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
"key": (self.batch_size, self.num_heads, self.sequence_length, self.head_size),
"value": (self.batch_size, self.num_heads, self.sequence_length, self.head_size),
}
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498) ### Description #### Issues Fixed (1) **TRT cross attention not thread safe**. [Core changes like this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6) are used to make it thread-safe: * Add an once_flag to CumulatedSequenceLengthCache to make sure it is only initialized once; and change the cache to be read only after initialization. Previously, the content is not read-only so it might be changed by other thread and potentially cause buffer overrun. * The kernel initialization is not guarded (Although the factory of kernel loading has static mutex to guard multiple threading), so the mutable variable might be set by two different threads at the same time. Add an once_flag to avoid that. This requires need some workspace computation change as well. So I did not create a separated pull request. (2) **Bias for cross attention** That scenario has assumption that only query has bias, but not for key and value. However, such assumption is not verified in runtime and there was no comment of assumption, and there was no test case so the support of scenario was disabled by mistake. Actually, the scenario is used in whisper model (TODO: we shall add tests for whisper to CI pipeline, and also update fusion script to verify such assumptions if needed.) CUDA/CPU kernels supports bias for cross attention as long as bias is zero for key and value. I updated the check to support the scenario and added comments wherever there is such assumption. (3) **Fallback support** Previously, unfused kernel did not support packed qkv and packed kv formats. That means some case might fail since there is no fallback. I added new AddBiasTranpose cuda kernels for them to support fallback, so that all supported cases will not fail. #### Improvements (4) **QKV workspace size**. The logic for no_qkv_workspace could be easily out of sync since related code are scattered in different source files. I refactor the code to move all related code to one file (attention_prepare_qkv.cu) and add asserts, so that the logic can be in sync. (5) **Remove confusing concept of pass past in kv** parameters.pass_past_in_kv is confusing since the k/v in cross attention is not past state. Remove it and use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH instead. New code does not use past_key/past_value for cross attention, so the logic is more clear. (6) **More coverage and less workspace and less transpose of flash and efficient attention** Previously, there is one condition does not run flash or efficient attention: ``` bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; ``` After this change, we can use flash and efficient attention for the case, and also less workspace. For example, cross attention with bias, the original code uses two additional workspaces: ``` transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) add bias: query => q, temp_k_workspace => k, temp_v_workspace => v ``` New logic is like ``` if (has bias) Add bias to query, key, value, and store in q, k, v workspace else Use query, key and value directly as q, k and v in kernel ``` We can see that, we do not need allocate temp_k_workspace and temp_v_workspace so use less memory. New code saved two transposes in this case. Flash and efficient attention supports BSNH or BNSH formats for k and v. In old code, k/v are also converted to BSNH format. Some is not necessary. I do some change to convert k/v to BSNH or BNSH case by case. So that there are more cases can be covered by flash or efficient attention to improve performance. (6) **Debugging support** Previously, there is less debug info. In this change, I add a flag for debug info in the AttentionData. So that we can output debug info during the processing. Also add functions to consolidate the dumping of inputs, QKV processing and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow disable dumping from cuda kernel. #### Summary of changes (1) Refactoring the CheckInputs, and pass in operator type. (2) Refactoring the PrepareQKV to support fallback for packed qkv or packed kv inputs. (3) Change a few case of PrepareQKV to allow more case covered by flash and efficient attention. (4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace parameters.pass_past_in_kv (5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of assumption that key/value has no bias in this case. (6) Fix thread-safe issue in CumulatedSequenceLengthCache handling. (7) Add test cases to cover all supported scenarios. Current support scenarios for MultiHeadAttention for CUDA/CPU: | Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc | ---- | ---- | ---- | ------ | ----- | --------- | -------- | -----|--------- | BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed | BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in CPU | BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in CPU | BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for Q only | BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only present | BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and present <br> (not share buffer) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
if self.has_past_input:
shapes = {
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
**shapes,
"past_key": (self.batch_size, self.num_heads, self.past_buffer_length, self.head_size),
"past_value": (self.batch_size, self.num_heads, self.past_buffer_length, self.head_size),
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498) ### Description #### Issues Fixed (1) **TRT cross attention not thread safe**. [Core changes like this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6) are used to make it thread-safe: * Add an once_flag to CumulatedSequenceLengthCache to make sure it is only initialized once; and change the cache to be read only after initialization. Previously, the content is not read-only so it might be changed by other thread and potentially cause buffer overrun. * The kernel initialization is not guarded (Although the factory of kernel loading has static mutex to guard multiple threading), so the mutable variable might be set by two different threads at the same time. Add an once_flag to avoid that. This requires need some workspace computation change as well. So I did not create a separated pull request. (2) **Bias for cross attention** That scenario has assumption that only query has bias, but not for key and value. However, such assumption is not verified in runtime and there was no comment of assumption, and there was no test case so the support of scenario was disabled by mistake. Actually, the scenario is used in whisper model (TODO: we shall add tests for whisper to CI pipeline, and also update fusion script to verify such assumptions if needed.) CUDA/CPU kernels supports bias for cross attention as long as bias is zero for key and value. I updated the check to support the scenario and added comments wherever there is such assumption. (3) **Fallback support** Previously, unfused kernel did not support packed qkv and packed kv formats. That means some case might fail since there is no fallback. I added new AddBiasTranpose cuda kernels for them to support fallback, so that all supported cases will not fail. #### Improvements (4) **QKV workspace size**. The logic for no_qkv_workspace could be easily out of sync since related code are scattered in different source files. I refactor the code to move all related code to one file (attention_prepare_qkv.cu) and add asserts, so that the logic can be in sync. (5) **Remove confusing concept of pass past in kv** parameters.pass_past_in_kv is confusing since the k/v in cross attention is not past state. Remove it and use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH instead. New code does not use past_key/past_value for cross attention, so the logic is more clear. (6) **More coverage and less workspace and less transpose of flash and efficient attention** Previously, there is one condition does not run flash or efficient attention: ``` bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; ``` After this change, we can use flash and efficient attention for the case, and also less workspace. For example, cross attention with bias, the original code uses two additional workspaces: ``` transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) add bias: query => q, temp_k_workspace => k, temp_v_workspace => v ``` New logic is like ``` if (has bias) Add bias to query, key, value, and store in q, k, v workspace else Use query, key and value directly as q, k and v in kernel ``` We can see that, we do not need allocate temp_k_workspace and temp_v_workspace so use less memory. New code saved two transposes in this case. Flash and efficient attention supports BSNH or BNSH formats for k and v. In old code, k/v are also converted to BSNH format. Some is not necessary. I do some change to convert k/v to BSNH or BNSH case by case. So that there are more cases can be covered by flash or efficient attention to improve performance. (6) **Debugging support** Previously, there is less debug info. In this change, I add a flag for debug info in the AttentionData. So that we can output debug info during the processing. Also add functions to consolidate the dumping of inputs, QKV processing and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow disable dumping from cuda kernel. #### Summary of changes (1) Refactoring the CheckInputs, and pass in operator type. (2) Refactoring the PrepareQKV to support fallback for packed qkv or packed kv inputs. (3) Change a few case of PrepareQKV to allow more case covered by flash and efficient attention. (4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace parameters.pass_past_in_kv (5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of assumption that key/value has no bias in this case. (6) Fix thread-safe issue in CumulatedSequenceLengthCache handling. (7) Add test cases to cover all supported scenarios. Current support scenarios for MultiHeadAttention for CUDA/CPU: | Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc | ---- | ---- | ---- | ------ | ----- | --------- | -------- | -----|--------- | BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed | BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in CPU | BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in CPU | BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for Q only | BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only present | BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and present <br> (not share buffer) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
}
if self.has_present_output:
shapes = {
**shapes,
"present_key": (self.batch_size, self.num_heads, self.present_buffer_length, self.head_size),
"present_value": (self.batch_size, self.num_heads, self.present_buffer_length, self.head_size),
}
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498) ### Description #### Issues Fixed (1) **TRT cross attention not thread safe**. [Core changes like this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6) are used to make it thread-safe: * Add an once_flag to CumulatedSequenceLengthCache to make sure it is only initialized once; and change the cache to be read only after initialization. Previously, the content is not read-only so it might be changed by other thread and potentially cause buffer overrun. * The kernel initialization is not guarded (Although the factory of kernel loading has static mutex to guard multiple threading), so the mutable variable might be set by two different threads at the same time. Add an once_flag to avoid that. This requires need some workspace computation change as well. So I did not create a separated pull request. (2) **Bias for cross attention** That scenario has assumption that only query has bias, but not for key and value. However, such assumption is not verified in runtime and there was no comment of assumption, and there was no test case so the support of scenario was disabled by mistake. Actually, the scenario is used in whisper model (TODO: we shall add tests for whisper to CI pipeline, and also update fusion script to verify such assumptions if needed.) CUDA/CPU kernels supports bias for cross attention as long as bias is zero for key and value. I updated the check to support the scenario and added comments wherever there is such assumption. (3) **Fallback support** Previously, unfused kernel did not support packed qkv and packed kv formats. That means some case might fail since there is no fallback. I added new AddBiasTranpose cuda kernels for them to support fallback, so that all supported cases will not fail. #### Improvements (4) **QKV workspace size**. The logic for no_qkv_workspace could be easily out of sync since related code are scattered in different source files. I refactor the code to move all related code to one file (attention_prepare_qkv.cu) and add asserts, so that the logic can be in sync. (5) **Remove confusing concept of pass past in kv** parameters.pass_past_in_kv is confusing since the k/v in cross attention is not past state. Remove it and use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH instead. New code does not use past_key/past_value for cross attention, so the logic is more clear. (6) **More coverage and less workspace and less transpose of flash and efficient attention** Previously, there is one condition does not run flash or efficient attention: ``` bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; ``` After this change, we can use flash and efficient attention for the case, and also less workspace. For example, cross attention with bias, the original code uses two additional workspaces: ``` transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) add bias: query => q, temp_k_workspace => k, temp_v_workspace => v ``` New logic is like ``` if (has bias) Add bias to query, key, value, and store in q, k, v workspace else Use query, key and value directly as q, k and v in kernel ``` We can see that, we do not need allocate temp_k_workspace and temp_v_workspace so use less memory. New code saved two transposes in this case. Flash and efficient attention supports BSNH or BNSH formats for k and v. In old code, k/v are also converted to BSNH format. Some is not necessary. I do some change to convert k/v to BSNH or BNSH case by case. So that there are more cases can be covered by flash or efficient attention to improve performance. (6) **Debugging support** Previously, there is less debug info. In this change, I add a flag for debug info in the AttentionData. So that we can output debug info during the processing. Also add functions to consolidate the dumping of inputs, QKV processing and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow disable dumping from cuda kernel. #### Summary of changes (1) Refactoring the CheckInputs, and pass in operator type. (2) Refactoring the PrepareQKV to support fallback for packed qkv or packed kv inputs. (3) Change a few case of PrepareQKV to allow more case covered by flash and efficient attention. (4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace parameters.pass_past_in_kv (5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of assumption that key/value has no bias in this case. (6) Fix thread-safe issue in CumulatedSequenceLengthCache handling. (7) Add test cases to cover all supported scenarios. Current support scenarios for MultiHeadAttention for CUDA/CPU: | Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc | ---- | ---- | ---- | ------ | ----- | --------- | -------- | -----|--------- | BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed | BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in CPU | BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in CPU | BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for Q only | BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only present | BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and present <br> (not share buffer) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
if self.has_bias:
shapes["bias"] = (3 * self.num_heads * self.head_size,)
if self.mask_format == AttentionMaskFormat.Mask_1D_Key_SeqLen:
shapes["mask"] = (self.batch_size,)
elif self.mask_format == AttentionMaskFormat.Mask_2D_Key_PaddingMask:
shapes["mask"] = (self.batch_size, self.total_sequence_length)
else:
assert self.mask_format == AttentionMaskFormat.Mask_None
if self.has_attn_bias:
shapes["attn_bias"] = (
1 if self.broadcast_attn_bias_dim_0 else self.batch_size,
1 if self.broadcast_attn_bias_dim_1 else self.num_heads,
self.sequence_length,
self.total_sequence_length,
)
return shapes
def symbolic_shape_dict(self, input_format=None):
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
shapes: Dict[str, Tuple] = {
"output": ("batch_size", "sequence_length", self.num_heads * self.head_size),
}
input_format = input_format or self.input_format
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
if input_format == InputFormats.QKV_BSN3H:
shapes = {
**shapes,
"query": ("batch_size", "sequence_length", self.num_heads, 3, self.head_size),
}
elif input_format == InputFormats.Q_KV_BSNH_BSN2H:
shapes = {
**shapes,
"query": ("batch_size", "sequence_length", self.num_heads * self.head_size),
"key": ("batch_size", "sequence_length", self.num_heads, 2, self.head_size),
}
elif input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH:
shapes = {
**shapes,
"query": ("batch_size", "sequence_length", self.num_heads * self.head_size),
"key": ("batch_size", "sequence_length", self.num_heads * self.head_size),
"value": ("batch_size", "sequence_length", self.num_heads * self.head_size),
}
else:
assert input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH
shapes = {
**shapes,
"query": ("batch_size", "sequence_length", self.num_heads * self.head_size),
"key": ("batch_size", self.num_heads, "sequence_length", self.head_size),
"value": ("batch_size", self.num_heads, "sequence_length", self.head_size),
}
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498) ### Description #### Issues Fixed (1) **TRT cross attention not thread safe**. [Core changes like this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6) are used to make it thread-safe: * Add an once_flag to CumulatedSequenceLengthCache to make sure it is only initialized once; and change the cache to be read only after initialization. Previously, the content is not read-only so it might be changed by other thread and potentially cause buffer overrun. * The kernel initialization is not guarded (Although the factory of kernel loading has static mutex to guard multiple threading), so the mutable variable might be set by two different threads at the same time. Add an once_flag to avoid that. This requires need some workspace computation change as well. So I did not create a separated pull request. (2) **Bias for cross attention** That scenario has assumption that only query has bias, but not for key and value. However, such assumption is not verified in runtime and there was no comment of assumption, and there was no test case so the support of scenario was disabled by mistake. Actually, the scenario is used in whisper model (TODO: we shall add tests for whisper to CI pipeline, and also update fusion script to verify such assumptions if needed.) CUDA/CPU kernels supports bias for cross attention as long as bias is zero for key and value. I updated the check to support the scenario and added comments wherever there is such assumption. (3) **Fallback support** Previously, unfused kernel did not support packed qkv and packed kv formats. That means some case might fail since there is no fallback. I added new AddBiasTranpose cuda kernels for them to support fallback, so that all supported cases will not fail. #### Improvements (4) **QKV workspace size**. The logic for no_qkv_workspace could be easily out of sync since related code are scattered in different source files. I refactor the code to move all related code to one file (attention_prepare_qkv.cu) and add asserts, so that the logic can be in sync. (5) **Remove confusing concept of pass past in kv** parameters.pass_past_in_kv is confusing since the k/v in cross attention is not past state. Remove it and use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH instead. New code does not use past_key/past_value for cross attention, so the logic is more clear. (6) **More coverage and less workspace and less transpose of flash and efficient attention** Previously, there is one condition does not run flash or efficient attention: ``` bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; ``` After this change, we can use flash and efficient attention for the case, and also less workspace. For example, cross attention with bias, the original code uses two additional workspaces: ``` transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) add bias: query => q, temp_k_workspace => k, temp_v_workspace => v ``` New logic is like ``` if (has bias) Add bias to query, key, value, and store in q, k, v workspace else Use query, key and value directly as q, k and v in kernel ``` We can see that, we do not need allocate temp_k_workspace and temp_v_workspace so use less memory. New code saved two transposes in this case. Flash and efficient attention supports BSNH or BNSH formats for k and v. In old code, k/v are also converted to BSNH format. Some is not necessary. I do some change to convert k/v to BSNH or BNSH case by case. So that there are more cases can be covered by flash or efficient attention to improve performance. (6) **Debugging support** Previously, there is less debug info. In this change, I add a flag for debug info in the AttentionData. So that we can output debug info during the processing. Also add functions to consolidate the dumping of inputs, QKV processing and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow disable dumping from cuda kernel. #### Summary of changes (1) Refactoring the CheckInputs, and pass in operator type. (2) Refactoring the PrepareQKV to support fallback for packed qkv or packed kv inputs. (3) Change a few case of PrepareQKV to allow more case covered by flash and efficient attention. (4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace parameters.pass_past_in_kv (5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of assumption that key/value has no bias in this case. (6) Fix thread-safe issue in CumulatedSequenceLengthCache handling. (7) Add test cases to cover all supported scenarios. Current support scenarios for MultiHeadAttention for CUDA/CPU: | Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc | ---- | ---- | ---- | ------ | ----- | --------- | -------- | -----|--------- | BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed | BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in CPU | BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in CPU | BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for Q only | BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only present | BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and present <br> (not share buffer) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
if self.has_past_input:
shapes = {
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
**shapes,
"past_key": ("batch_size", self.num_heads, "past_buffer_length", self.head_size),
"past_value": ("batch_size", self.num_heads, "past_buffer_length", self.head_size),
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498) ### Description #### Issues Fixed (1) **TRT cross attention not thread safe**. [Core changes like this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6) are used to make it thread-safe: * Add an once_flag to CumulatedSequenceLengthCache to make sure it is only initialized once; and change the cache to be read only after initialization. Previously, the content is not read-only so it might be changed by other thread and potentially cause buffer overrun. * The kernel initialization is not guarded (Although the factory of kernel loading has static mutex to guard multiple threading), so the mutable variable might be set by two different threads at the same time. Add an once_flag to avoid that. This requires need some workspace computation change as well. So I did not create a separated pull request. (2) **Bias for cross attention** That scenario has assumption that only query has bias, but not for key and value. However, such assumption is not verified in runtime and there was no comment of assumption, and there was no test case so the support of scenario was disabled by mistake. Actually, the scenario is used in whisper model (TODO: we shall add tests for whisper to CI pipeline, and also update fusion script to verify such assumptions if needed.) CUDA/CPU kernels supports bias for cross attention as long as bias is zero for key and value. I updated the check to support the scenario and added comments wherever there is such assumption. (3) **Fallback support** Previously, unfused kernel did not support packed qkv and packed kv formats. That means some case might fail since there is no fallback. I added new AddBiasTranpose cuda kernels for them to support fallback, so that all supported cases will not fail. #### Improvements (4) **QKV workspace size**. The logic for no_qkv_workspace could be easily out of sync since related code are scattered in different source files. I refactor the code to move all related code to one file (attention_prepare_qkv.cu) and add asserts, so that the logic can be in sync. (5) **Remove confusing concept of pass past in kv** parameters.pass_past_in_kv is confusing since the k/v in cross attention is not past state. Remove it and use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH instead. New code does not use past_key/past_value for cross attention, so the logic is more clear. (6) **More coverage and less workspace and less transpose of flash and efficient attention** Previously, there is one condition does not run flash or efficient attention: ``` bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; ``` After this change, we can use flash and efficient attention for the case, and also less workspace. For example, cross attention with bias, the original code uses two additional workspaces: ``` transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) add bias: query => q, temp_k_workspace => k, temp_v_workspace => v ``` New logic is like ``` if (has bias) Add bias to query, key, value, and store in q, k, v workspace else Use query, key and value directly as q, k and v in kernel ``` We can see that, we do not need allocate temp_k_workspace and temp_v_workspace so use less memory. New code saved two transposes in this case. Flash and efficient attention supports BSNH or BNSH formats for k and v. In old code, k/v are also converted to BSNH format. Some is not necessary. I do some change to convert k/v to BSNH or BNSH case by case. So that there are more cases can be covered by flash or efficient attention to improve performance. (6) **Debugging support** Previously, there is less debug info. In this change, I add a flag for debug info in the AttentionData. So that we can output debug info during the processing. Also add functions to consolidate the dumping of inputs, QKV processing and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow disable dumping from cuda kernel. #### Summary of changes (1) Refactoring the CheckInputs, and pass in operator type. (2) Refactoring the PrepareQKV to support fallback for packed qkv or packed kv inputs. (3) Change a few case of PrepareQKV to allow more case covered by flash and efficient attention. (4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace parameters.pass_past_in_kv (5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of assumption that key/value has no bias in this case. (6) Fix thread-safe issue in CumulatedSequenceLengthCache handling. (7) Add test cases to cover all supported scenarios. Current support scenarios for MultiHeadAttention for CUDA/CPU: | Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc | ---- | ---- | ---- | ------ | ----- | --------- | -------- | -----|--------- | BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed | BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in CPU | BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in CPU | BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for Q only | BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only present | BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and present <br> (not share buffer) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
}
if self.has_present_output:
shapes = {
**shapes,
"present_key": ("batch_size", self.num_heads, "present_buffer_length", self.head_size),
"present_value": ("batch_size", self.num_heads, "present_buffer_length", self.head_size),
}
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498) ### Description #### Issues Fixed (1) **TRT cross attention not thread safe**. [Core changes like this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6) are used to make it thread-safe: * Add an once_flag to CumulatedSequenceLengthCache to make sure it is only initialized once; and change the cache to be read only after initialization. Previously, the content is not read-only so it might be changed by other thread and potentially cause buffer overrun. * The kernel initialization is not guarded (Although the factory of kernel loading has static mutex to guard multiple threading), so the mutable variable might be set by two different threads at the same time. Add an once_flag to avoid that. This requires need some workspace computation change as well. So I did not create a separated pull request. (2) **Bias for cross attention** That scenario has assumption that only query has bias, but not for key and value. However, such assumption is not verified in runtime and there was no comment of assumption, and there was no test case so the support of scenario was disabled by mistake. Actually, the scenario is used in whisper model (TODO: we shall add tests for whisper to CI pipeline, and also update fusion script to verify such assumptions if needed.) CUDA/CPU kernels supports bias for cross attention as long as bias is zero for key and value. I updated the check to support the scenario and added comments wherever there is such assumption. (3) **Fallback support** Previously, unfused kernel did not support packed qkv and packed kv formats. That means some case might fail since there is no fallback. I added new AddBiasTranpose cuda kernels for them to support fallback, so that all supported cases will not fail. #### Improvements (4) **QKV workspace size**. The logic for no_qkv_workspace could be easily out of sync since related code are scattered in different source files. I refactor the code to move all related code to one file (attention_prepare_qkv.cu) and add asserts, so that the logic can be in sync. (5) **Remove confusing concept of pass past in kv** parameters.pass_past_in_kv is confusing since the k/v in cross attention is not past state. Remove it and use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH instead. New code does not use past_key/past_value for cross attention, so the logic is more clear. (6) **More coverage and less workspace and less transpose of flash and efficient attention** Previously, there is one condition does not run flash or efficient attention: ``` bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; ``` After this change, we can use flash and efficient attention for the case, and also less workspace. For example, cross attention with bias, the original code uses two additional workspaces: ``` transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) add bias: query => q, temp_k_workspace => k, temp_v_workspace => v ``` New logic is like ``` if (has bias) Add bias to query, key, value, and store in q, k, v workspace else Use query, key and value directly as q, k and v in kernel ``` We can see that, we do not need allocate temp_k_workspace and temp_v_workspace so use less memory. New code saved two transposes in this case. Flash and efficient attention supports BSNH or BNSH formats for k and v. In old code, k/v are also converted to BSNH format. Some is not necessary. I do some change to convert k/v to BSNH or BNSH case by case. So that there are more cases can be covered by flash or efficient attention to improve performance. (6) **Debugging support** Previously, there is less debug info. In this change, I add a flag for debug info in the AttentionData. So that we can output debug info during the processing. Also add functions to consolidate the dumping of inputs, QKV processing and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow disable dumping from cuda kernel. #### Summary of changes (1) Refactoring the CheckInputs, and pass in operator type. (2) Refactoring the PrepareQKV to support fallback for packed qkv or packed kv inputs. (3) Change a few case of PrepareQKV to allow more case covered by flash and efficient attention. (4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace parameters.pass_past_in_kv (5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of assumption that key/value has no bias in this case. (6) Fix thread-safe issue in CumulatedSequenceLengthCache handling. (7) Add test cases to cover all supported scenarios. Current support scenarios for MultiHeadAttention for CUDA/CPU: | Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc | ---- | ---- | ---- | ------ | ----- | --------- | -------- | -----|--------- | BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed | BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in CPU | BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in CPU | BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for Q only | BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only present | BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and present <br> (not share buffer) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
if self.has_bias:
shapes["bias"] = (3 * self.num_heads * self.head_size,)
if self.mask_format == AttentionMaskFormat.Mask_1D_Key_SeqLen:
shapes["mask"] = ("batch_size",)
elif self.mask_format == AttentionMaskFormat.Mask_2D_Key_PaddingMask:
shapes["mask"] = ("batch_size", "total_sequence_length")
else:
assert self.mask_format == AttentionMaskFormat.Mask_None
if self.has_attn_bias:
shapes["attn_bias"] = ("batch_size_or_1", "num_heads_or_1", "sequence_length", "total_sequence_length")
return shapes
def right_side_padding_masks(self):
q_mask = torch.ones(self.batch_size, 1, self.sequence_length, 1, dtype=torch.bool, device=self.device)
k_mask = torch.ones(self.batch_size, 1, self.total_sequence_length, 1, dtype=torch.bool, device=self.device)
mask = torch.ones(
self.batch_size,
self.num_heads,
self.sequence_length,
self.total_sequence_length,
dtype=torch.bool,
device=self.device,
)
if self.mask_format != AttentionMaskFormat.Mask_None:
for i, (m, n) in enumerate(zip(self.mask_index_q, self.mask_index_kv)):
q_mask[i, :, m:, :] = False
k_mask[i, :, n:, :] = False
mask[i, :, m:, :] = False
mask[i, :, :, n:] = False
return q_mask, k_mask, mask
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498) ### Description #### Issues Fixed (1) **TRT cross attention not thread safe**. [Core changes like this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6) are used to make it thread-safe: * Add an once_flag to CumulatedSequenceLengthCache to make sure it is only initialized once; and change the cache to be read only after initialization. Previously, the content is not read-only so it might be changed by other thread and potentially cause buffer overrun. * The kernel initialization is not guarded (Although the factory of kernel loading has static mutex to guard multiple threading), so the mutable variable might be set by two different threads at the same time. Add an once_flag to avoid that. This requires need some workspace computation change as well. So I did not create a separated pull request. (2) **Bias for cross attention** That scenario has assumption that only query has bias, but not for key and value. However, such assumption is not verified in runtime and there was no comment of assumption, and there was no test case so the support of scenario was disabled by mistake. Actually, the scenario is used in whisper model (TODO: we shall add tests for whisper to CI pipeline, and also update fusion script to verify such assumptions if needed.) CUDA/CPU kernels supports bias for cross attention as long as bias is zero for key and value. I updated the check to support the scenario and added comments wherever there is such assumption. (3) **Fallback support** Previously, unfused kernel did not support packed qkv and packed kv formats. That means some case might fail since there is no fallback. I added new AddBiasTranpose cuda kernels for them to support fallback, so that all supported cases will not fail. #### Improvements (4) **QKV workspace size**. The logic for no_qkv_workspace could be easily out of sync since related code are scattered in different source files. I refactor the code to move all related code to one file (attention_prepare_qkv.cu) and add asserts, so that the logic can be in sync. (5) **Remove confusing concept of pass past in kv** parameters.pass_past_in_kv is confusing since the k/v in cross attention is not past state. Remove it and use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH instead. New code does not use past_key/past_value for cross attention, so the logic is more clear. (6) **More coverage and less workspace and less transpose of flash and efficient attention** Previously, there is one condition does not run flash or efficient attention: ``` bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; ``` After this change, we can use flash and efficient attention for the case, and also less workspace. For example, cross attention with bias, the original code uses two additional workspaces: ``` transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) add bias: query => q, temp_k_workspace => k, temp_v_workspace => v ``` New logic is like ``` if (has bias) Add bias to query, key, value, and store in q, k, v workspace else Use query, key and value directly as q, k and v in kernel ``` We can see that, we do not need allocate temp_k_workspace and temp_v_workspace so use less memory. New code saved two transposes in this case. Flash and efficient attention supports BSNH or BNSH formats for k and v. In old code, k/v are also converted to BSNH format. Some is not necessary. I do some change to convert k/v to BSNH or BNSH case by case. So that there are more cases can be covered by flash or efficient attention to improve performance. (6) **Debugging support** Previously, there is less debug info. In this change, I add a flag for debug info in the AttentionData. So that we can output debug info during the processing. Also add functions to consolidate the dumping of inputs, QKV processing and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow disable dumping from cuda kernel. #### Summary of changes (1) Refactoring the CheckInputs, and pass in operator type. (2) Refactoring the PrepareQKV to support fallback for packed qkv or packed kv inputs. (3) Change a few case of PrepareQKV to allow more case covered by flash and efficient attention. (4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace parameters.pass_past_in_kv (5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of assumption that key/value has no bias in this case. (6) Fix thread-safe issue in CumulatedSequenceLengthCache handling. (7) Add test cases to cover all supported scenarios. Current support scenarios for MultiHeadAttention for CUDA/CPU: | Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc | ---- | ---- | ---- | ------ | ----- | --------- | -------- | -----|--------- | BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed | BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in CPU | BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in CPU | BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for Q only | BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only present | BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and present <br> (not share buffer) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
def random_inputs(self, seed: int = 123, no_bias_k_v: bool = False):
device = self.device
dtype = self.dtype
shape_dict = self.shape_dict()
if seed > 0:
torch.manual_seed(seed)
shape = (self.batch_size, self.sequence_length, self.num_heads, self.head_size)
q = torch.empty(shape, device=device, dtype=dtype).normal_(mean=0, std=0.1)
k = torch.empty(shape, device=device, dtype=dtype).normal_(mean=0, std=0.1)
v = torch.empty(shape, device=device, dtype=dtype).normal_(mean=0, std=0.1)
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498) ### Description #### Issues Fixed (1) **TRT cross attention not thread safe**. [Core changes like this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6) are used to make it thread-safe: * Add an once_flag to CumulatedSequenceLengthCache to make sure it is only initialized once; and change the cache to be read only after initialization. Previously, the content is not read-only so it might be changed by other thread and potentially cause buffer overrun. * The kernel initialization is not guarded (Although the factory of kernel loading has static mutex to guard multiple threading), so the mutable variable might be set by two different threads at the same time. Add an once_flag to avoid that. This requires need some workspace computation change as well. So I did not create a separated pull request. (2) **Bias for cross attention** That scenario has assumption that only query has bias, but not for key and value. However, such assumption is not verified in runtime and there was no comment of assumption, and there was no test case so the support of scenario was disabled by mistake. Actually, the scenario is used in whisper model (TODO: we shall add tests for whisper to CI pipeline, and also update fusion script to verify such assumptions if needed.) CUDA/CPU kernels supports bias for cross attention as long as bias is zero for key and value. I updated the check to support the scenario and added comments wherever there is such assumption. (3) **Fallback support** Previously, unfused kernel did not support packed qkv and packed kv formats. That means some case might fail since there is no fallback. I added new AddBiasTranpose cuda kernels for them to support fallback, so that all supported cases will not fail. #### Improvements (4) **QKV workspace size**. The logic for no_qkv_workspace could be easily out of sync since related code are scattered in different source files. I refactor the code to move all related code to one file (attention_prepare_qkv.cu) and add asserts, so that the logic can be in sync. (5) **Remove confusing concept of pass past in kv** parameters.pass_past_in_kv is confusing since the k/v in cross attention is not past state. Remove it and use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH instead. New code does not use past_key/past_value for cross attention, so the logic is more clear. (6) **More coverage and less workspace and less transpose of flash and efficient attention** Previously, there is one condition does not run flash or efficient attention: ``` bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; ``` After this change, we can use flash and efficient attention for the case, and also less workspace. For example, cross attention with bias, the original code uses two additional workspaces: ``` transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) add bias: query => q, temp_k_workspace => k, temp_v_workspace => v ``` New logic is like ``` if (has bias) Add bias to query, key, value, and store in q, k, v workspace else Use query, key and value directly as q, k and v in kernel ``` We can see that, we do not need allocate temp_k_workspace and temp_v_workspace so use less memory. New code saved two transposes in this case. Flash and efficient attention supports BSNH or BNSH formats for k and v. In old code, k/v are also converted to BSNH format. Some is not necessary. I do some change to convert k/v to BSNH or BNSH case by case. So that there are more cases can be covered by flash or efficient attention to improve performance. (6) **Debugging support** Previously, there is less debug info. In this change, I add a flag for debug info in the AttentionData. So that we can output debug info during the processing. Also add functions to consolidate the dumping of inputs, QKV processing and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow disable dumping from cuda kernel. #### Summary of changes (1) Refactoring the CheckInputs, and pass in operator type. (2) Refactoring the PrepareQKV to support fallback for packed qkv or packed kv inputs. (3) Change a few case of PrepareQKV to allow more case covered by flash and efficient attention. (4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace parameters.pass_past_in_kv (5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of assumption that key/value has no bias in this case. (6) Fix thread-safe issue in CumulatedSequenceLengthCache handling. (7) Add test cases to cover all supported scenarios. Current support scenarios for MultiHeadAttention for CUDA/CPU: | Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc | ---- | ---- | ---- | ------ | ----- | --------- | -------- | -----|--------- | BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed | BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in CPU | BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in CPU | BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for Q only | BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only present | BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and present <br> (not share buffer) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
bias_q = torch.empty((self.num_heads * self.head_size,), device=device, dtype=dtype).normal_(mean=0, std=0.1)
bias_k = torch.empty((self.num_heads * self.head_size,), device=device, dtype=dtype).normal_(mean=0, std=0.1)
bias_v = torch.empty((self.num_heads * self.head_size,), device=device, dtype=dtype).normal_(mean=0, std=0.1)
if no_bias_k_v:
bias_k = torch.zeros_like(bias_k)
bias_v = torch.zeros_like(bias_v)
k_bnsh = k.transpose(1, 2)
v_bnsh = v.transpose(1, 2)
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
if self.input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH:
feeds = {
"query": q.reshape(shape_dict["query"]),
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
"key": k.reshape(shape_dict["key"]),
"value": v.reshape(shape_dict["value"]),
}
elif self.input_format == InputFormats.QKV_BSN3H:
query = q.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size)
key = k.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size)
value = v.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size)
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
feeds = {
"query": torch.dstack((query, key, value)).reshape(shape_dict["query"]).contiguous(),
}
elif self.input_format == InputFormats.Q_KV_BSNH_BSN2H:
key = k.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size)
value = v.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size)
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
feeds = {
"query": q.reshape(shape_dict["query"]),
"key": torch.dstack((key, value)).reshape(shape_dict["key"]).contiguous(),
}
else:
assert self.input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH
feeds = {
"query": q.reshape(shape_dict["query"]),
"key": k_bnsh.contiguous(),
"value": v_bnsh.contiguous(),
}
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498) ### Description #### Issues Fixed (1) **TRT cross attention not thread safe**. [Core changes like this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6) are used to make it thread-safe: * Add an once_flag to CumulatedSequenceLengthCache to make sure it is only initialized once; and change the cache to be read only after initialization. Previously, the content is not read-only so it might be changed by other thread and potentially cause buffer overrun. * The kernel initialization is not guarded (Although the factory of kernel loading has static mutex to guard multiple threading), so the mutable variable might be set by two different threads at the same time. Add an once_flag to avoid that. This requires need some workspace computation change as well. So I did not create a separated pull request. (2) **Bias for cross attention** That scenario has assumption that only query has bias, but not for key and value. However, such assumption is not verified in runtime and there was no comment of assumption, and there was no test case so the support of scenario was disabled by mistake. Actually, the scenario is used in whisper model (TODO: we shall add tests for whisper to CI pipeline, and also update fusion script to verify such assumptions if needed.) CUDA/CPU kernels supports bias for cross attention as long as bias is zero for key and value. I updated the check to support the scenario and added comments wherever there is such assumption. (3) **Fallback support** Previously, unfused kernel did not support packed qkv and packed kv formats. That means some case might fail since there is no fallback. I added new AddBiasTranpose cuda kernels for them to support fallback, so that all supported cases will not fail. #### Improvements (4) **QKV workspace size**. The logic for no_qkv_workspace could be easily out of sync since related code are scattered in different source files. I refactor the code to move all related code to one file (attention_prepare_qkv.cu) and add asserts, so that the logic can be in sync. (5) **Remove confusing concept of pass past in kv** parameters.pass_past_in_kv is confusing since the k/v in cross attention is not past state. Remove it and use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH instead. New code does not use past_key/past_value for cross attention, so the logic is more clear. (6) **More coverage and less workspace and less transpose of flash and efficient attention** Previously, there is one condition does not run flash or efficient attention: ``` bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; ``` After this change, we can use flash and efficient attention for the case, and also less workspace. For example, cross attention with bias, the original code uses two additional workspaces: ``` transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) add bias: query => q, temp_k_workspace => k, temp_v_workspace => v ``` New logic is like ``` if (has bias) Add bias to query, key, value, and store in q, k, v workspace else Use query, key and value directly as q, k and v in kernel ``` We can see that, we do not need allocate temp_k_workspace and temp_v_workspace so use less memory. New code saved two transposes in this case. Flash and efficient attention supports BSNH or BNSH formats for k and v. In old code, k/v are also converted to BSNH format. Some is not necessary. I do some change to convert k/v to BSNH or BNSH case by case. So that there are more cases can be covered by flash or efficient attention to improve performance. (6) **Debugging support** Previously, there is less debug info. In this change, I add a flag for debug info in the AttentionData. So that we can output debug info during the processing. Also add functions to consolidate the dumping of inputs, QKV processing and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow disable dumping from cuda kernel. #### Summary of changes (1) Refactoring the CheckInputs, and pass in operator type. (2) Refactoring the PrepareQKV to support fallback for packed qkv or packed kv inputs. (3) Change a few case of PrepareQKV to allow more case covered by flash and efficient attention. (4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace parameters.pass_past_in_kv (5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of assumption that key/value has no bias in this case. (6) Fix thread-safe issue in CumulatedSequenceLengthCache handling. (7) Add test cases to cover all supported scenarios. Current support scenarios for MultiHeadAttention for CUDA/CPU: | Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc | ---- | ---- | ---- | ------ | ----- | --------- | -------- | -----|--------- | BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed | BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in CPU | BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in CPU | BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for Q only | BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only present | BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and present <br> (not share buffer) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
if self.has_past_input:
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
feeds = {
**feeds,
"past_key": torch.empty(shape_dict["past_key"], device=device, dtype=dtype).normal_(mean=0, std=0.1),
"past_value": torch.empty(shape_dict["past_value"], device=device, dtype=dtype).normal_(
mean=0, std=0.1
),
}
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498) ### Description #### Issues Fixed (1) **TRT cross attention not thread safe**. [Core changes like this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6) are used to make it thread-safe: * Add an once_flag to CumulatedSequenceLengthCache to make sure it is only initialized once; and change the cache to be read only after initialization. Previously, the content is not read-only so it might be changed by other thread and potentially cause buffer overrun. * The kernel initialization is not guarded (Although the factory of kernel loading has static mutex to guard multiple threading), so the mutable variable might be set by two different threads at the same time. Add an once_flag to avoid that. This requires need some workspace computation change as well. So I did not create a separated pull request. (2) **Bias for cross attention** That scenario has assumption that only query has bias, but not for key and value. However, such assumption is not verified in runtime and there was no comment of assumption, and there was no test case so the support of scenario was disabled by mistake. Actually, the scenario is used in whisper model (TODO: we shall add tests for whisper to CI pipeline, and also update fusion script to verify such assumptions if needed.) CUDA/CPU kernels supports bias for cross attention as long as bias is zero for key and value. I updated the check to support the scenario and added comments wherever there is such assumption. (3) **Fallback support** Previously, unfused kernel did not support packed qkv and packed kv formats. That means some case might fail since there is no fallback. I added new AddBiasTranpose cuda kernels for them to support fallback, so that all supported cases will not fail. #### Improvements (4) **QKV workspace size**. The logic for no_qkv_workspace could be easily out of sync since related code are scattered in different source files. I refactor the code to move all related code to one file (attention_prepare_qkv.cu) and add asserts, so that the logic can be in sync. (5) **Remove confusing concept of pass past in kv** parameters.pass_past_in_kv is confusing since the k/v in cross attention is not past state. Remove it and use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH instead. New code does not use past_key/past_value for cross attention, so the logic is more clear. (6) **More coverage and less workspace and less transpose of flash and efficient attention** Previously, there is one condition does not run flash or efficient attention: ``` bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; ``` After this change, we can use flash and efficient attention for the case, and also less workspace. For example, cross attention with bias, the original code uses two additional workspaces: ``` transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) add bias: query => q, temp_k_workspace => k, temp_v_workspace => v ``` New logic is like ``` if (has bias) Add bias to query, key, value, and store in q, k, v workspace else Use query, key and value directly as q, k and v in kernel ``` We can see that, we do not need allocate temp_k_workspace and temp_v_workspace so use less memory. New code saved two transposes in this case. Flash and efficient attention supports BSNH or BNSH formats for k and v. In old code, k/v are also converted to BSNH format. Some is not necessary. I do some change to convert k/v to BSNH or BNSH case by case. So that there are more cases can be covered by flash or efficient attention to improve performance. (6) **Debugging support** Previously, there is less debug info. In this change, I add a flag for debug info in the AttentionData. So that we can output debug info during the processing. Also add functions to consolidate the dumping of inputs, QKV processing and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow disable dumping from cuda kernel. #### Summary of changes (1) Refactoring the CheckInputs, and pass in operator type. (2) Refactoring the PrepareQKV to support fallback for packed qkv or packed kv inputs. (3) Change a few case of PrepareQKV to allow more case covered by flash and efficient attention. (4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace parameters.pass_past_in_kv (5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of assumption that key/value has no bias in this case. (6) Fix thread-safe issue in CumulatedSequenceLengthCache handling. (7) Add test cases to cover all supported scenarios. Current support scenarios for MultiHeadAttention for CUDA/CPU: | Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc | ---- | ---- | ---- | ------ | ----- | --------- | -------- | -----|--------- | BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed | BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in CPU | BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in CPU | BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for Q only | BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only present | BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and present <br> (not share buffer) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
if self.has_bias:
feeds["bias"] = torch.concat([bias_q, bias_k, bias_v], dim=0).reshape(shape_dict["bias"]).contiguous()
# Generate padding mask
if self.mask_format != AttentionMaskFormat.Mask_None:
self.mask_index_kv = torch.randint(
1, self.total_sequence_length + 1, (self.batch_size,), dtype=torch.int32, device=self.device
)
if self.past_sequence_length > 0:
self.mask_index_q = (
torch.ones(self.batch_size, dtype=torch.int32, device=self.device) * self.sequence_length
)
else: # prompt case
self.mask_index_q = self.mask_index_kv.clone()
mask = None
if self.mask_format == AttentionMaskFormat.Mask_1D_Key_SeqLen:
mask = self.mask_index_kv.clone()
elif self.mask_format == AttentionMaskFormat.Mask_2D_Key_PaddingMask:
k_mask = torch.ones(self.batch_size, 1, self.total_sequence_length, 1, dtype=torch.bool, device=self.device)
for i, n in enumerate(self.mask_index_kv):
k_mask[i, :, n:, :] = False
mask = k_mask.reshape(self.batch_size, self.total_sequence_length)
else:
assert self.mask_format == AttentionMaskFormat.Mask_None
if mask is not None:
feeds = {**feeds, "mask": mask.to(dtype=torch.int32)} # mask is int32 (not bool) for MultiHeadAttention op.
if self.has_attn_bias:
attn_bias = torch.empty(
(
1 if self.broadcast_attn_bias_dim_0 else self.batch_size,
1 if self.broadcast_attn_bias_dim_1 else self.num_heads,
self.sequence_length,
self.total_sequence_length,
),
device=self.device,
dtype=dtype,
).normal_(mean=0, std=0.1)
feeds["attn_bias"] = attn_bias
return feeds
def get_input_output_names(self):
if self.input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH:
inputs, outputs = ["query", "key", "value"], ["output"]
elif self.input_format == InputFormats.QKV_BSN3H:
inputs, outputs = ["query"], ["output"]
elif self.input_format == InputFormats.Q_KV_BSNH_BSN2H:
inputs, outputs = ["query", "key"], ["output"]
else:
inputs, outputs = ["query", "key", "value"], ["output"]
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498) ### Description #### Issues Fixed (1) **TRT cross attention not thread safe**. [Core changes like this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6) are used to make it thread-safe: * Add an once_flag to CumulatedSequenceLengthCache to make sure it is only initialized once; and change the cache to be read only after initialization. Previously, the content is not read-only so it might be changed by other thread and potentially cause buffer overrun. * The kernel initialization is not guarded (Although the factory of kernel loading has static mutex to guard multiple threading), so the mutable variable might be set by two different threads at the same time. Add an once_flag to avoid that. This requires need some workspace computation change as well. So I did not create a separated pull request. (2) **Bias for cross attention** That scenario has assumption that only query has bias, but not for key and value. However, such assumption is not verified in runtime and there was no comment of assumption, and there was no test case so the support of scenario was disabled by mistake. Actually, the scenario is used in whisper model (TODO: we shall add tests for whisper to CI pipeline, and also update fusion script to verify such assumptions if needed.) CUDA/CPU kernels supports bias for cross attention as long as bias is zero for key and value. I updated the check to support the scenario and added comments wherever there is such assumption. (3) **Fallback support** Previously, unfused kernel did not support packed qkv and packed kv formats. That means some case might fail since there is no fallback. I added new AddBiasTranpose cuda kernels for them to support fallback, so that all supported cases will not fail. #### Improvements (4) **QKV workspace size**. The logic for no_qkv_workspace could be easily out of sync since related code are scattered in different source files. I refactor the code to move all related code to one file (attention_prepare_qkv.cu) and add asserts, so that the logic can be in sync. (5) **Remove confusing concept of pass past in kv** parameters.pass_past_in_kv is confusing since the k/v in cross attention is not past state. Remove it and use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH instead. New code does not use past_key/past_value for cross attention, so the logic is more clear. (6) **More coverage and less workspace and less transpose of flash and efficient attention** Previously, there is one condition does not run flash or efficient attention: ``` bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; ``` After this change, we can use flash and efficient attention for the case, and also less workspace. For example, cross attention with bias, the original code uses two additional workspaces: ``` transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) add bias: query => q, temp_k_workspace => k, temp_v_workspace => v ``` New logic is like ``` if (has bias) Add bias to query, key, value, and store in q, k, v workspace else Use query, key and value directly as q, k and v in kernel ``` We can see that, we do not need allocate temp_k_workspace and temp_v_workspace so use less memory. New code saved two transposes in this case. Flash and efficient attention supports BSNH or BNSH formats for k and v. In old code, k/v are also converted to BSNH format. Some is not necessary. I do some change to convert k/v to BSNH or BNSH case by case. So that there are more cases can be covered by flash or efficient attention to improve performance. (6) **Debugging support** Previously, there is less debug info. In this change, I add a flag for debug info in the AttentionData. So that we can output debug info during the processing. Also add functions to consolidate the dumping of inputs, QKV processing and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow disable dumping from cuda kernel. #### Summary of changes (1) Refactoring the CheckInputs, and pass in operator type. (2) Refactoring the PrepareQKV to support fallback for packed qkv or packed kv inputs. (3) Change a few case of PrepareQKV to allow more case covered by flash and efficient attention. (4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace parameters.pass_past_in_kv (5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of assumption that key/value has no bias in this case. (6) Fix thread-safe issue in CumulatedSequenceLengthCache handling. (7) Add test cases to cover all supported scenarios. Current support scenarios for MultiHeadAttention for CUDA/CPU: | Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc | ---- | ---- | ---- | ------ | ----- | --------- | -------- | -----|--------- | BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed | BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in CPU | BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in CPU | BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for Q only | BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only present | BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and present <br> (not share buffer) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
if self.has_bias:
assert self.input_format != InputFormats.Q_KV_BSNH_BSN2H
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498) ### Description #### Issues Fixed (1) **TRT cross attention not thread safe**. [Core changes like this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6) are used to make it thread-safe: * Add an once_flag to CumulatedSequenceLengthCache to make sure it is only initialized once; and change the cache to be read only after initialization. Previously, the content is not read-only so it might be changed by other thread and potentially cause buffer overrun. * The kernel initialization is not guarded (Although the factory of kernel loading has static mutex to guard multiple threading), so the mutable variable might be set by two different threads at the same time. Add an once_flag to avoid that. This requires need some workspace computation change as well. So I did not create a separated pull request. (2) **Bias for cross attention** That scenario has assumption that only query has bias, but not for key and value. However, such assumption is not verified in runtime and there was no comment of assumption, and there was no test case so the support of scenario was disabled by mistake. Actually, the scenario is used in whisper model (TODO: we shall add tests for whisper to CI pipeline, and also update fusion script to verify such assumptions if needed.) CUDA/CPU kernels supports bias for cross attention as long as bias is zero for key and value. I updated the check to support the scenario and added comments wherever there is such assumption. (3) **Fallback support** Previously, unfused kernel did not support packed qkv and packed kv formats. That means some case might fail since there is no fallback. I added new AddBiasTranpose cuda kernels for them to support fallback, so that all supported cases will not fail. #### Improvements (4) **QKV workspace size**. The logic for no_qkv_workspace could be easily out of sync since related code are scattered in different source files. I refactor the code to move all related code to one file (attention_prepare_qkv.cu) and add asserts, so that the logic can be in sync. (5) **Remove confusing concept of pass past in kv** parameters.pass_past_in_kv is confusing since the k/v in cross attention is not past state. Remove it and use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH instead. New code does not use past_key/past_value for cross attention, so the logic is more clear. (6) **More coverage and less workspace and less transpose of flash and efficient attention** Previously, there is one condition does not run flash or efficient attention: ``` bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; ``` After this change, we can use flash and efficient attention for the case, and also less workspace. For example, cross attention with bias, the original code uses two additional workspaces: ``` transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) add bias: query => q, temp_k_workspace => k, temp_v_workspace => v ``` New logic is like ``` if (has bias) Add bias to query, key, value, and store in q, k, v workspace else Use query, key and value directly as q, k and v in kernel ``` We can see that, we do not need allocate temp_k_workspace and temp_v_workspace so use less memory. New code saved two transposes in this case. Flash and efficient attention supports BSNH or BNSH formats for k and v. In old code, k/v are also converted to BSNH format. Some is not necessary. I do some change to convert k/v to BSNH or BNSH case by case. So that there are more cases can be covered by flash or efficient attention to improve performance. (6) **Debugging support** Previously, there is less debug info. In this change, I add a flag for debug info in the AttentionData. So that we can output debug info during the processing. Also add functions to consolidate the dumping of inputs, QKV processing and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow disable dumping from cuda kernel. #### Summary of changes (1) Refactoring the CheckInputs, and pass in operator type. (2) Refactoring the PrepareQKV to support fallback for packed qkv or packed kv inputs. (3) Change a few case of PrepareQKV to allow more case covered by flash and efficient attention. (4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace parameters.pass_past_in_kv (5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of assumption that key/value has no bias in this case. (6) Fix thread-safe issue in CumulatedSequenceLengthCache handling. (7) Add test cases to cover all supported scenarios. Current support scenarios for MultiHeadAttention for CUDA/CPU: | Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc | ---- | ---- | ---- | ------ | ----- | --------- | -------- | -----|--------- | BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed | BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in CPU | BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in CPU | BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for Q only | BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only present | BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and present <br> (not share buffer) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
inputs = [*inputs, "bias"]
if self.mask_format != AttentionMaskFormat.Mask_None:
inputs = [*inputs, "mask"]
if self.has_attn_bias:
inputs = [*inputs, "attn_bias"]
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498) ### Description #### Issues Fixed (1) **TRT cross attention not thread safe**. [Core changes like this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6) are used to make it thread-safe: * Add an once_flag to CumulatedSequenceLengthCache to make sure it is only initialized once; and change the cache to be read only after initialization. Previously, the content is not read-only so it might be changed by other thread and potentially cause buffer overrun. * The kernel initialization is not guarded (Although the factory of kernel loading has static mutex to guard multiple threading), so the mutable variable might be set by two different threads at the same time. Add an once_flag to avoid that. This requires need some workspace computation change as well. So I did not create a separated pull request. (2) **Bias for cross attention** That scenario has assumption that only query has bias, but not for key and value. However, such assumption is not verified in runtime and there was no comment of assumption, and there was no test case so the support of scenario was disabled by mistake. Actually, the scenario is used in whisper model (TODO: we shall add tests for whisper to CI pipeline, and also update fusion script to verify such assumptions if needed.) CUDA/CPU kernels supports bias for cross attention as long as bias is zero for key and value. I updated the check to support the scenario and added comments wherever there is such assumption. (3) **Fallback support** Previously, unfused kernel did not support packed qkv and packed kv formats. That means some case might fail since there is no fallback. I added new AddBiasTranpose cuda kernels for them to support fallback, so that all supported cases will not fail. #### Improvements (4) **QKV workspace size**. The logic for no_qkv_workspace could be easily out of sync since related code are scattered in different source files. I refactor the code to move all related code to one file (attention_prepare_qkv.cu) and add asserts, so that the logic can be in sync. (5) **Remove confusing concept of pass past in kv** parameters.pass_past_in_kv is confusing since the k/v in cross attention is not past state. Remove it and use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH instead. New code does not use past_key/past_value for cross attention, so the logic is more clear. (6) **More coverage and less workspace and less transpose of flash and efficient attention** Previously, there is one condition does not run flash or efficient attention: ``` bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; ``` After this change, we can use flash and efficient attention for the case, and also less workspace. For example, cross attention with bias, the original code uses two additional workspaces: ``` transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) add bias: query => q, temp_k_workspace => k, temp_v_workspace => v ``` New logic is like ``` if (has bias) Add bias to query, key, value, and store in q, k, v workspace else Use query, key and value directly as q, k and v in kernel ``` We can see that, we do not need allocate temp_k_workspace and temp_v_workspace so use less memory. New code saved two transposes in this case. Flash and efficient attention supports BSNH or BNSH formats for k and v. In old code, k/v are also converted to BSNH format. Some is not necessary. I do some change to convert k/v to BSNH or BNSH case by case. So that there are more cases can be covered by flash or efficient attention to improve performance. (6) **Debugging support** Previously, there is less debug info. In this change, I add a flag for debug info in the AttentionData. So that we can output debug info during the processing. Also add functions to consolidate the dumping of inputs, QKV processing and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow disable dumping from cuda kernel. #### Summary of changes (1) Refactoring the CheckInputs, and pass in operator type. (2) Refactoring the PrepareQKV to support fallback for packed qkv or packed kv inputs. (3) Change a few case of PrepareQKV to allow more case covered by flash and efficient attention. (4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace parameters.pass_past_in_kv (5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of assumption that key/value has no bias in this case. (6) Fix thread-safe issue in CumulatedSequenceLengthCache handling. (7) Add test cases to cover all supported scenarios. Current support scenarios for MultiHeadAttention for CUDA/CPU: | Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc | ---- | ---- | ---- | ------ | ----- | --------- | -------- | -----|--------- | BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed | BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in CPU | BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in CPU | BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for Q only | BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only present | BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and present <br> (not share buffer) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
if self.has_past_input:
inputs = [*inputs, "past_key", "past_value"]
if self.has_present_output:
outputs = [*outputs, "present_key", "present_value"]
return inputs, outputs
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
def fill_optional_mha_inputs(input_names):
inputs = ["query", "key", "value", "bias", "mask", "attn_bias", "past_key", "past_value"]
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498) ### Description #### Issues Fixed (1) **TRT cross attention not thread safe**. [Core changes like this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6) are used to make it thread-safe: * Add an once_flag to CumulatedSequenceLengthCache to make sure it is only initialized once; and change the cache to be read only after initialization. Previously, the content is not read-only so it might be changed by other thread and potentially cause buffer overrun. * The kernel initialization is not guarded (Although the factory of kernel loading has static mutex to guard multiple threading), so the mutable variable might be set by two different threads at the same time. Add an once_flag to avoid that. This requires need some workspace computation change as well. So I did not create a separated pull request. (2) **Bias for cross attention** That scenario has assumption that only query has bias, but not for key and value. However, such assumption is not verified in runtime and there was no comment of assumption, and there was no test case so the support of scenario was disabled by mistake. Actually, the scenario is used in whisper model (TODO: we shall add tests for whisper to CI pipeline, and also update fusion script to verify such assumptions if needed.) CUDA/CPU kernels supports bias for cross attention as long as bias is zero for key and value. I updated the check to support the scenario and added comments wherever there is such assumption. (3) **Fallback support** Previously, unfused kernel did not support packed qkv and packed kv formats. That means some case might fail since there is no fallback. I added new AddBiasTranpose cuda kernels for them to support fallback, so that all supported cases will not fail. #### Improvements (4) **QKV workspace size**. The logic for no_qkv_workspace could be easily out of sync since related code are scattered in different source files. I refactor the code to move all related code to one file (attention_prepare_qkv.cu) and add asserts, so that the logic can be in sync. (5) **Remove confusing concept of pass past in kv** parameters.pass_past_in_kv is confusing since the k/v in cross attention is not past state. Remove it and use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH instead. New code does not use past_key/past_value for cross attention, so the logic is more clear. (6) **More coverage and less workspace and less transpose of flash and efficient attention** Previously, there is one condition does not run flash or efficient attention: ``` bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; ``` After this change, we can use flash and efficient attention for the case, and also less workspace. For example, cross attention with bias, the original code uses two additional workspaces: ``` transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) add bias: query => q, temp_k_workspace => k, temp_v_workspace => v ``` New logic is like ``` if (has bias) Add bias to query, key, value, and store in q, k, v workspace else Use query, key and value directly as q, k and v in kernel ``` We can see that, we do not need allocate temp_k_workspace and temp_v_workspace so use less memory. New code saved two transposes in this case. Flash and efficient attention supports BSNH or BNSH formats for k and v. In old code, k/v are also converted to BSNH format. Some is not necessary. I do some change to convert k/v to BSNH or BNSH case by case. So that there are more cases can be covered by flash or efficient attention to improve performance. (6) **Debugging support** Previously, there is less debug info. In this change, I add a flag for debug info in the AttentionData. So that we can output debug info during the processing. Also add functions to consolidate the dumping of inputs, QKV processing and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow disable dumping from cuda kernel. #### Summary of changes (1) Refactoring the CheckInputs, and pass in operator type. (2) Refactoring the PrepareQKV to support fallback for packed qkv or packed kv inputs. (3) Change a few case of PrepareQKV to allow more case covered by flash and efficient attention. (4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace parameters.pass_past_in_kv (5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of assumption that key/value has no bias in this case. (6) Fix thread-safe issue in CumulatedSequenceLengthCache handling. (7) Add test cases to cover all supported scenarios. Current support scenarios for MultiHeadAttention for CUDA/CPU: | Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc | ---- | ---- | ---- | ------ | ----- | --------- | -------- | -----|--------- | BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed | BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in CPU | BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in CPU | BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for Q only | BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only present | BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and present <br> (not share buffer) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
# Remove optional inputs that are not in input_names with empty string
inputs_with_optional = [input if input in input_names else "" for input in inputs]
# Remove empty string at the end of the list.
while inputs_with_optional[-1] == "":
inputs_with_optional.pop(-1)
return inputs_with_optional
def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig, use_symbolic_shape=False):
input_names, output_names = config.get_input_output_names()
float_type = TensorProto.FLOAT16 if config.dtype == torch.float16 else TensorProto.FLOAT
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
nodes = [
helper.make_node(
"MultiHeadAttention",
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498) ### Description #### Issues Fixed (1) **TRT cross attention not thread safe**. [Core changes like this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6) are used to make it thread-safe: * Add an once_flag to CumulatedSequenceLengthCache to make sure it is only initialized once; and change the cache to be read only after initialization. Previously, the content is not read-only so it might be changed by other thread and potentially cause buffer overrun. * The kernel initialization is not guarded (Although the factory of kernel loading has static mutex to guard multiple threading), so the mutable variable might be set by two different threads at the same time. Add an once_flag to avoid that. This requires need some workspace computation change as well. So I did not create a separated pull request. (2) **Bias for cross attention** That scenario has assumption that only query has bias, but not for key and value. However, such assumption is not verified in runtime and there was no comment of assumption, and there was no test case so the support of scenario was disabled by mistake. Actually, the scenario is used in whisper model (TODO: we shall add tests for whisper to CI pipeline, and also update fusion script to verify such assumptions if needed.) CUDA/CPU kernels supports bias for cross attention as long as bias is zero for key and value. I updated the check to support the scenario and added comments wherever there is such assumption. (3) **Fallback support** Previously, unfused kernel did not support packed qkv and packed kv formats. That means some case might fail since there is no fallback. I added new AddBiasTranpose cuda kernels for them to support fallback, so that all supported cases will not fail. #### Improvements (4) **QKV workspace size**. The logic for no_qkv_workspace could be easily out of sync since related code are scattered in different source files. I refactor the code to move all related code to one file (attention_prepare_qkv.cu) and add asserts, so that the logic can be in sync. (5) **Remove confusing concept of pass past in kv** parameters.pass_past_in_kv is confusing since the k/v in cross attention is not past state. Remove it and use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH instead. New code does not use past_key/past_value for cross attention, so the logic is more clear. (6) **More coverage and less workspace and less transpose of flash and efficient attention** Previously, there is one condition does not run flash or efficient attention: ``` bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; ``` After this change, we can use flash and efficient attention for the case, and also less workspace. For example, cross attention with bias, the original code uses two additional workspaces: ``` transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) add bias: query => q, temp_k_workspace => k, temp_v_workspace => v ``` New logic is like ``` if (has bias) Add bias to query, key, value, and store in q, k, v workspace else Use query, key and value directly as q, k and v in kernel ``` We can see that, we do not need allocate temp_k_workspace and temp_v_workspace so use less memory. New code saved two transposes in this case. Flash and efficient attention supports BSNH or BNSH formats for k and v. In old code, k/v are also converted to BSNH format. Some is not necessary. I do some change to convert k/v to BSNH or BNSH case by case. So that there are more cases can be covered by flash or efficient attention to improve performance. (6) **Debugging support** Previously, there is less debug info. In this change, I add a flag for debug info in the AttentionData. So that we can output debug info during the processing. Also add functions to consolidate the dumping of inputs, QKV processing and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow disable dumping from cuda kernel. #### Summary of changes (1) Refactoring the CheckInputs, and pass in operator type. (2) Refactoring the PrepareQKV to support fallback for packed qkv or packed kv inputs. (3) Change a few case of PrepareQKV to allow more case covered by flash and efficient attention. (4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace parameters.pass_past_in_kv (5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of assumption that key/value has no bias in this case. (6) Fix thread-safe issue in CumulatedSequenceLengthCache handling. (7) Add test cases to cover all supported scenarios. Current support scenarios for MultiHeadAttention for CUDA/CPU: | Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc | ---- | ---- | ---- | ------ | ----- | --------- | -------- | -----|--------- | BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed | BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in CPU | BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in CPU | BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for Q only | BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only present | BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and present <br> (not share buffer) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
fill_optional_mha_inputs(input_names),
output_names,
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
"MultiHeadAttention_0",
num_heads=config.num_heads,
unidirectional=int(config.causal),
scale=config.scale,
mask_filter_value=float("-inf"),
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
domain="com.microsoft",
),
]
shape_dict = config.symbolic_shape_dict() if use_symbolic_shape else config.shape_dict()
inputs = [
helper.make_tensor_value_info(
input_name, TensorProto.INT32 if input_name == "mask" else float_type, list(shape_dict[input_name])
)
for input_name in input_names
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498) ### Description #### Issues Fixed (1) **TRT cross attention not thread safe**. [Core changes like this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6) are used to make it thread-safe: * Add an once_flag to CumulatedSequenceLengthCache to make sure it is only initialized once; and change the cache to be read only after initialization. Previously, the content is not read-only so it might be changed by other thread and potentially cause buffer overrun. * The kernel initialization is not guarded (Although the factory of kernel loading has static mutex to guard multiple threading), so the mutable variable might be set by two different threads at the same time. Add an once_flag to avoid that. This requires need some workspace computation change as well. So I did not create a separated pull request. (2) **Bias for cross attention** That scenario has assumption that only query has bias, but not for key and value. However, such assumption is not verified in runtime and there was no comment of assumption, and there was no test case so the support of scenario was disabled by mistake. Actually, the scenario is used in whisper model (TODO: we shall add tests for whisper to CI pipeline, and also update fusion script to verify such assumptions if needed.) CUDA/CPU kernels supports bias for cross attention as long as bias is zero for key and value. I updated the check to support the scenario and added comments wherever there is such assumption. (3) **Fallback support** Previously, unfused kernel did not support packed qkv and packed kv formats. That means some case might fail since there is no fallback. I added new AddBiasTranpose cuda kernels for them to support fallback, so that all supported cases will not fail. #### Improvements (4) **QKV workspace size**. The logic for no_qkv_workspace could be easily out of sync since related code are scattered in different source files. I refactor the code to move all related code to one file (attention_prepare_qkv.cu) and add asserts, so that the logic can be in sync. (5) **Remove confusing concept of pass past in kv** parameters.pass_past_in_kv is confusing since the k/v in cross attention is not past state. Remove it and use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH instead. New code does not use past_key/past_value for cross attention, so the logic is more clear. (6) **More coverage and less workspace and less transpose of flash and efficient attention** Previously, there is one condition does not run flash or efficient attention: ``` bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; ``` After this change, we can use flash and efficient attention for the case, and also less workspace. For example, cross attention with bias, the original code uses two additional workspaces: ``` transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) add bias: query => q, temp_k_workspace => k, temp_v_workspace => v ``` New logic is like ``` if (has bias) Add bias to query, key, value, and store in q, k, v workspace else Use query, key and value directly as q, k and v in kernel ``` We can see that, we do not need allocate temp_k_workspace and temp_v_workspace so use less memory. New code saved two transposes in this case. Flash and efficient attention supports BSNH or BNSH formats for k and v. In old code, k/v are also converted to BSNH format. Some is not necessary. I do some change to convert k/v to BSNH or BNSH case by case. So that there are more cases can be covered by flash or efficient attention to improve performance. (6) **Debugging support** Previously, there is less debug info. In this change, I add a flag for debug info in the AttentionData. So that we can output debug info during the processing. Also add functions to consolidate the dumping of inputs, QKV processing and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow disable dumping from cuda kernel. #### Summary of changes (1) Refactoring the CheckInputs, and pass in operator type. (2) Refactoring the PrepareQKV to support fallback for packed qkv or packed kv inputs. (3) Change a few case of PrepareQKV to allow more case covered by flash and efficient attention. (4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace parameters.pass_past_in_kv (5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of assumption that key/value has no bias in this case. (6) Fix thread-safe issue in CumulatedSequenceLengthCache handling. (7) Add test cases to cover all supported scenarios. Current support scenarios for MultiHeadAttention for CUDA/CPU: | Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc | ---- | ---- | ---- | ------ | ----- | --------- | -------- | -----|--------- | BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed | BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in CPU | BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in CPU | BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for Q only | BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only present | BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and present <br> (not share buffer) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
if input_name
]
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
outputs = [
helper.make_tensor_value_info(output_name, float_type, list(shape_dict[output_name]))
for output_name in output_names
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498) ### Description #### Issues Fixed (1) **TRT cross attention not thread safe**. [Core changes like this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6) are used to make it thread-safe: * Add an once_flag to CumulatedSequenceLengthCache to make sure it is only initialized once; and change the cache to be read only after initialization. Previously, the content is not read-only so it might be changed by other thread and potentially cause buffer overrun. * The kernel initialization is not guarded (Although the factory of kernel loading has static mutex to guard multiple threading), so the mutable variable might be set by two different threads at the same time. Add an once_flag to avoid that. This requires need some workspace computation change as well. So I did not create a separated pull request. (2) **Bias for cross attention** That scenario has assumption that only query has bias, but not for key and value. However, such assumption is not verified in runtime and there was no comment of assumption, and there was no test case so the support of scenario was disabled by mistake. Actually, the scenario is used in whisper model (TODO: we shall add tests for whisper to CI pipeline, and also update fusion script to verify such assumptions if needed.) CUDA/CPU kernels supports bias for cross attention as long as bias is zero for key and value. I updated the check to support the scenario and added comments wherever there is such assumption. (3) **Fallback support** Previously, unfused kernel did not support packed qkv and packed kv formats. That means some case might fail since there is no fallback. I added new AddBiasTranpose cuda kernels for them to support fallback, so that all supported cases will not fail. #### Improvements (4) **QKV workspace size**. The logic for no_qkv_workspace could be easily out of sync since related code are scattered in different source files. I refactor the code to move all related code to one file (attention_prepare_qkv.cu) and add asserts, so that the logic can be in sync. (5) **Remove confusing concept of pass past in kv** parameters.pass_past_in_kv is confusing since the k/v in cross attention is not past state. Remove it and use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH instead. New code does not use past_key/past_value for cross attention, so the logic is more clear. (6) **More coverage and less workspace and less transpose of flash and efficient attention** Previously, there is one condition does not run flash or efficient attention: ``` bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; ``` After this change, we can use flash and efficient attention for the case, and also less workspace. For example, cross attention with bias, the original code uses two additional workspaces: ``` transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) add bias: query => q, temp_k_workspace => k, temp_v_workspace => v ``` New logic is like ``` if (has bias) Add bias to query, key, value, and store in q, k, v workspace else Use query, key and value directly as q, k and v in kernel ``` We can see that, we do not need allocate temp_k_workspace and temp_v_workspace so use less memory. New code saved two transposes in this case. Flash and efficient attention supports BSNH or BNSH formats for k and v. In old code, k/v are also converted to BSNH format. Some is not necessary. I do some change to convert k/v to BSNH or BNSH case by case. So that there are more cases can be covered by flash or efficient attention to improve performance. (6) **Debugging support** Previously, there is less debug info. In this change, I add a flag for debug info in the AttentionData. So that we can output debug info during the processing. Also add functions to consolidate the dumping of inputs, QKV processing and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow disable dumping from cuda kernel. #### Summary of changes (1) Refactoring the CheckInputs, and pass in operator type. (2) Refactoring the PrepareQKV to support fallback for packed qkv or packed kv inputs. (3) Change a few case of PrepareQKV to allow more case covered by flash and efficient attention. (4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace parameters.pass_past_in_kv (5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of assumption that key/value has no bias in this case. (6) Fix thread-safe issue in CumulatedSequenceLengthCache handling. (7) Add test cases to cover all supported scenarios. Current support scenarios for MultiHeadAttention for CUDA/CPU: | Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc | ---- | ---- | ---- | ------ | ----- | --------- | -------- | -----|--------- | BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed | BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in CPU | BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in CPU | BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for Q only | BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only present | BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and present <br> (not share buffer) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
if output_name
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
]
graph = helper.make_graph(
nodes,
"MultiHeadAttention_Graph",
inputs,
outputs,
)
model = helper.make_model(graph)
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
return model.SerializeToString()
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
def create_ort_session(
config: MultiHeadAttentionConfig,
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
session_options=None,
attention_kernel=SdpaKernel.DEFAULT,
use_symbolic_shape: bool = True,
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498) ### Description #### Issues Fixed (1) **TRT cross attention not thread safe**. [Core changes like this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6) are used to make it thread-safe: * Add an once_flag to CumulatedSequenceLengthCache to make sure it is only initialized once; and change the cache to be read only after initialization. Previously, the content is not read-only so it might be changed by other thread and potentially cause buffer overrun. * The kernel initialization is not guarded (Although the factory of kernel loading has static mutex to guard multiple threading), so the mutable variable might be set by two different threads at the same time. Add an once_flag to avoid that. This requires need some workspace computation change as well. So I did not create a separated pull request. (2) **Bias for cross attention** That scenario has assumption that only query has bias, but not for key and value. However, such assumption is not verified in runtime and there was no comment of assumption, and there was no test case so the support of scenario was disabled by mistake. Actually, the scenario is used in whisper model (TODO: we shall add tests for whisper to CI pipeline, and also update fusion script to verify such assumptions if needed.) CUDA/CPU kernels supports bias for cross attention as long as bias is zero for key and value. I updated the check to support the scenario and added comments wherever there is such assumption. (3) **Fallback support** Previously, unfused kernel did not support packed qkv and packed kv formats. That means some case might fail since there is no fallback. I added new AddBiasTranpose cuda kernels for them to support fallback, so that all supported cases will not fail. #### Improvements (4) **QKV workspace size**. The logic for no_qkv_workspace could be easily out of sync since related code are scattered in different source files. I refactor the code to move all related code to one file (attention_prepare_qkv.cu) and add asserts, so that the logic can be in sync. (5) **Remove confusing concept of pass past in kv** parameters.pass_past_in_kv is confusing since the k/v in cross attention is not past state. Remove it and use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH instead. New code does not use past_key/past_value for cross attention, so the logic is more clear. (6) **More coverage and less workspace and less transpose of flash and efficient attention** Previously, there is one condition does not run flash or efficient attention: ``` bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; ``` After this change, we can use flash and efficient attention for the case, and also less workspace. For example, cross attention with bias, the original code uses two additional workspaces: ``` transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) add bias: query => q, temp_k_workspace => k, temp_v_workspace => v ``` New logic is like ``` if (has bias) Add bias to query, key, value, and store in q, k, v workspace else Use query, key and value directly as q, k and v in kernel ``` We can see that, we do not need allocate temp_k_workspace and temp_v_workspace so use less memory. New code saved two transposes in this case. Flash and efficient attention supports BSNH or BNSH formats for k and v. In old code, k/v are also converted to BSNH format. Some is not necessary. I do some change to convert k/v to BSNH or BNSH case by case. So that there are more cases can be covered by flash or efficient attention to improve performance. (6) **Debugging support** Previously, there is less debug info. In this change, I add a flag for debug info in the AttentionData. So that we can output debug info during the processing. Also add functions to consolidate the dumping of inputs, QKV processing and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow disable dumping from cuda kernel. #### Summary of changes (1) Refactoring the CheckInputs, and pass in operator type. (2) Refactoring the PrepareQKV to support fallback for packed qkv or packed kv inputs. (3) Change a few case of PrepareQKV to allow more case covered by flash and efficient attention. (4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace parameters.pass_past_in_kv (5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of assumption that key/value has no bias in this case. (6) Fix thread-safe issue in CumulatedSequenceLengthCache handling. (7) Add test cases to cover all supported scenarios. Current support scenarios for MultiHeadAttention for CUDA/CPU: | Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc | ---- | ---- | ---- | ------ | ----- | --------- | -------- | -----|--------- | BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed | BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in CPU | BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in CPU | BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for Q only | BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only present | BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and present <br> (not share buffer) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
use_tf32: bool = True,
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
) -> CudaSession:
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
if config.verbose:
print(f"create session for {vars(config)}")
onnx_model_str = create_multi_head_attention_onnx_model(config, use_symbolic_shape=use_symbolic_shape)
if config.provider == "CUDAExecutionProvider":
device_id = torch.cuda.current_device() if isinstance(config.device, str) else config.device.index
provider_options = CudaSession.get_cuda_provider_options(device_id, config.enable_cuda_graph)
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
provider_options["sdpa_kernel"] = int(attention_kernel)
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498) ### Description #### Issues Fixed (1) **TRT cross attention not thread safe**. [Core changes like this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6) are used to make it thread-safe: * Add an once_flag to CumulatedSequenceLengthCache to make sure it is only initialized once; and change the cache to be read only after initialization. Previously, the content is not read-only so it might be changed by other thread and potentially cause buffer overrun. * The kernel initialization is not guarded (Although the factory of kernel loading has static mutex to guard multiple threading), so the mutable variable might be set by two different threads at the same time. Add an once_flag to avoid that. This requires need some workspace computation change as well. So I did not create a separated pull request. (2) **Bias for cross attention** That scenario has assumption that only query has bias, but not for key and value. However, such assumption is not verified in runtime and there was no comment of assumption, and there was no test case so the support of scenario was disabled by mistake. Actually, the scenario is used in whisper model (TODO: we shall add tests for whisper to CI pipeline, and also update fusion script to verify such assumptions if needed.) CUDA/CPU kernels supports bias for cross attention as long as bias is zero for key and value. I updated the check to support the scenario and added comments wherever there is such assumption. (3) **Fallback support** Previously, unfused kernel did not support packed qkv and packed kv formats. That means some case might fail since there is no fallback. I added new AddBiasTranpose cuda kernels for them to support fallback, so that all supported cases will not fail. #### Improvements (4) **QKV workspace size**. The logic for no_qkv_workspace could be easily out of sync since related code are scattered in different source files. I refactor the code to move all related code to one file (attention_prepare_qkv.cu) and add asserts, so that the logic can be in sync. (5) **Remove confusing concept of pass past in kv** parameters.pass_past_in_kv is confusing since the k/v in cross attention is not past state. Remove it and use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH instead. New code does not use past_key/past_value for cross attention, so the logic is more clear. (6) **More coverage and less workspace and less transpose of flash and efficient attention** Previously, there is one condition does not run flash or efficient attention: ``` bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; ``` After this change, we can use flash and efficient attention for the case, and also less workspace. For example, cross attention with bias, the original code uses two additional workspaces: ``` transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) add bias: query => q, temp_k_workspace => k, temp_v_workspace => v ``` New logic is like ``` if (has bias) Add bias to query, key, value, and store in q, k, v workspace else Use query, key and value directly as q, k and v in kernel ``` We can see that, we do not need allocate temp_k_workspace and temp_v_workspace so use less memory. New code saved two transposes in this case. Flash and efficient attention supports BSNH or BNSH formats for k and v. In old code, k/v are also converted to BSNH format. Some is not necessary. I do some change to convert k/v to BSNH or BNSH case by case. So that there are more cases can be covered by flash or efficient attention to improve performance. (6) **Debugging support** Previously, there is less debug info. In this change, I add a flag for debug info in the AttentionData. So that we can output debug info during the processing. Also add functions to consolidate the dumping of inputs, QKV processing and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow disable dumping from cuda kernel. #### Summary of changes (1) Refactoring the CheckInputs, and pass in operator type. (2) Refactoring the PrepareQKV to support fallback for packed qkv or packed kv inputs. (3) Change a few case of PrepareQKV to allow more case covered by flash and efficient attention. (4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace parameters.pass_past_in_kv (5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of assumption that key/value has no bias in this case. (6) Fix thread-safe issue in CumulatedSequenceLengthCache handling. (7) Add test cases to cover all supported scenarios. Current support scenarios for MultiHeadAttention for CUDA/CPU: | Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc | ---- | ---- | ---- | ------ | ----- | --------- | -------- | -----|--------- | BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed | BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in CPU | BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in CPU | BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for Q only | BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only present | BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and present <br> (not share buffer) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
provider_options["use_tf32"] = int(use_tf32)
providers = [(config.provider, provider_options), "CPUExecutionProvider"]
else:
providers = ["CPUExecutionProvider"]
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
ort_session = InferenceSession(onnx_model_str, session_options, providers=providers)
return ort_session
def create_session(
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498) ### Description #### Issues Fixed (1) **TRT cross attention not thread safe**. [Core changes like this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6) are used to make it thread-safe: * Add an once_flag to CumulatedSequenceLengthCache to make sure it is only initialized once; and change the cache to be read only after initialization. Previously, the content is not read-only so it might be changed by other thread and potentially cause buffer overrun. * The kernel initialization is not guarded (Although the factory of kernel loading has static mutex to guard multiple threading), so the mutable variable might be set by two different threads at the same time. Add an once_flag to avoid that. This requires need some workspace computation change as well. So I did not create a separated pull request. (2) **Bias for cross attention** That scenario has assumption that only query has bias, but not for key and value. However, such assumption is not verified in runtime and there was no comment of assumption, and there was no test case so the support of scenario was disabled by mistake. Actually, the scenario is used in whisper model (TODO: we shall add tests for whisper to CI pipeline, and also update fusion script to verify such assumptions if needed.) CUDA/CPU kernels supports bias for cross attention as long as bias is zero for key and value. I updated the check to support the scenario and added comments wherever there is such assumption. (3) **Fallback support** Previously, unfused kernel did not support packed qkv and packed kv formats. That means some case might fail since there is no fallback. I added new AddBiasTranpose cuda kernels for them to support fallback, so that all supported cases will not fail. #### Improvements (4) **QKV workspace size**. The logic for no_qkv_workspace could be easily out of sync since related code are scattered in different source files. I refactor the code to move all related code to one file (attention_prepare_qkv.cu) and add asserts, so that the logic can be in sync. (5) **Remove confusing concept of pass past in kv** parameters.pass_past_in_kv is confusing since the k/v in cross attention is not past state. Remove it and use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH instead. New code does not use past_key/past_value for cross attention, so the logic is more clear. (6) **More coverage and less workspace and less transpose of flash and efficient attention** Previously, there is one condition does not run flash or efficient attention: ``` bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; ``` After this change, we can use flash and efficient attention for the case, and also less workspace. For example, cross attention with bias, the original code uses two additional workspaces: ``` transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) add bias: query => q, temp_k_workspace => k, temp_v_workspace => v ``` New logic is like ``` if (has bias) Add bias to query, key, value, and store in q, k, v workspace else Use query, key and value directly as q, k and v in kernel ``` We can see that, we do not need allocate temp_k_workspace and temp_v_workspace so use less memory. New code saved two transposes in this case. Flash and efficient attention supports BSNH or BNSH formats for k and v. In old code, k/v are also converted to BSNH format. Some is not necessary. I do some change to convert k/v to BSNH or BNSH case by case. So that there are more cases can be covered by flash or efficient attention to improve performance. (6) **Debugging support** Previously, there is less debug info. In this change, I add a flag for debug info in the AttentionData. So that we can output debug info during the processing. Also add functions to consolidate the dumping of inputs, QKV processing and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow disable dumping from cuda kernel. #### Summary of changes (1) Refactoring the CheckInputs, and pass in operator type. (2) Refactoring the PrepareQKV to support fallback for packed qkv or packed kv inputs. (3) Change a few case of PrepareQKV to allow more case covered by flash and efficient attention. (4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace parameters.pass_past_in_kv (5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of assumption that key/value has no bias in this case. (6) Fix thread-safe issue in CumulatedSequenceLengthCache handling. (7) Add test cases to cover all supported scenarios. Current support scenarios for MultiHeadAttention for CUDA/CPU: | Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc | ---- | ---- | ---- | ------ | ----- | --------- | -------- | -----|--------- | BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed | BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in CPU | BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in CPU | BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for Q only | BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only present | BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and present <br> (not share buffer) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
config: MultiHeadAttentionConfig, session_options=None, attention_kernel=SdpaKernel.DEFAULT, use_tf32: bool = True
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
) -> CudaSession:
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498) ### Description #### Issues Fixed (1) **TRT cross attention not thread safe**. [Core changes like this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6) are used to make it thread-safe: * Add an once_flag to CumulatedSequenceLengthCache to make sure it is only initialized once; and change the cache to be read only after initialization. Previously, the content is not read-only so it might be changed by other thread and potentially cause buffer overrun. * The kernel initialization is not guarded (Although the factory of kernel loading has static mutex to guard multiple threading), so the mutable variable might be set by two different threads at the same time. Add an once_flag to avoid that. This requires need some workspace computation change as well. So I did not create a separated pull request. (2) **Bias for cross attention** That scenario has assumption that only query has bias, but not for key and value. However, such assumption is not verified in runtime and there was no comment of assumption, and there was no test case so the support of scenario was disabled by mistake. Actually, the scenario is used in whisper model (TODO: we shall add tests for whisper to CI pipeline, and also update fusion script to verify such assumptions if needed.) CUDA/CPU kernels supports bias for cross attention as long as bias is zero for key and value. I updated the check to support the scenario and added comments wherever there is such assumption. (3) **Fallback support** Previously, unfused kernel did not support packed qkv and packed kv formats. That means some case might fail since there is no fallback. I added new AddBiasTranpose cuda kernels for them to support fallback, so that all supported cases will not fail. #### Improvements (4) **QKV workspace size**. The logic for no_qkv_workspace could be easily out of sync since related code are scattered in different source files. I refactor the code to move all related code to one file (attention_prepare_qkv.cu) and add asserts, so that the logic can be in sync. (5) **Remove confusing concept of pass past in kv** parameters.pass_past_in_kv is confusing since the k/v in cross attention is not past state. Remove it and use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH instead. New code does not use past_key/past_value for cross attention, so the logic is more clear. (6) **More coverage and less workspace and less transpose of flash and efficient attention** Previously, there is one condition does not run flash or efficient attention: ``` bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; ``` After this change, we can use flash and efficient attention for the case, and also less workspace. For example, cross attention with bias, the original code uses two additional workspaces: ``` transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) add bias: query => q, temp_k_workspace => k, temp_v_workspace => v ``` New logic is like ``` if (has bias) Add bias to query, key, value, and store in q, k, v workspace else Use query, key and value directly as q, k and v in kernel ``` We can see that, we do not need allocate temp_k_workspace and temp_v_workspace so use less memory. New code saved two transposes in this case. Flash and efficient attention supports BSNH or BNSH formats for k and v. In old code, k/v are also converted to BSNH format. Some is not necessary. I do some change to convert k/v to BSNH or BNSH case by case. So that there are more cases can be covered by flash or efficient attention to improve performance. (6) **Debugging support** Previously, there is less debug info. In this change, I add a flag for debug info in the AttentionData. So that we can output debug info during the processing. Also add functions to consolidate the dumping of inputs, QKV processing and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow disable dumping from cuda kernel. #### Summary of changes (1) Refactoring the CheckInputs, and pass in operator type. (2) Refactoring the PrepareQKV to support fallback for packed qkv or packed kv inputs. (3) Change a few case of PrepareQKV to allow more case covered by flash and efficient attention. (4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace parameters.pass_past_in_kv (5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of assumption that key/value has no bias in this case. (6) Fix thread-safe issue in CumulatedSequenceLengthCache handling. (7) Add test cases to cover all supported scenarios. Current support scenarios for MultiHeadAttention for CUDA/CPU: | Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc | ---- | ---- | ---- | ------ | ----- | --------- | -------- | -----|--------- | BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed | BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in CPU | BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in CPU | BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for Q only | BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only present | BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and present <br> (not share buffer) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
ort_session = create_ort_session(
config, session_options, attention_kernel, use_symbolic_shape=False, use_tf32=use_tf32
)
cuda_session = CudaSession(ort_session, config.device, config.enable_cuda_graph)
shape_dict = config.shape_dict()
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
cuda_session.allocate_buffers(shape_dict)
return cuda_session
class OrtMultiHeadAttention:
"""A wrapper of ORT MultiHeadAttention to test relevance and performance."""
[CUDA] Fix MultiHeadAttention thread safe and bias support (#21498) ### Description #### Issues Fixed (1) **TRT cross attention not thread safe**. [Core changes like this](https://github.com/microsoft/onnxruntime/commit/6fd7aba3d4f27089de1a4ead86a2ae7e667c18b6) are used to make it thread-safe: * Add an once_flag to CumulatedSequenceLengthCache to make sure it is only initialized once; and change the cache to be read only after initialization. Previously, the content is not read-only so it might be changed by other thread and potentially cause buffer overrun. * The kernel initialization is not guarded (Although the factory of kernel loading has static mutex to guard multiple threading), so the mutable variable might be set by two different threads at the same time. Add an once_flag to avoid that. This requires need some workspace computation change as well. So I did not create a separated pull request. (2) **Bias for cross attention** That scenario has assumption that only query has bias, but not for key and value. However, such assumption is not verified in runtime and there was no comment of assumption, and there was no test case so the support of scenario was disabled by mistake. Actually, the scenario is used in whisper model (TODO: we shall add tests for whisper to CI pipeline, and also update fusion script to verify such assumptions if needed.) CUDA/CPU kernels supports bias for cross attention as long as bias is zero for key and value. I updated the check to support the scenario and added comments wherever there is such assumption. (3) **Fallback support** Previously, unfused kernel did not support packed qkv and packed kv formats. That means some case might fail since there is no fallback. I added new AddBiasTranpose cuda kernels for them to support fallback, so that all supported cases will not fail. #### Improvements (4) **QKV workspace size**. The logic for no_qkv_workspace could be easily out of sync since related code are scattered in different source files. I refactor the code to move all related code to one file (attention_prepare_qkv.cu) and add asserts, so that the logic can be in sync. (5) **Remove confusing concept of pass past in kv** parameters.pass_past_in_kv is confusing since the k/v in cross attention is not past state. Remove it and use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH instead. New code does not use past_key/past_value for cross attention, so the logic is more clear. (6) **More coverage and less workspace and less transpose of flash and efficient attention** Previously, there is one condition does not run flash or efficient attention: ``` bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; ``` After this change, we can use flash and efficient attention for the case, and also less workspace. For example, cross attention with bias, the original code uses two additional workspaces: ``` transpose: past_key (BxNxSxH) => temp_k_workspace (BxSxNxH), past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) add bias: query => q, temp_k_workspace => k, temp_v_workspace => v ``` New logic is like ``` if (has bias) Add bias to query, key, value, and store in q, k, v workspace else Use query, key and value directly as q, k and v in kernel ``` We can see that, we do not need allocate temp_k_workspace and temp_v_workspace so use less memory. New code saved two transposes in this case. Flash and efficient attention supports BSNH or BNSH formats for k and v. In old code, k/v are also converted to BSNH format. Some is not necessary. I do some change to convert k/v to BSNH or BNSH case by case. So that there are more cases can be covered by flash or efficient attention to improve performance. (6) **Debugging support** Previously, there is less debug info. In this change, I add a flag for debug info in the AttentionData. So that we can output debug info during the processing. Also add functions to consolidate the dumping of inputs, QKV processing and outputs; Add an environment variable `ORT_ENABLE_GPU_DUMP` to allow disable dumping from cuda kernel. #### Summary of changes (1) Refactoring the CheckInputs, and pass in operator type. (2) Refactoring the PrepareQKV to support fallback for packed qkv or packed kv inputs. (3) Change a few case of PrepareQKV to allow more case covered by flash and efficient attention. (4) use parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH to replace parameters.pass_past_in_kv (5) Allow bias input for Q_K_V_BSNH_BNSH_BNSH, and add comments of assumption that key/value has no bias in this case. (6) Fix thread-safe issue in CumulatedSequenceLengthCache handling. (7) Add test cases to cover all supported scenarios. Current support scenarios for MultiHeadAttention for CUDA/CPU: | Q | K | V | pastK| pastV | presentK| presentV | Bias | Op desc | ---- | ---- | ---- | ------ | ----- | --------- | -------- | -----|--------- | BSNH | BLNH| BLNH| - | - | - | - | QKV | not packed | BLN3H| - | - | - | - | - | - | QKV | qkv packed <br> not support in CPU | BSNH | BLN2H| - | - | - | - | - | --- | kv packed <br> not support in CPU | BSNH | BNLH| BNLH| - | - | - | - | Q-- | cross attention <br> bias for Q only | BSNH | BLNH | BLNH | - | - | BNTH | BNTH | QKV | no past <br> only present | BSNH | BLNH | BLNH | BNPH | BNPH | BNTH | BNTH | QKV | past and present <br> (not share buffer) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> https://github.com/microsoft/onnxruntime/issues/18854
2024-07-31 16:01:05 +00:00
def __init__(self, config: MultiHeadAttentionConfig, session_options=None, use_tf32: bool = True):
self.ort_session = create_session(config, session_options, use_tf32=use_tf32)
self.feed_dict = config.random_inputs()
def infer(self, run_options=None, synchronize=True):
return self.ort_session.infer(self.feed_dict, run_options=run_options, synchronize=synchronize)
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
def measure_latency(cuda_session: CudaSession, input_dict):
start = time.time()
_ = cuda_session.infer(input_dict)
end = time.time()
return end - start
def flops(batch, sequence_length, head_size, num_heads, causal):
return 4 * batch * sequence_length**2 * num_heads * head_size // (2 if causal else 1)
def tflops_per_second(flop, time):
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
try:
return (flop / time / 10**12) if not math.isnan(time) else 0.0
except ZeroDivisionError:
return None
def get_gpu_kernel_name(attention_kernel: SdpaKernel) -> str:
kernel_names = {
SdpaKernel.DEFAULT: "ort:default",
SdpaKernel.FLASH_ATTENTION: "ort:flash",
SdpaKernel.EFFICIENT_ATTENTION: "ort:efficient",
SdpaKernel.CUDNN_FLASH_ATTENTION: "ort:cudnn",
SdpaKernel.MATH: "ort:math",
}
assert attention_kernel in kernel_names
return kernel_names[attention_kernel]
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
def get_cpu_kernel_name(config: MultiHeadAttentionConfig) -> str:
# CPU Flash Attention does not support causal and kv cache etc.
if not (config.causal or config.use_kv_cache or config.past_sequence_length > 0):
if os.getenv("ORT_DISABLE_FLASH_ATTENTION") != "1":
return "ort:flash"
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
return "ort:math"
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
# ------------------------------------------------------------------
# Functions for benchmarking PyTorch SDPA
# ------------------------------------------------------------------
def benchmark_torch_function(repeats: int, func: Callable, *args, **kwargs) -> float:
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
warmup = 5
for _ in range(warmup):
func(*args, **kwargs)
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
timer = benchmark.Timer(
stmt="func(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "func": func},
)
return timer.timeit(number=repeats).median
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
def run_torch_sdpa(
batch_size: int,
q_seq_len: int,
kv_seq_len: int,
num_heads: int,
head_size: int,
causal: bool,
device,
dtype,
has_mask: bool = False,
mask_dim: int = 2,
mask_dtype=torch.bool,
backend: Optional[int] = None,
repeats: int = 100,
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
):
q_shape = (batch_size, num_heads, q_seq_len, head_size)
kv_shape = (batch_size, num_heads, kv_seq_len, head_size)
q = torch.randn(q_shape, device=device, dtype=dtype)
k = torch.randn(kv_shape, device=device, dtype=dtype)
v = torch.randn(kv_shape, device=device, dtype=dtype)
attn_mask = None
if has_mask:
mask_shape = (batch_size, num_heads, q_seq_len, kv_seq_len) if mask_dim == 4 else (q_seq_len, kv_seq_len)
attn_mask = torch.ones(mask_shape, dtype=mask_dtype, device=device)
context = sdpa_kernel(backend) if backend is not None else nullcontext()
with context:
average_latency = benchmark_torch_function(
repeats,
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
scaled_dot_product_attention,
q,
k,
v,
is_causal=causal,
attn_mask=attn_mask,
)
return average_latency
def get_test_configs(args: argparse.Namespace):
use_gpu: bool = args.use_gpu
if args.batch_size > 0:
run_unfused = args.sequence_length + args.past_sequence_length <= (2048 if use_gpu else 1024)
return [
(
args.batch_size,
args.sequence_length,
args.past_sequence_length,
args.num_heads,
args.head_size,
run_unfused,
),
]
if use_gpu:
# (batch_size, sequence_length, past_sequence_length, num_heads, head_size, run_unfused)
configs = [
(32, 512, 0, 64, 32, True),
(32, 512, 0, 128, 16, True),
(16, 1024, 0, 64, 32, True),
(16, 1024, 0, 128, 16, True),
(8, 2048, 0, 64, 32, True),
(8, 2048, 0, 128, 16, False),
(4, 4096, 0, 64, 32, False),
(4, 4096, 0, 128, 16, False),
(2, 8192, 0, 64, 32, False),
(2, 8192, 0, 128, 16, False),
(1, 16384, 0, 64, 32, False),
(1, 16384, 0, 128, 16, False),
# stable diffusion
(1, 4096, 0, 8, 40, False),
(1, 4096, 0, 8, 80, False),
(1, 4096, 0, 8, 160, False),
(4, 4096, 0, 8, 40, False),
(4, 4096, 0, 8, 80, False),
(4, 4096, 0, 8, 160, False),
(1, 16384, 0, 8, 40, False),
(1, 16384, 0, 8, 80, False),
(1, 16384, 0, 8, 160, False),
# bert-base
(128, 128, 0, 12, 64, True),
(64, 128, 0, 12, 64, True),
(128, 384, 0, 12, 64, True),
(64, 384, 0, 12, 64, True),
(128, 512, 0, 12, 64, True),
(64, 512, 0, 12, 64, True),
# TNLGv4
(4, 2048, 0, 32, 128, True),
(4, 4096, 0, 32, 128, False),
(8, 2048, 0, 32, 128, False),
(8, 4096, 0, 32, 128, False),
]
else:
configs = [
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
# TNLGv4
(1, 128, 0, 32, 128, True),
(1, 256, 0, 32, 128, True),
(1, 512, 0, 32, 128, True),
(1, 1024, 0, 32, 128, True),
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
# (1, 2048, 0, 32, 128, True),
# bert-base
(1, 128, 0, 12, 64, True),
(1, 384, 0, 12, 64, True),
(1, 512, 0, 12, 64, True),
(4, 128, 0, 12, 64, True),
(4, 384, 0, 12, 64, True),
(4, 512, 0, 12, 64, True),
# bert-large
(1, 128, 0, 16, 64, True),
(1, 384, 0, 16, 64, True),
(1, 512, 0, 16, 64, True),
(4, 128, 0, 16, 64, True),
(4, 384, 0, 16, 64, True),
(4, 512, 0, 16, 64, True),
]
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
return configs
def get_compute_capability():
assert torch.cuda.is_available()
major, minor = torch.cuda.get_device_capability()
sm = major * 10 + minor
return sm
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
class CaptureStdout:
def __init__(self):
self.fd = sys.stdout.fileno()
self.chunk_size = 1024
self.output = b""
def _capture(self):
chunks = []
while chunk := os.read(self._pipe_reader, self.chunk_size):
chunks.append(chunk)
self.output = b"".join(chunks)
def __enter__(self):
self._duped_fd = os.dup(self.fd)
self._pipe_reader, pipe_writer = os.pipe()
os.dup2(pipe_writer, self.fd)
os.close(pipe_writer)
self._capture_thread = threading.Thread(target=self._capture)
self._capture_thread.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
os.close(self.fd)
self._capture_thread.join()
os.close(self._pipe_reader)
os.dup2(self._duped_fd, self.fd)
os.close(self._duped_fd)
def sdpa_kernel_from_debug_info(
config: MultiHeadAttentionConfig, attention_kernel: SdpaKernel, sess_options: SessionOptions
):
os.environ["ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO"] = "1"
captured_text = None
try:
with CaptureStdout() as captured:
session = create_session(config, sess_options, attention_kernel=attention_kernel)
input_dict = config.random_inputs()
session.infer(input_dict)
captured_text = captured.output.decode()
except Exception as e:
print(f"Failed to run {attention_kernel=} for {config=}. Exception: {e}")
finally:
os.environ["ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO"] = "0"
if captured_text is not None:
m = re.search("SdpaKernel=(?P<kernel>[A-Z_]+)", captured_text)
if m is not None:
name = m.group("kernel")
kernel_names = {
"FLASH_ATTENTION": "ort:flash",
"EFFICIENT_ATTENTION": "ort:efficient",
"CUDNN_FLASH_ATTENTION": "ort:cudnn",
"MATH": "ort:math",
"TRT_FUSED_ATTENTION": "ort:trt_fmha",
"TRT_FLASH_ATTENTION": "ort:trt_flash",
"TRT_CROSS_ATTENTION": "ort:trt_cross",
"TRT_CAUSAL_ATTENTION": "ort:trt_causal",
}
return kernel_names[name]
else:
print("Failed to get sdpa kernel from debug info:", captured_text)
return None
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
def run_tflops_test(
csv_writer: csv.DictWriter,
args: argparse.Namespace,
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
):
use_gpu: bool = args.use_gpu
enable_cuda_graph: bool = args.use_cuda_graph
causal: bool = args.causal
intra_op_num_threads: int = args.intra_op_num_threads
repeats: int = args.repeats
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
print(f"run_tflops_test: causal={causal}")
if use_gpu:
device_id = torch.cuda.current_device()
device = torch.device("cuda", device_id)
formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH, InputFormats.Q_KV_BSNH_BSN2H, InputFormats.QKV_BSN3H]
provider = "CUDAExecutionProvider"
# flash attention is available for sm >= 80
sm = get_compute_capability()
if sm >= 80:
[CUDA] cuDNN Flash Attention (#21629) ### Description - [x] Add cuDNN flash attention using cudnn frontend, and enable it in MultiHeadAttention operator. - [x] Support attention mask. - [x] Support attention bias. - [x] Update tests and benchmark script. The cuDNN SDPA is disabled by default. To enable it, need the following: (1) Requires cuDNN 9.3 or newer version installed. (2) Set an environment variable `ORT_ENABLE_CUDNN_FLASH_ATTENTION=1` or set `sdpa_kernel=8` cuda provider option to enable it. (3) Only works on devices with compute capability >= 8.0. Note that some combinations of parameters might be rejected due to limited support of head dimension or sequence lengths. Future Works: (1) FP8 and BF16 APIs. Currently, only API for FP16 are exposed. (2) Add API to support ragged batching (padding removed in inputs). (3) Support other input formats (like QKV_BS3NH). (4) Currently, q are converted to BSNH, k/v are converted to either BSNH or BNSH format. May do some experiment to see whether converting q to BNSH could be better in some case. ### Example Benchmark Results on H100 The following tests are on FP16 MultiHeadAttention operator without attention mask and attention bias. #### Test Setting 1 batch_size | sequence_length | past_sequence_length | num_heads | head_size -- | -- | -- | -- | -- 16 | 256 | 0 | 32 | 128 format | average_latency | tflops | kernel -- | -- | -- | -- Q,K,V (BNSH) | 0.000075 | 229.5 | torch:flash Q,K,V (BNSH) | 0.000119 | 144.8 | torch:efficient Q,K,V (BNSH) | 0.000224 | 76.5 | torch:math Q,K,V (BSNH) | 0.000075 | 227.8 | ort:cudnn Q,K,V (BSNH) | 0.000094 | 182.8 | ort:flash Q,K,V (BSNH) | 0.000138 | 124.7 | ort:efficient Q,K,V (BSNH) | 0.000438 | 39.3 | ort:math Q,KV | 0.000129 | 133.0 | ort:cudnn Q,KV | 0.000151 | 114.1 | ort:flash Q,KV | 0.000194 | 88.5 | ort:efficient QKV | 0.000154 | 111.8 | ort:cudnn QKV | 0.000175 | 98.0 | ort:flash QKV | 0.000217 | 79.0 | ort:efficient #### Test Setting 2 batch_size | sequence_length | past_sequence_length | num_heads | head_size -- | -- | -- | -- | -- 16 | 512 | 0 | 16 | 64 format | average_latency | tflops | kernel -- | -- | -- | -- Q,K,V (BNSH) | 0.000069 | 249.2 | torch:flash Q,K,V (BNSH) | 0.000141 | 121.7 | torch:efficient Q,K,V (BNSH) | 0.000294 | 58.5 | torch:math Q,K,V (BSNH) | 0.000077 | 221.7 | ort:cudnn Q,K,V (BSNH) | 0.000087 | 196.6 | ort:flash Q,K,V (BSNH) | 0.000163 | 105.6 | ort:efficient Q,K,V (BSNH) | 0.000651 | 26.4 | ort:math Q,KV | 0.000103 | 167.1 | ort:cudnn Q,KV | 0.000117 | 146.3 | ort:flash Q,KV | 0.000192 | 89.6 | ort:efficient QKV | 0.000113 | 151.5 | ort:cudnn QKV | 0.000128 | 134.7 | ort:flash QKV | 0.000201 | 85.3 | ort:efficient
2024-08-20 15:50:22 +00:00
backends = [
SdpaKernel.DEFAULT,
SdpaKernel.FLASH_ATTENTION,
SdpaKernel.EFFICIENT_ATTENTION,
SdpaKernel.CUDNN_FLASH_ATTENTION,
SdpaKernel.MATH,
]
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
else:
backends = [SdpaKernel.DEFAULT, SdpaKernel.EFFICIENT_ATTENTION, SdpaKernel.MATH]
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
else:
device_id = 0
device = torch.device("cpu")
formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH]
enable_cuda_graph = False
provider = "CPUExecutionProvider"
backends = [SdpaKernel.DEFAULT]
configs = get_test_configs(args)
print(
"\nformat\tcausal\tattBias\tbatch\tseqlen\tpast\theads\th_dim\tthreads\tms\tTFLOPS\tsdpa_kernel\trequest_kernel"
)
for input_format in formats:
for batch_size, sequence_length, past_sequence_length, num_heads, head_size, enable_unfused in configs:
config = MultiHeadAttentionConfig(
batch_size=batch_size,
sequence_length=sequence_length,
num_heads=num_heads,
head_size=head_size,
causal=causal,
use_kv_cache=past_sequence_length > 0,
past_sequence_length=past_sequence_length,
max_cache_sequence_length=None,
kv_sequence_length=None,
provider=provider,
enable_cuda_graph=enable_cuda_graph,
device=device,
dtype=torch.float16 if use_gpu else torch.float,
share_past_present_buffer=False,
input_format=input_format,
has_attn_bias=args.has_attn_bias,
broadcast_attn_bias_dim_0=args.broadcast_attn_bias_dim_0,
broadcast_attn_bias_dim_1=args.broadcast_attn_bias_dim_1,
)
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
for attention_kernel in backends:
sess_options = SessionOptions()
sess_options.intra_op_num_threads = intra_op_num_threads
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
if use_gpu:
request_kernel = get_gpu_kernel_name(attention_kernel)
else:
request_kernel = get_cpu_kernel_name(config)
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
if "math" in request_kernel:
# Skip large sequence length for Unfused kernel to avoid OOM.
if not enable_unfused:
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
if config.verbose:
print(f"skip unfused kernel for {vars(config)}")
continue
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
# Unfused kernel does not support packed QKV or packed KV formats.
if input_format not in [InputFormats.Q_K_V_BSNH_BSNH_BSNH]:
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
if config.verbose:
print(f"skip input_format for {vars(config)}")
continue
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
if use_gpu:
actual_kernel = sdpa_kernel_from_debug_info(config, attention_kernel, sess_options)
if actual_kernel is None:
print(f"Warning: skip {config} since kernel from debug info is None")
continue
else:
# CPU has no debug info for now.
actual_kernel = request_kernel
session = create_session(config, sess_options, attention_kernel=attention_kernel)
input_dict = config.random_inputs()
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
# warm up session
try:
_ = measure_latency(session, input_dict)
except Exception as e:
print(f"Failed to run {request_kernel=} for {config=}. Exception: {e}")
continue
latency_list = []
for _ in range(repeats):
latency = measure_latency(session, input_dict)
latency_list.append(latency)
average_latency = statistics.mean(latency_list)
del session
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
format_str = InputFormats.input_format_str(input_format)
# compute TFLOPS per second
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
speed = None
if past_sequence_length == 0:
speed = tflops_per_second(
flops(batch_size, sequence_length, head_size, num_heads, causal), average_latency
)
row = {
"use_gpu": use_gpu,
"enable_cuda_graph": enable_cuda_graph,
"format": format_str,
"causal": causal,
"batch_size": batch_size,
"sequence_length": sequence_length,
"past_sequence_length": past_sequence_length,
"num_heads": num_heads,
"head_size": head_size,
"has_attn_bias": args.has_attn_bias,
"broadcast_attn_bias_dim_0": args.broadcast_attn_bias_dim_0,
"broadcast_attn_bias_dim_1": args.broadcast_attn_bias_dim_1,
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
"intra_op_num_threads": intra_op_num_threads,
"average_latency": average_latency,
"tflops": speed,
"request_kernel": request_kernel,
"kernel": actual_kernel,
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
}
csv_writer.writerow(row)
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
speed = f"{speed:.2f}" if speed is not None else "NA"
print(
f"{format_str}\t{causal}\t{args.has_attn_bias}\t{batch_size}\t"
f"{sequence_length}\t{past_sequence_length}\t{num_heads}\t{head_size}\t"
f"{intra_op_num_threads}\t{average_latency * 1000:.2f}\t{speed}\t{actual_kernel}\t{request_kernel}"
)
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
def run_torch_test(
csv_writer: csv.DictWriter,
args: argparse.Namespace,
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
):
use_gpu: bool = args.use_gpu
causal: bool = args.causal
configs = get_test_configs(args)
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
if use_gpu:
if not torch.cuda.is_available():
return
device_id = torch.cuda.current_device()
device = torch.device("cuda", device_id)
dtype = torch.float16
backends = [
None,
SDPBackend.FLASH_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION,
SDPBackend.CUDNN_ATTENTION,
SDPBackend.MATH,
]
else:
device = torch.device("cpu")
dtype = torch.float32
backends = [None]
backend_names = {
SDPBackend.FLASH_ATTENTION: "torch:flash",
SDPBackend.EFFICIENT_ATTENTION: "torch:efficient",
SDPBackend.CUDNN_ATTENTION: "torch:cudnn",
SDPBackend.MATH: "torch:math",
None: "torch:default",
}
# Test PyTorch latency
for batch_size, sequence_length, past_sequence_length, num_heads, head_size, enable_unfused in configs:
for backend in backends:
if backend == SDPBackend.MATH and not enable_unfused:
continue
if backend == SDPBackend.FLASH_ATTENTION and platform.system() != "Linux":
continue
backend_name = backend_names[backend]
try:
with torch.no_grad():
torch_latency = run_torch_sdpa(
batch_size,
sequence_length,
sequence_length,
num_heads,
head_size,
causal,
has_mask=False,
mask_dim=2,
mask_dtype=torch.bool,
device=device,
dtype=dtype,
backend=backend,
repeats=args.repeats,
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
)
except RuntimeError:
continue
speed = tflops_per_second(flops(batch_size, sequence_length, head_size, num_heads, causal), torch_latency)
input_format = "Q,K,V"
print(
f"{input_format}\t{causal}\t{False}\t{batch_size}\t"
f"{sequence_length}\t{past_sequence_length}\t{num_heads}\t{head_size}\t"
f"{torch.get_num_threads()}\t{torch_latency * 1000:.2f}\t{speed}\t{backend_name}\t{backend_name}"
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
)
row = {
"use_gpu": use_gpu,
"enable_cuda_graph": False,
"format": input_format,
"causal": causal,
"batch_size": batch_size,
"sequence_length": sequence_length,
"past_sequence_length": past_sequence_length,
"num_heads": num_heads,
"head_size": head_size,
"has_attn_bias": False,
"broadcast_attn_bias_dim_0": False,
"broadcast_attn_bias_dim_1": False,
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
"intra_op_num_threads": torch.get_num_threads(),
"average_latency": torch_latency,
"tflops": speed,
"request_kernel": backend_name,
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
"kernel": backend_name,
}
csv_writer.writerow(row)
def run_tflops_tests(args):
features = "gpu" if args.use_gpu else "cpu"
if args.causal:
features += "_causal"
if args.past_sequence_length > 0:
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
features += "_past"
csv_filename = "benchmark_mha_{}_{}_{}.csv".format(
features,
"torch" if args.torch else "ort",
datetime.now().strftime("%Y%m%d-%H%M%S"),
)
with open(csv_filename, mode="a", newline="") as csv_file:
column_names = [
"use_gpu",
"enable_cuda_graph",
"format",
"causal",
"batch_size",
"sequence_length",
"past_sequence_length",
"num_heads",
"head_size",
"has_attn_bias",
"broadcast_attn_bias_dim_0",
"broadcast_attn_bias_dim_1",
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
"intra_op_num_threads",
"average_latency",
"tflops",
"request_kernel",
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
"kernel",
]
csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
csv_writer.writeheader()
if args.torch:
run_torch_test(csv_writer, args)
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
else:
run_tflops_test(csv_writer, args)
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
def plot_prompt_performance(
model_name: str,
batch_size: int,
num_heads: int,
head_size: int,
max_seq_len: int,
):
import triton
formats = InputFormats.get_name_list()
# Exclude cross attention since kernel crashes for some configuration.
formats = formats[:-1]
settings = {
"line_vals": formats,
"line_names": ["ORT-MHA:" + name for name in formats],
"styles": [("red", "solid"), ("yellow", "dashdot"), ("blue", "dashed"), ("green", "dotted")][0 : len(formats)],
}
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
sm = get_compute_capability()
configs = [
triton.testing.Benchmark(
x_names=["sequence_length"],
x_vals=[2**i for i in range(6, 17) if 2**i <= max_seq_len],
line_arg="input_format",
ylabel="ms",
**settings,
plot_name=f"prompt-sm{sm}-{model_name}-b{batch_size}-h{num_heads}_{head_size}-fp16",
args={
"batch_size": batch_size,
"num_heads": num_heads,
"head_size": head_size,
},
)
]
@triton.testing.perf_report(configs)
def benchmark(
input_format: str,
sequence_length: int,
batch_size: int,
num_heads: int,
head_size: int,
device="cuda",
):
warmup = 15
repeat = 100
config: MultiHeadAttentionConfig = MultiHeadAttentionConfig(
batch_size=batch_size,
sequence_length=sequence_length,
num_heads=num_heads,
head_size=head_size,
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
causal=False,
past_sequence_length=0,
kv_sequence_length=sequence_length if input_format == "Q,K',V'" else None,
max_cache_sequence_length=max_seq_len,
provider="CUDAExecutionProvider",
enable_cuda_graph=False,
device=device,
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
dtype=torch.float16,
use_kv_cache=False,
input_format=InputFormats.convert(input_format),
)
obj = OrtMultiHeadAttention(config)
ms = triton.testing.do_bench(obj.infer, warmup=warmup, rep=repeat)
return ms
benchmark.run(save_path=".", print_data=True)
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
def run_bert_performance_test():
"""
Run performance tests for prompt and token generation.
"""
configures = [
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
# (1, 32, 128, 8192, "TNLGv4"),
# (4, 32, 128, 8192, "TNLGv4"),
(1, 12, 64, 1024, "BertBase"),
(16, 12, 64, 1024, "BertBase"),
(1, 16, 64, 1024, "BertLarge"),
(8, 16, 64, 1024, "BertLarge"),
]
for batch_size, num_heads, head_size, max_seq_len, model_name in configures:
plot_prompt_performance(
batch_size=batch_size,
num_heads=num_heads,
head_size=head_size,
max_seq_len=max_seq_len,
model_name=model_name,
)
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
def _parse_arguments():
parser = argparse.ArgumentParser(description="Benchmark MultiHeadAttention for ONNX Runtime and PyTorch.")
parser.add_argument(
"--use_gpu",
required=False,
action="store_true",
help="Use GPU for inference.",
)
parser.set_defaults(use_gpu=False)
parser.add_argument(
"--use_cuda_graph",
required=False,
action="store_true",
help="Use cuda graph in onnxruntime.",
)
parser.set_defaults(use_cuda_graph=False)
parser.add_argument(
"--intra_op_num_threads",
required=False,
type=int,
choices=[0, 1, 2, 4, 8, 16],
default=0,
help="intra_op_num_threads for onnxruntime. ",
)
parser.add_argument(
"--causal",
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
required=False,
action="store_true",
help="test unidirectional",
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
)
parser.set_defaults(causal=False)
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
parser.add_argument(
"-b",
"--batch_size",
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
required=False,
type=int,
default=0,
help="batch size",
)
parser.add_argument(
"-s",
"--sequence_length",
required=False,
type=int,
default=512,
help="sequence length",
)
parser.add_argument(
"-p",
"--past_sequence_length",
required=False,
type=int,
default=0,
help="past sequence length",
)
parser.add_argument(
"-n",
"--num_heads",
required=False,
type=int,
default=16,
help="number of attention heads",
)
parser.add_argument(
"-d",
"--head_size",
required=False,
type=int,
default=64,
help="hidden dimension per head",
)
parser.add_argument(
"-r",
"--repeats",
required=False,
type=int,
default=0,
help="number of repeats for performance test",
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
)
parser.add_argument(
"--torch",
required=False,
action="store_true",
help="test pytorch instead of onnxruntime",
)
parser.set_defaults(torch=False)
parser.add_argument(
"--has_attn_bias",
required=False,
action="store_true",
help="has attention bias",
)
parser.set_defaults(has_attn_bias=False)
parser.add_argument(
"--broadcast_attn_bias_dim_0",
required=False,
action="store_true",
help="broadcast attention bias dimension 0",
)
parser.set_defaults(broadcast_attn_bias_dim_0=False)
parser.add_argument(
"--broadcast_attn_bias_dim_1",
required=False,
action="store_true",
help="broadcast attention bias dimension 1",
)
parser.set_defaults(broadcast_attn_bias_dim_1=False)
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
args = parser.parse_args()
return args
Flash Attention v2 MHA (#17227) ### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
2023-08-31 20:52:21 +00:00
if __name__ == "__main__":
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
args = _parse_arguments()
print(f"arguments:{args}")
if args.repeats == 0:
args.repeats = 10000 if args.use_gpu else 100
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
if args.use_gpu:
assert torch.cuda.is_available()
if not args.torch:
assert "CUDAExecutionProvider" in get_available_providers()
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
if args.torch:
assert Version(torch.__version__) >= Version("2.3.0")
assert args.past_sequence_length == 0
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
if args.use_gpu and args.batch_size == 0 and not args.torch:
if platform.system() == "Linux":
s = torch.cuda.Stream()
with torch.cuda.stream(s), torch.no_grad():
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
run_bert_performance_test()
Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
2024-07-27 01:45:14 +00:00
run_tflops_tests(args)