mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-01 03:45:06 +00:00
[CUDA] FusedMHARunnerFP16v2 thread-safe (#21420)
### Description
- [x] Rewrite FusedMHARunnerFP16v2 to make it thread-safe.
- [x] Add multi-threading tests
Previously, the kernel parameters params is stored as a member of mha
runner, which means that different threads might change the params at
the same time and impacts the other threads.
For example, if batch_size and seq_len was changed by another thread to
larger values in setup(...), buffer overrun might happen in run(...)
because a kernel could read/write memory out of range of allocated
buffers.
In new implementation, I change the api and remove mutable member
variables to make it thread safe. Below is summary of change:
Before:
```
class FusedMHARunnerFP16v2::mhaImpl {
void setup(int seq_len, int batch_size) {
// change scalar params
}
void run(input, output) {
// change params for input and output pointers
// launch kernel using params
}
Fused_multihead_attention_params_v2 params; // mutable, not thread-safe
}
```
After:
```
class FusedMHARunnerFP16v2::FmhaImpl {
void setup(int seq_len, int batch_size, Fused_multihead_attention_params_v2& params) {
// change params
}
void run(params, input, output) {
// change params with input and output pointers
// launch kernel using params
}
}
```
### Motivation and Context
https://github.com/microsoft/onnxruntime/issues/18854
https://github.com/microsoft/onnxruntime/issues/21413
This commit is contained in:
parent
11bf309736
commit
a6c5e2cd20
10 changed files with 537 additions and 237 deletions
|
|
@ -149,8 +149,8 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
nullptr == relative_position_bias &&
|
||||
parameters.past_sequence_length == 0 &&
|
||||
parameters.hidden_size == parameters.v_hidden_size &&
|
||||
FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
|
||||
enable_trt_flash_attention_, true);
|
||||
FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length,
|
||||
enable_trt_flash_attention_, true);
|
||||
if (use_causal_fused_runner) {
|
||||
// Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node.
|
||||
if (nullptr == fused_fp16_runner_.get()) {
|
||||
|
|
@ -171,8 +171,8 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
nullptr == present &&
|
||||
nullptr == relative_position_bias &&
|
||||
parameters.hidden_size == parameters.v_hidden_size &&
|
||||
FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
|
||||
enable_trt_flash_attention_, false);
|
||||
FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length,
|
||||
enable_trt_flash_attention_, false);
|
||||
|
||||
if (use_fused_runner) {
|
||||
// Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node.
|
||||
|
|
@ -184,8 +184,8 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
}
|
||||
|
||||
// In case some kernel not loaded due to shared memory limit, we need to double check here.
|
||||
const int S = fused_fp16_runner_->getSFromMaxSeqLen(sequence_length);
|
||||
if (fused_fp16_runner_->isValid(S)) {
|
||||
const int normalized_seq_len = fused_fp16_runner_->NormalizeSequenceLength(sequence_length);
|
||||
if (fused_fp16_runner_->IsValid(normalized_seq_len)) {
|
||||
fused_runner = fused_fp16_runner_.get();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -245,12 +245,10 @@ Status FusedTrtSelfAttention(
|
|||
|
||||
FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast<FusedMHARunnerFP16v2*>(data.fused_runner);
|
||||
|
||||
const int S = causal ? sequence_length : fused_fp16_runner->getSFromMaxSeqLen(sequence_length);
|
||||
const int s = causal ? sequence_length : fused_fp16_runner->NormalizeSequenceLength(sequence_length);
|
||||
|
||||
// B = 2 * batch_size when there is padding in input, and B = batch_size when padding is removed.
|
||||
const int B = (nullptr == data.mask_index ? batch_size : 2 * batch_size);
|
||||
|
||||
fused_fp16_runner->setup(S, B);
|
||||
const int b = (nullptr == data.mask_index ? batch_size : 2 * batch_size);
|
||||
|
||||
if (!causal) {
|
||||
assert(data.qkv_format == AttentionQkvFormat::QKV_BSN3H);
|
||||
|
|
@ -261,12 +259,12 @@ Status FusedTrtSelfAttention(
|
|||
packed_qkv = data.query;
|
||||
}
|
||||
|
||||
fused_fp16_runner->run(packed_qkv, sequence_offset, data.output, stream);
|
||||
fused_fp16_runner->Run(b, s, packed_qkv, sequence_offset, data.output, stream);
|
||||
DUMP_TENSOR("fused output", data.output,
|
||||
batch_size, sequence_length, parameters.num_heads, parameters.v_head_size);
|
||||
} else {
|
||||
assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH);
|
||||
fused_fp16_runner->run(data.gemm_buffer, sequence_offset, data.output, stream);
|
||||
fused_fp16_runner->Run(b, s, data.gemm_buffer, sequence_offset, data.output, stream);
|
||||
DUMP_TENSOR("fused causal output", data.output,
|
||||
batch_size, sequence_length, parameters.num_heads, parameters.v_head_size);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -193,8 +193,8 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
(nullptr == key_padding_mask || is_mask_1d_seq_len) &&
|
||||
parameters.hidden_size == parameters.v_hidden_size &&
|
||||
parameters.sequence_length == parameters.kv_sequence_length &&
|
||||
FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
|
||||
enable_trt_flash_attention_, false);
|
||||
FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length,
|
||||
enable_trt_flash_attention_, false);
|
||||
if (use_fused_runner) {
|
||||
// Here we assume that num_heads and head_size does not change for a MultiHeadAttention node.
|
||||
if (nullptr == fused_fp16_runner_.get()) {
|
||||
|
|
@ -206,8 +206,8 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
}
|
||||
|
||||
// In case some kernel not loaded due to shared memory limit, we need to double check here.
|
||||
const int S = fused_fp16_runner_->getSFromMaxSeqLen(sequence_length);
|
||||
if (fused_fp16_runner_->isValid(S)) {
|
||||
const int normalized_seq_len = fused_fp16_runner_->NormalizeSequenceLength(sequence_length);
|
||||
if (fused_fp16_runner_->IsValid(normalized_seq_len)) {
|
||||
fused_runner = fused_fp16_runner_.get();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -55,11 +55,11 @@ MHARunner* TrtFusedAttention<T>::GetFusedRunner(const cudaDeviceProp& device_pro
|
|||
|
||||
// Check whether we can use fused kernel
|
||||
int sm = device_prop.major * 10 + device_prop.minor;
|
||||
bool is_fMHA_supported = FusedMHARunnerFP16v2::is_supported(sm,
|
||||
parameters.head_size,
|
||||
parameters.sequence_length,
|
||||
enable_trt_flash_attention_,
|
||||
false /*causal*/);
|
||||
bool is_fMHA_supported = FusedMHARunnerFP16v2::IsSupported(sm,
|
||||
parameters.head_size,
|
||||
parameters.sequence_length,
|
||||
enable_trt_flash_attention_,
|
||||
false /*causal*/);
|
||||
|
||||
if (!is_fMHA_supported) {
|
||||
return fused_runner;
|
||||
|
|
@ -72,8 +72,8 @@ MHARunner* TrtFusedAttention<T>::GetFusedRunner(const cudaDeviceProp& device_pro
|
|||
}
|
||||
|
||||
// In case some kernel not loaded due to shared memory limit, we need to double check here.
|
||||
const int S = fused_fp16_runner_->getSFromMaxSeqLen(parameters.sequence_length);
|
||||
if (fused_fp16_runner_->isValid(S)) {
|
||||
const int normalized_seq_len = fused_fp16_runner_->NormalizeSequenceLength(parameters.sequence_length);
|
||||
if (fused_fp16_runner_->IsValid(normalized_seq_len)) {
|
||||
fused_runner = fused_fp16_runner_.get();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -459,10 +459,9 @@ Status FusedScaledDotProductAttention(
|
|||
parameters.token_count, stream);
|
||||
|
||||
FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast<FusedMHARunnerFP16v2*>(fused_runner);
|
||||
const int S = fused_fp16_runner->getSFromMaxSeqLen(sequence_length);
|
||||
fused_fp16_runner->setup(S, batch_size);
|
||||
|
||||
fused_fp16_runner->run(data.workspace, data.cumulative_sequence_length, data.output, stream);
|
||||
const int normalized_seq_len = fused_fp16_runner->NormalizeSequenceLength(sequence_length);
|
||||
fused_fp16_runner->Run(batch_size, normalized_seq_len,
|
||||
data.workspace, data.cumulative_sequence_length, data.output, stream);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -575,10 +575,8 @@ Status FusedAttentionTrt(
|
|||
}
|
||||
|
||||
FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast<FusedMHARunnerFP16v2*>(fused_runner);
|
||||
const int S = fused_fp16_runner->getSFromMaxSeqLen(sequence_length);
|
||||
fused_fp16_runner->setup(S, batch_size);
|
||||
|
||||
fused_fp16_runner->run(qkv, data.cumulative_sequence_length, data.output, stream);
|
||||
const int normalized_seq_len = fused_fp16_runner->NormalizeSequenceLength(sequence_length);
|
||||
fused_fp16_runner->Run(batch_size, normalized_seq_len, qkv, data.cumulative_sequence_length, data.output, stream);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,10 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// Modifications: Update interface and implmentation to be thread-safe
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h"
|
||||
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/fused_multihead_attention_v2.h"
|
||||
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/flash_attention/fmha_flash_attention.h"
|
||||
|
|
@ -34,28 +38,28 @@ void set_alpha_fp16(uint32_t& alpha, float norm) {
|
|||
alpha = temp.u32;
|
||||
}
|
||||
|
||||
class FusedMHARunnerFP16v2::mhaImpl {
|
||||
class FusedMHARunnerFP16v2::FmhaImpl {
|
||||
public:
|
||||
mhaImpl(FusedMHARunnerFP16v2* interface)
|
||||
: interface(interface),
|
||||
sm(interface->mSm),
|
||||
xmmaKernel(getXMMAKernelsV2(DATA_TYPE_FP16, sm)) {
|
||||
FmhaImpl(FusedMHARunnerFP16v2* interface, int sm)
|
||||
: interface_(interface),
|
||||
sm_(sm),
|
||||
xmma_kernel_(getXMMAKernelsV2(DATA_TYPE_FP16, sm)) {
|
||||
ORT_ENFORCE((sm == kSM_70 || sm == kSM_75 || sm == kSM_80 || sm == kSM_86 || sm == kSM_89),
|
||||
"Unsupported architecture");
|
||||
|
||||
flash_attention_kernel = nullptr;
|
||||
if (interface->mEnableFlashAttention) {
|
||||
flash_attention_kernel = get_flash_attention_kernels(DATA_TYPE_FP16, sm);
|
||||
flash_kernel_ = nullptr;
|
||||
if (interface_->enable_flash_attention_) {
|
||||
flash_kernel_ = get_flash_attention_kernels(DATA_TYPE_FP16, sm);
|
||||
}
|
||||
|
||||
params.clear();
|
||||
}
|
||||
|
||||
~mhaImpl() {}
|
||||
~FmhaImpl() {}
|
||||
|
||||
void setup(const int seq_len, const int B) {
|
||||
// For bert and vit, use flash attention when sequence length is larger than the threshold.
|
||||
use_flash_attention = is_flash_attention(seq_len);
|
||||
void Setup(Fused_multihead_attention_params_v2& params,
|
||||
int sequence_length, // normalized sequence length
|
||||
int batch_size,
|
||||
bool& use_flash_attention) const {
|
||||
use_flash_attention = UseFlashAttention(sequence_length);
|
||||
|
||||
params.force_unroll = use_flash_attention;
|
||||
|
||||
|
|
@ -67,27 +71,27 @@ class FusedMHARunnerFP16v2::mhaImpl {
|
|||
warps_m = 4;
|
||||
warps_n = 1;
|
||||
} else {
|
||||
if (sm == 70) {
|
||||
if (seq_len == 64 || seq_len == 96) {
|
||||
if (sm_ == 70) {
|
||||
if (sequence_length == 64 || sequence_length == 96) {
|
||||
warps_m = 2;
|
||||
warps_n = 2;
|
||||
} else if (seq_len == 128) {
|
||||
} else if (sequence_length == 128) {
|
||||
warps_m = 1;
|
||||
warps_n = 4;
|
||||
} else if (seq_len == 256 || seq_len == 384) {
|
||||
} else if (sequence_length == 256 || sequence_length == 384) {
|
||||
warps_m = 1;
|
||||
warps_n = 8;
|
||||
} else {
|
||||
ORT_ENFORCE(false, "Unsupported sequence length");
|
||||
}
|
||||
} else {
|
||||
if (seq_len == 32 || seq_len == 64 || seq_len == 96 || seq_len == 128) {
|
||||
if (sequence_length == 32 || sequence_length == 64 || sequence_length == 96 || sequence_length == 128) {
|
||||
warps_m = 2;
|
||||
warps_n = 2;
|
||||
} else if (seq_len == 192 || seq_len == 256) {
|
||||
} else if (sequence_length == 192 || sequence_length == 256) {
|
||||
warps_m = 1;
|
||||
warps_n = 4;
|
||||
} else if (seq_len == 384) {
|
||||
} else if (sequence_length == 384) {
|
||||
warps_m = 1;
|
||||
warps_n = 8;
|
||||
} else {
|
||||
|
|
@ -97,11 +101,11 @@ class FusedMHARunnerFP16v2::mhaImpl {
|
|||
}
|
||||
|
||||
// The number of threads per CTA.
|
||||
threads_per_cta = warps_m * warps_n * warps_k * 32;
|
||||
size_t threads_per_cta = warps_m * warps_n * warps_k * 32;
|
||||
// The number of xmmas in the M dimension. We use one uint32_t per XMMA in the M dimension.
|
||||
xmmas_m = (seq_len + 16 * warps_m - 1) / (16 * warps_m);
|
||||
size_t xmmas_m = (sequence_length + 16 * warps_m - 1) / (16 * warps_m);
|
||||
|
||||
const float scale_bmm1 = interface->mScale;
|
||||
const float scale_bmm1 = interface_->scale_;
|
||||
const float scale_softmax = 1.f; // Seems to be only required for int8
|
||||
const float scale_bmm2 = 1.f;
|
||||
|
||||
|
|
@ -109,20 +113,21 @@ class FusedMHARunnerFP16v2::mhaImpl {
|
|||
set_alpha_fp16(params.scale_softmax, scale_softmax);
|
||||
set_alpha_fp16(params.scale_bmm2, scale_bmm2);
|
||||
|
||||
params.b = B;
|
||||
params.h = interface->mNumHeads;
|
||||
params.s = seq_len;
|
||||
params.d = interface->mHeadSize;
|
||||
params.b = batch_size;
|
||||
params.h = interface_->num_heads_;
|
||||
params.s = sequence_length;
|
||||
params.d = interface_->head_size_;
|
||||
|
||||
params.qkv_stride_in_bytes = 3 * interface->mNumHeads * interface->mHeadSize * sizeof(half);
|
||||
params.qkv_stride_in_bytes = 3 * interface_->num_heads_ * interface_->head_size_ * sizeof(half);
|
||||
params.packed_mask_stride_in_bytes = xmmas_m * threads_per_cta * sizeof(uint32_t);
|
||||
params.o_stride_in_bytes = interface->mNumHeads * interface->mHeadSize * sizeof(half);
|
||||
|
||||
has_causal_mask = false;
|
||||
params.o_stride_in_bytes = interface_->num_heads_ * interface_->head_size_ * sizeof(half);
|
||||
}
|
||||
|
||||
void setup_causal_masked_fmha(const int seq_len, const int B) {
|
||||
const float scale_bmm1 = interface->mScale;
|
||||
void SetupCausal(Fused_multihead_attention_params_v2& params,
|
||||
int sequence_length, // normalized sequence length
|
||||
int batch_size,
|
||||
bool& use_flash_attention) const {
|
||||
const float scale_bmm1 = interface_->scale_;
|
||||
const float scale_softmax = 1.f; // Seems to be only required for int8
|
||||
const float scale_bmm2 = 1.f;
|
||||
|
||||
|
|
@ -130,16 +135,17 @@ class FusedMHARunnerFP16v2::mhaImpl {
|
|||
set_alpha_fp16(params.scale_softmax, scale_softmax);
|
||||
set_alpha_fp16(params.scale_bmm2, scale_bmm2);
|
||||
|
||||
params.b = B;
|
||||
params.h = interface->mNumHeads;
|
||||
params.s = seq_len;
|
||||
params.d = interface->mHeadSize;
|
||||
params.b = batch_size;
|
||||
params.h = interface_->num_heads_;
|
||||
params.s = sequence_length;
|
||||
params.d = interface_->head_size_;
|
||||
|
||||
params.qkv_stride_in_bytes = 3 * interface->mNumHeads * interface->mHeadSize * sizeof(half);
|
||||
params.o_stride_in_bytes = interface->mNumHeads * interface->mHeadSize * sizeof(half);
|
||||
params.qkv_stride_in_bytes = 3 * interface_->num_heads_ * interface_->head_size_ * sizeof(half);
|
||||
params.o_stride_in_bytes = interface_->num_heads_ * interface_->head_size_ * sizeof(half);
|
||||
|
||||
// fallback to original fmha_v2 when head_size <= 64 and seq_len <- 128
|
||||
use_flash_attention = interface->mEnableFlashAttention;
|
||||
use_flash_attention = interface_->enable_flash_attention_;
|
||||
|
||||
// fallback to original fmha_v2 when head_size <= 64 and sequence_length <= 128
|
||||
if (params.d <= 64 && params.s <= 128) {
|
||||
use_flash_attention = false;
|
||||
// get max sequence length
|
||||
|
|
@ -152,97 +158,87 @@ class FusedMHARunnerFP16v2::mhaImpl {
|
|||
|
||||
// set flags
|
||||
params.force_unroll = use_flash_attention;
|
||||
has_causal_mask = true;
|
||||
}
|
||||
|
||||
void run(const void* input, const void* cu_seqlens, void* output, cudaStream_t stream) {
|
||||
void Run(Fused_multihead_attention_params_v2& params,
|
||||
const void* input,
|
||||
const void* cu_seqlens,
|
||||
void* output,
|
||||
cudaStream_t stream,
|
||||
bool use_flash_attention,
|
||||
bool has_causal_mask) const {
|
||||
params.qkv_ptr = const_cast<void*>(input);
|
||||
params.o_ptr = output;
|
||||
params.cu_seqlens = static_cast<int*>(const_cast<void*>(cu_seqlens));
|
||||
|
||||
if (use_flash_attention && flash_attention_kernel != nullptr && !has_causal_mask) {
|
||||
flash_attention_kernel->run(params, stream);
|
||||
if (use_flash_attention && flash_kernel_ != nullptr && !has_causal_mask) {
|
||||
flash_kernel_->run(params, stream);
|
||||
} else {
|
||||
xmmaKernel->run(params, stream, use_flash_attention, has_causal_mask);
|
||||
xmma_kernel_->run(params, stream, use_flash_attention, has_causal_mask);
|
||||
}
|
||||
|
||||
CUDA_CALL_THROW(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
bool isValid(int s) const {
|
||||
if (is_flash_attention(s)) {
|
||||
return (flash_attention_kernel != nullptr) && flash_attention_kernel->isValid(s);
|
||||
bool IsValid(int sequence_length) const {
|
||||
if (UseFlashAttention(sequence_length)) {
|
||||
return (flash_kernel_ != nullptr) && flash_kernel_->isValid(sequence_length);
|
||||
}
|
||||
|
||||
return xmmaKernel->isValid(s);
|
||||
return xmma_kernel_->isValid(sequence_length);
|
||||
}
|
||||
|
||||
int getSFromMaxSeqLen(const int max_seq_len) const {
|
||||
if (is_flash_attention(max_seq_len)) {
|
||||
int NormalizeSequenceLength(int max_seq_len) const {
|
||||
if (UseFlashAttention(max_seq_len)) {
|
||||
return max_seq_len;
|
||||
}
|
||||
|
||||
int seq_len = max_seq_len;
|
||||
int sequence_length = max_seq_len;
|
||||
if (max_seq_len <= 32) {
|
||||
seq_len = (sm == 70) ? 64 : 32;
|
||||
sequence_length = (sm_ == 70) ? 64 : 32;
|
||||
} else if (max_seq_len <= 64) {
|
||||
seq_len = 64;
|
||||
sequence_length = 64;
|
||||
} else if (max_seq_len <= 96) {
|
||||
seq_len = 96;
|
||||
sequence_length = 96;
|
||||
} else if (max_seq_len <= 128) {
|
||||
seq_len = 128;
|
||||
sequence_length = 128;
|
||||
} else if (max_seq_len <= 192) {
|
||||
seq_len = (sm == 70) ? 256 : 192;
|
||||
sequence_length = (sm_ == 70) ? 256 : 192;
|
||||
} else if (max_seq_len <= 256) {
|
||||
seq_len = 256;
|
||||
sequence_length = 256;
|
||||
} else if (max_seq_len <= 384) {
|
||||
seq_len = 384;
|
||||
sequence_length = 384;
|
||||
}
|
||||
|
||||
return seq_len;
|
||||
return sequence_length;
|
||||
}
|
||||
|
||||
protected:
|
||||
bool is_flash_attention(const int seq_len) const {
|
||||
ORT_ENFORCE(interface->mHasCausalMask == false);
|
||||
return interface->mEnableFlashAttention && seq_len >= kMinSequenceLengthFlashAttention;
|
||||
bool UseFlashAttention(int sequence_length) const {
|
||||
ORT_ENFORCE(interface_->is_causal_ == false);
|
||||
return interface_->enable_flash_attention_ && sequence_length >= kMinSequenceLengthFlashAttention;
|
||||
}
|
||||
|
||||
private:
|
||||
FusedMHARunnerFP16v2* interface;
|
||||
Fused_multihead_attention_params_v2 params;
|
||||
int sm;
|
||||
const FusedMultiHeadAttentionXMMAKernelV2* xmmaKernel;
|
||||
const FusedMultiHeadFlashAttentionKernel* flash_attention_kernel;
|
||||
size_t xmmas_m;
|
||||
size_t threads_per_cta;
|
||||
bool use_flash_attention = false;
|
||||
bool has_causal_mask = false;
|
||||
FusedMHARunnerFP16v2* interface_;
|
||||
int sm_;
|
||||
const FusedMultiHeadAttentionXMMAKernelV2* xmma_kernel_;
|
||||
const FusedMultiHeadFlashAttentionKernel* flash_kernel_;
|
||||
};
|
||||
|
||||
FusedMHARunnerFP16v2::FusedMHARunnerFP16v2(const int numHeads,
|
||||
const int headSize,
|
||||
const int sm,
|
||||
bool causal_mask,
|
||||
FusedMHARunnerFP16v2::FusedMHARunnerFP16v2(int num_heads,
|
||||
int head_size,
|
||||
int sm,
|
||||
bool causal,
|
||||
bool enable_flash_attention,
|
||||
const float scale)
|
||||
: MHARunner(numHeads, headSize, 2, causal_mask, scale),
|
||||
mSm(sm),
|
||||
mEnableFlashAttention(enable_flash_attention),
|
||||
pimpl(new mhaImpl(this)) {
|
||||
float scale)
|
||||
: MHARunner(num_heads, head_size, causal, scale),
|
||||
enable_flash_attention_(enable_flash_attention),
|
||||
impl_(new FmhaImpl(this, sm)) {
|
||||
}
|
||||
|
||||
void FusedMHARunnerFP16v2::setup(const int seq_len, const int B) {
|
||||
MHARunner::setup(seq_len, B);
|
||||
if (mHasCausalMask) {
|
||||
pimpl->setup_causal_masked_fmha(seq_len, B);
|
||||
} else {
|
||||
pimpl->setup(seq_len, B);
|
||||
}
|
||||
}
|
||||
|
||||
bool FusedMHARunnerFP16v2::is_supported(int sm, int head_size, int sequence_length,
|
||||
bool enable_flash_attention, bool causal) {
|
||||
bool FusedMHARunnerFP16v2::IsSupported(int sm, int head_size, int sequence_length,
|
||||
bool enable_flash_attention, bool causal) {
|
||||
if (causal) {
|
||||
if (!(sm == kSM_70 || sm == kSM_75 || sm == kSM_80 || sm == kSM_86 || sm == kSM_89)) {
|
||||
return false;
|
||||
|
|
@ -284,34 +280,44 @@ bool FusedMHARunnerFP16v2::is_supported(int sm, int head_size, int sequence_leng
|
|||
return sequence_length <= max_sequence_length;
|
||||
}
|
||||
|
||||
size_t FusedMHARunnerFP16v2::getWorkspaceSize() const {
|
||||
return 0;
|
||||
void FusedMHARunnerFP16v2::Run(int batch_size,
|
||||
int normalized_sequence_length,
|
||||
const void* input,
|
||||
const void* cu_seqlens,
|
||||
void* output,
|
||||
cudaStream_t stream) const {
|
||||
Fused_multihead_attention_params_v2 params;
|
||||
bool use_flash_attention = false;
|
||||
if (is_causal_) {
|
||||
impl_->SetupCausal(params, normalized_sequence_length, batch_size, use_flash_attention);
|
||||
} else {
|
||||
impl_->Setup(params, normalized_sequence_length, batch_size, use_flash_attention);
|
||||
}
|
||||
|
||||
impl_->Run(params, input, cu_seqlens, output, stream, use_flash_attention, is_causal_);
|
||||
}
|
||||
|
||||
void FusedMHARunnerFP16v2::run(const void* input, const void* cu_seqlens, void* output, cudaStream_t stream) {
|
||||
pimpl->run(input, cu_seqlens, output, stream);
|
||||
bool FusedMHARunnerFP16v2::IsValid(int normalized_sequence_length) const {
|
||||
return impl_->IsValid(normalized_sequence_length);
|
||||
}
|
||||
|
||||
bool FusedMHARunnerFP16v2::isValid(int s) const {
|
||||
return pimpl->isValid(s);
|
||||
int FusedMHARunnerFP16v2::NormalizeSequenceLength(int max_seq_len) const {
|
||||
return impl_->NormalizeSequenceLength(max_seq_len);
|
||||
}
|
||||
|
||||
int FusedMHARunnerFP16v2::getSFromMaxSeqLen(const int max_seq_len) const {
|
||||
return pimpl->getSFromMaxSeqLen(max_seq_len);
|
||||
}
|
||||
|
||||
std::unique_ptr<MHARunner> FusedMHARunnerFP16v2::Create(const int numHeads,
|
||||
const int headSize,
|
||||
const int sm,
|
||||
bool causal_mask,
|
||||
bool enable_flash_attention,
|
||||
const float scale) {
|
||||
std::unique_ptr<MHARunner> FusedMHARunnerFP16v2::Create(int num_heads,
|
||||
int head_size,
|
||||
int sm,
|
||||
bool causal,
|
||||
bool enable_flash_attention,
|
||||
const float scale) {
|
||||
#ifdef _MSC_VER
|
||||
return std::make_unique<FusedMHARunnerFP16v2>(numHeads, headSize, sm, causal_mask, enable_flash_attention, scale);
|
||||
return std::make_unique<FusedMHARunnerFP16v2>(num_heads, head_size, sm, causal, enable_flash_attention, scale);
|
||||
#else
|
||||
// Linux build has error using make_unique: invalid application of ‘sizeof’ to incomplete type ‘onnxruntime::contrib::cuda::FusedMHARunnerFP16v2::mhaImpl
|
||||
// Linux build has error using make_unique: invalid application of ‘sizeof’ to
|
||||
// incomplete type ‘onnxruntime::contrib::cuda::FusedMHARunnerFP16v2::FmhaImpl
|
||||
std::unique_ptr<MHARunner> runner;
|
||||
runner.reset(new FusedMHARunnerFP16v2(numHeads, headSize, sm, causal_mask, enable_flash_attention, scale));
|
||||
runner.reset(new FusedMHARunnerFP16v2(num_heads, head_size, sm, causal, enable_flash_attention, scale));
|
||||
return runner;
|
||||
#endif
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,6 +14,10 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// Modifications: Update interface and implmentation to be thread-safe
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
|
@ -25,103 +29,70 @@ namespace cuda {
|
|||
|
||||
constexpr int kMinSequenceLengthFlashAttention = 385;
|
||||
|
||||
// Multi-Head Attention runner
|
||||
class MHARunner {
|
||||
public:
|
||||
MHARunner(const int numHeads, const int headSize, const int wordSize, bool causal_mask, const float scale)
|
||||
: mS(0),
|
||||
mB(0),
|
||||
mOmatSize(0),
|
||||
mNumMats(0),
|
||||
mNumHeads(numHeads),
|
||||
mHeadSize(headSize),
|
||||
mWordSize(wordSize),
|
||||
mLdQKV(0),
|
||||
mStrideQKV(0),
|
||||
mLdOut(0),
|
||||
mStrideOut(0),
|
||||
mScale(scale == 0.0f ? 1.f / sqrtf(static_cast<float>(headSize))
|
||||
: scale),
|
||||
mHasCausalMask(causal_mask) {
|
||||
MHARunner(int num_heads, int head_size, bool causal, float scale)
|
||||
: num_heads_(num_heads),
|
||||
head_size_(head_size),
|
||||
scale_(scale == 0.0f ? 1.f / sqrtf(static_cast<float>(head_size)) : scale),
|
||||
is_causal_(causal) {
|
||||
}
|
||||
|
||||
virtual ~MHARunner() = default;
|
||||
|
||||
virtual void setup(const int S, const int B) {
|
||||
ORT_ENFORCE(S > 0);
|
||||
ORT_ENFORCE(B > 0);
|
||||
virtual int NormalizeSequenceLength(int max_seq_len) const = 0;
|
||||
|
||||
mB = B;
|
||||
mS = S;
|
||||
virtual bool IsValid(int normalized_sequence_length) const = 0;
|
||||
|
||||
mLdQKV = 3 * B * mNumHeads * mHeadSize;
|
||||
mStrideQKV = 3 * mHeadSize;
|
||||
|
||||
mLdOut = B * mNumHeads * mHeadSize;
|
||||
mStrideOut = mHeadSize;
|
||||
mOmatSize = S * S;
|
||||
mNumMats = B * mNumHeads;
|
||||
}
|
||||
|
||||
virtual void run(const void* input, const void* cu_seqlens, void* output, cudaStream_t stream) = 0;
|
||||
|
||||
virtual size_t getWorkspaceSize() const = 0;
|
||||
|
||||
virtual bool isValid(int s) const = 0;
|
||||
|
||||
virtual int getSFromMaxSeqLen(const int max_seq_len) const = 0;
|
||||
virtual void Run(int batch_size,
|
||||
int normalized_sequence_length,
|
||||
const void* input,
|
||||
const void* cu_seqlens,
|
||||
void* output,
|
||||
cudaStream_t stream) const = 0;
|
||||
|
||||
protected:
|
||||
int mS;
|
||||
int mB;
|
||||
int mOmatSize;
|
||||
int mNumMats;
|
||||
int mNumHeads;
|
||||
int mHeadSize;
|
||||
int mWordSize;
|
||||
int mLdQKV;
|
||||
int mStrideQKV;
|
||||
int mLdOut;
|
||||
int mStrideOut;
|
||||
|
||||
float mScale;
|
||||
bool mHasCausalMask;
|
||||
int num_heads_;
|
||||
int head_size_;
|
||||
float scale_;
|
||||
bool is_causal_;
|
||||
};
|
||||
|
||||
class FusedMHARunnerFP16v2 : public MHARunner {
|
||||
public:
|
||||
FusedMHARunnerFP16v2(const int numHeads,
|
||||
const int headSize,
|
||||
const int sm,
|
||||
bool causal_mask,
|
||||
FusedMHARunnerFP16v2(int num_heads,
|
||||
int head_size,
|
||||
int sm,
|
||||
bool causal,
|
||||
bool enable_flash_attention,
|
||||
const float scale);
|
||||
~FusedMHARunnerFP16v2() = default; // for pimpl
|
||||
float scale);
|
||||
|
||||
virtual void setup(const int S, const int B) override;
|
||||
~FusedMHARunnerFP16v2() = default; // for impl_
|
||||
|
||||
static bool is_supported(int sm, int head_size, int sequence_length, bool enable_flash_attention, bool causal);
|
||||
static bool IsSupported(int sm, int head_size, int sequence_length, bool enable_flash_attention, bool causal);
|
||||
|
||||
void run(const void* input, const void* cu_seqlens, void* output, cudaStream_t stream) override;
|
||||
|
||||
size_t getWorkspaceSize() const override;
|
||||
|
||||
bool isValid(int s) const override;
|
||||
|
||||
int getSFromMaxSeqLen(const int max_seq_len) const override;
|
||||
|
||||
static std::unique_ptr<MHARunner> Create(const int numHeads,
|
||||
const int headSize,
|
||||
const int sm,
|
||||
bool causal_mask,
|
||||
static std::unique_ptr<MHARunner> Create(int num_heads,
|
||||
int head_size,
|
||||
int sm,
|
||||
bool causal,
|
||||
bool enable_flash_attention,
|
||||
const float scale);
|
||||
float scale);
|
||||
|
||||
bool IsValid(int normalized_sequence_length) const override;
|
||||
|
||||
int NormalizeSequenceLength(int max_seq_len) const override;
|
||||
|
||||
void Run(int batch_size,
|
||||
int normalized_sequence_length,
|
||||
const void* input,
|
||||
const void* cu_seqlens,
|
||||
void* output,
|
||||
cudaStream_t stream) const override;
|
||||
|
||||
private:
|
||||
int mSm;
|
||||
bool mEnableFlashAttention;
|
||||
class mhaImpl;
|
||||
std::unique_ptr<mhaImpl> pimpl;
|
||||
bool enable_flash_attention_;
|
||||
class FmhaImpl;
|
||||
std::unique_ptr<FmhaImpl> impl_;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -156,6 +156,49 @@ class MultiHeadAttentionConfig:
|
|||
)
|
||||
return shapes
|
||||
|
||||
def symbolic_shape_dict(self, input_format=None):
|
||||
input_format = input_format or self.input_format
|
||||
if input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH:
|
||||
# cross attention does not have past state
|
||||
return {
|
||||
"query": ("batch_size", "sequence_length", self.num_heads * self.head_size),
|
||||
"key": ("batch_size", self.num_heads, "sequence_length", self.head_size),
|
||||
"value": ("batch_size", self.num_heads, "sequence_length", self.head_size),
|
||||
"output": ("batch_size", "sequence_length", self.num_heads * self.head_size),
|
||||
}
|
||||
|
||||
if self.use_kv_cache:
|
||||
shapes = {
|
||||
"past_key": ("batch_size", self.num_heads, "past_buffer_length", self.head_size),
|
||||
"past_value": ("batch_size", self.num_heads, "past_buffer_length", self.head_size),
|
||||
"output": ("batch_size", "sequence_length", self.num_heads * self.head_size),
|
||||
"present_key": ("batch_size", self.num_heads, "present_buffer_length", self.head_size),
|
||||
"present_value": ("batch_size", self.num_heads, "present_buffer_length", self.head_size),
|
||||
}
|
||||
else:
|
||||
shapes = {
|
||||
"output": ("batch_size", "sequence_length", self.num_heads * self.head_size),
|
||||
}
|
||||
|
||||
if input_format == InputFormats.QKV_BSN3H:
|
||||
shapes.update({"query": ("batch_size", "sequence_length", self.num_heads, 3, self.head_size)})
|
||||
elif input_format == InputFormats.Q_KV_BSNH_BSN2H:
|
||||
shapes.update(
|
||||
{
|
||||
"query": ("batch_size", "sequence_length", self.num_heads * self.head_size),
|
||||
"key": ("batch_size", "sequence_length", self.num_heads, 2, self.head_size),
|
||||
}
|
||||
)
|
||||
else: # input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH
|
||||
shapes.update(
|
||||
{
|
||||
"query": ("batch_size", "sequence_length", self.num_heads * self.head_size),
|
||||
"key": ("batch_size", "sequence_length", self.num_heads * self.head_size),
|
||||
"value": ("batch_size", "sequence_length", self.num_heads * self.head_size),
|
||||
}
|
||||
)
|
||||
return shapes
|
||||
|
||||
def random_inputs(self, seed: int = 123):
|
||||
device = self.device
|
||||
dtype = self.dtype
|
||||
|
|
@ -215,7 +258,7 @@ class MultiHeadAttentionConfig:
|
|||
|
||||
def get_input_output_names(self):
|
||||
if self.input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH:
|
||||
return ["query", "key"], ["output"]
|
||||
return ["query", "key", "value"], ["output"]
|
||||
|
||||
if self.input_format == InputFormats.QKV_BSN3H:
|
||||
inputs, outputs = ["query"], ["output"]
|
||||
|
|
@ -235,7 +278,7 @@ def fill_optional_mha_inputs(input_names):
|
|||
return input_names[:-2] + [""] * (len(inputs) - len(input_names)) + input_names[-2:]
|
||||
|
||||
|
||||
def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig):
|
||||
def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig, use_symbolic_shape=False):
|
||||
input_names, output_names = config.get_input_output_names()
|
||||
|
||||
float_type = TensorProto.FLOAT16 if config.dtype == torch.float16 else TensorProto.FLOAT
|
||||
|
|
@ -252,7 +295,7 @@ def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig):
|
|||
),
|
||||
]
|
||||
|
||||
shape_dict = config.shape_dict()
|
||||
shape_dict = config.symbolic_shape_dict() if use_symbolic_shape else config.shape_dict()
|
||||
inputs = [
|
||||
helper.make_tensor_value_info(input_name, float_type, list(shape_dict[input_name]))
|
||||
for input_name in input_names
|
||||
|
|
|
|||
|
|
@ -7,17 +7,39 @@
|
|||
Test MultiHeadAttention operator for CUDA and CPU.
|
||||
"""
|
||||
|
||||
import concurrent.futures
|
||||
import itertools
|
||||
import unittest
|
||||
from typing import Optional
|
||||
from enum import IntEnum
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
from benchmark_mha import InputFormats, MultiHeadAttentionConfig, OrtMultiHeadAttention
|
||||
from benchmark_mha import (
|
||||
InputFormats,
|
||||
MultiHeadAttentionConfig,
|
||||
OrtMultiHeadAttention,
|
||||
create_multi_head_attention_onnx_model,
|
||||
)
|
||||
from einops import rearrange
|
||||
from parameterized import parameterized
|
||||
|
||||
import onnxruntime
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
|
||||
class SdpaKernel(IntEnum):
|
||||
"""Bit flags for sdpa_kernel CUDA provider option"""
|
||||
|
||||
DEFAULT = 0
|
||||
FLASH_ATTENTION = 1
|
||||
EFFICIENT_ATTENTION = 2
|
||||
TRT_FUSED_ATTENTION = 4
|
||||
CUDNN_FLASH_ATTENTION = 8
|
||||
MATH = 16
|
||||
TRT_FLASH_ATTENTION = 32
|
||||
TRT_CROSS_ATTENTION = 64
|
||||
TRT_CAUSAL_ATTENTION = 128
|
||||
|
||||
|
||||
def attention_reference(
|
||||
|
|
@ -105,9 +127,16 @@ def mha_with_past_reference(
|
|||
|
||||
def get_provider_support_info(provider: str, use_kv_cache: bool):
|
||||
if provider == "CUDAExecutionProvider":
|
||||
formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH, InputFormats.Q_KV_BSNH_BSN2H, InputFormats.QKV_BSN3H]
|
||||
if not use_kv_cache:
|
||||
formats.append(InputFormats.Q_K_V_BSNH_BSNH_BSNH)
|
||||
formats = [
|
||||
InputFormats.Q_K_V_BSNH_BSNH_BSNH,
|
||||
InputFormats.Q_KV_BSNH_BSN2H,
|
||||
InputFormats.QKV_BSN3H,
|
||||
InputFormats.Q_K_V_BSNH_BNSH_BNSH,
|
||||
]
|
||||
else:
|
||||
formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH]
|
||||
|
||||
device_id = torch.cuda.current_device()
|
||||
device = torch.device("cuda", device_id)
|
||||
dtype = torch.float16
|
||||
|
|
@ -121,15 +150,16 @@ def get_provider_support_info(provider: str, use_kv_cache: bool):
|
|||
return device, dtype, formats
|
||||
|
||||
|
||||
def has_cuda_support():
|
||||
def get_compute_capability():
|
||||
if torch.cuda.is_available() and "CUDAExecutionProvider" in onnxruntime.get_available_providers():
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
return major >= 6
|
||||
return False
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
sm = major * 10 + minor
|
||||
return sm
|
||||
return 0
|
||||
|
||||
|
||||
def no_kv_cache_test_cases(provider: str, comprehensive: bool):
|
||||
if provider == "CUDAExecutionProvider" and not has_cuda_support():
|
||||
if provider == "CUDAExecutionProvider" and get_compute_capability() < 60:
|
||||
return
|
||||
yield
|
||||
|
||||
|
|
@ -192,7 +222,7 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool):
|
|||
|
||||
|
||||
def kv_cache_test_cases(provider: str, comprehensive: bool):
|
||||
if provider == "CUDAExecutionProvider" and not has_cuda_support():
|
||||
if provider == "CUDAExecutionProvider" and get_compute_capability() < 60:
|
||||
return
|
||||
yield
|
||||
|
||||
|
|
@ -262,6 +292,92 @@ def mha_test_cases(provider: str, comprehensive: bool):
|
|||
)
|
||||
|
||||
|
||||
def no_kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool):
|
||||
if provider == "CUDAExecutionProvider" and get_compute_capability() < 60:
|
||||
return
|
||||
yield
|
||||
|
||||
batch_sizes = [1, 2]
|
||||
sequence_lengths = [1, 16, 127, 128, 255, 256, 383, 384, 400] if comprehensive else [1, 64, 128, 256]
|
||||
heads = [4]
|
||||
head_sizes = [8, 16, 32, 40, 64, 80, 96, 128, 160, 192, 224, 256] if comprehensive else [32, 64]
|
||||
|
||||
device, dtype, formats = get_provider_support_info(provider, False)
|
||||
|
||||
for format in formats:
|
||||
for causal in [True, False]:
|
||||
for num_heads in heads:
|
||||
for head_size in head_sizes:
|
||||
configs = [] # list of configurations to run in parallel
|
||||
for batch_size in batch_sizes:
|
||||
for sequence_length in sequence_lengths:
|
||||
config = MultiHeadAttentionConfig(
|
||||
batch_size=batch_size,
|
||||
sequence_length=sequence_length,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
causal=causal,
|
||||
past_sequence_length=0,
|
||||
kv_sequence_length=sequence_length,
|
||||
max_cache_sequence_length=None,
|
||||
provider=provider,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
use_kv_cache=False,
|
||||
share_past_present_buffer=False,
|
||||
input_format=format,
|
||||
)
|
||||
configs.append(config)
|
||||
yield configs
|
||||
|
||||
|
||||
def kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool):
|
||||
if provider == "CUDAExecutionProvider" and get_compute_capability() < 60:
|
||||
return
|
||||
yield
|
||||
|
||||
batch_sizes = [1, 2]
|
||||
sequence_lengths = [1, 32, 127, 128, 383, 384, 400] if comprehensive else [1, 32, 127, 128]
|
||||
heads = [4]
|
||||
head_sizes = [8, 16, 32, 40, 64, 80, 96, 128, 160, 192, 224, 256] if comprehensive else [32, 64]
|
||||
|
||||
sequence_length = 1
|
||||
device, dtype, formats = get_provider_support_info(provider, True)
|
||||
|
||||
for format in formats:
|
||||
for causal in [True, False]:
|
||||
for num_heads in heads:
|
||||
for head_size in head_sizes:
|
||||
configs = []
|
||||
for batch_size in batch_sizes:
|
||||
for past_sequence_length in sequence_lengths:
|
||||
config = MultiHeadAttentionConfig(
|
||||
batch_size=batch_size,
|
||||
sequence_length=sequence_length,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
causal=causal,
|
||||
past_sequence_length=past_sequence_length,
|
||||
kv_sequence_length=sequence_length,
|
||||
max_cache_sequence_length=None,
|
||||
provider=provider,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
use_kv_cache=True,
|
||||
share_past_present_buffer=False,
|
||||
input_format=format,
|
||||
)
|
||||
configs.append(config)
|
||||
yield configs
|
||||
|
||||
|
||||
def multi_thread_test_cases(provider: str, comprehensive: bool):
|
||||
return itertools.chain(
|
||||
no_kv_cache_multi_thread_test_cases(provider, comprehensive),
|
||||
kv_cache_multi_thread_test_cases(provider, comprehensive),
|
||||
)
|
||||
|
||||
|
||||
def causal_mask(seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, device=None):
|
||||
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
|
||||
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
|
||||
|
|
@ -346,20 +462,189 @@ def parity_check_mha(
|
|||
)
|
||||
|
||||
|
||||
def parity_check_mha_multi_threading(
|
||||
test_inputs: List[Dict],
|
||||
rtol: float = 1e-3,
|
||||
atol: float = 1e-3,
|
||||
sdpa_kernel: int = SdpaKernel.DEFAULT,
|
||||
max_threads: int = 5,
|
||||
verbose: bool = False,
|
||||
):
|
||||
# Use the first config to create a session, which is shared by all configs to run in parallel.
|
||||
config = test_inputs[0]["config"]
|
||||
# For now, MHA CUDA kernel does not support causal so skip such test cases.
|
||||
if config.causal and config.provider == "CUDAExecutionProvider":
|
||||
return None
|
||||
# Some kernel does not support certain input format.
|
||||
if sdpa_kernel not in [
|
||||
SdpaKernel.DEFAULT,
|
||||
SdpaKernel.FLASH_ATTENTION,
|
||||
SdpaKernel.EFFICIENT_ATTENTION,
|
||||
] and config.input_format in [InputFormats.Q_KV_BSNH_BSN2H]:
|
||||
return None
|
||||
if verbose:
|
||||
print(f"create a shared session with {vars(config)}")
|
||||
onnx_model_str = create_multi_head_attention_onnx_model(config, use_symbolic_shape=True)
|
||||
if config.provider == "CUDAExecutionProvider":
|
||||
provider_options = {"arena_extend_strategy": "kSameAsRequested", "sdpa_kernel": int(sdpa_kernel)}
|
||||
providers = [(config.provider, provider_options), "CPUExecutionProvider"]
|
||||
else:
|
||||
providers = ["CPUExecutionProvider"]
|
||||
ort_session = InferenceSession(onnx_model_str, providers=providers)
|
||||
|
||||
def convert_to_ort_inputs(feed_dict):
|
||||
ort_inputs = {}
|
||||
|
||||
for k, v in feed_dict.items():
|
||||
if isinstance(v, numpy.ndarray):
|
||||
ort_inputs[k] = v
|
||||
else:
|
||||
ort_inputs[k] = v.detach().cpu().numpy()
|
||||
return ort_inputs
|
||||
|
||||
def check_parity_with_config(i: int):
|
||||
config = test_inputs[i]["config"]
|
||||
if verbose:
|
||||
print(f"Thread {i} with {vars(config)}")
|
||||
|
||||
ort_inputs = test_inputs[i]["ort_inputs"]
|
||||
|
||||
if verbose:
|
||||
print(f"Thread {i} ort inputs: {ort_inputs}")
|
||||
ort_outputs = ort_session.run(None, convert_to_ort_inputs(ort_inputs))
|
||||
out = numpy.reshape(
|
||||
ort_outputs[0], (config.batch_size, config.sequence_length, config.num_heads, config.head_size)
|
||||
)
|
||||
|
||||
# Create reference inputs
|
||||
config.input_format = InputFormats.Q_K_V_BSNH_BSNH_BSNH
|
||||
ref_inputs = test_inputs[i]["ref_inputs"]
|
||||
if verbose:
|
||||
print(f"Thread {i} ref inputs: {ref_inputs}")
|
||||
q = (
|
||||
ref_inputs["query"]
|
||||
.reshape((config.batch_size, config.sequence_length, config.num_heads, config.head_size))
|
||||
.transpose(1, 2)
|
||||
)
|
||||
k = (
|
||||
ref_inputs["key"]
|
||||
.reshape((config.batch_size, config.kv_sequence_length, config.num_heads, config.head_size))
|
||||
.transpose(1, 2)
|
||||
)
|
||||
v = (
|
||||
ref_inputs["value"]
|
||||
.reshape((config.batch_size, config.kv_sequence_length, config.num_heads, config.head_size))
|
||||
.transpose(1, 2)
|
||||
)
|
||||
|
||||
mask = None
|
||||
if config.causal:
|
||||
mask = causal_mask(config.sequence_length, config.total_sequence_length, device=config.device)
|
||||
|
||||
k_cache = None
|
||||
v_cache = None
|
||||
if config.use_kv_cache:
|
||||
past_k = ref_inputs["past_key"]
|
||||
past_v = ref_inputs["past_value"]
|
||||
out_ref, k_cache, v_cache = mha_with_past_reference(config, past_k, past_v, q, k, v, mask=mask)
|
||||
else:
|
||||
out_ref = attention_reference(config.head_size, q, k, v, mask=mask)
|
||||
|
||||
try:
|
||||
numpy.testing.assert_allclose(
|
||||
out,
|
||||
out_ref.detach().cpu().numpy(),
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
equal_nan=True,
|
||||
err_msg=f"output not close: {config=}",
|
||||
)
|
||||
|
||||
if config.use_kv_cache:
|
||||
present_key = ort_outputs[1]
|
||||
numpy.testing.assert_allclose(
|
||||
k_cache.detach().cpu().numpy(),
|
||||
present_key,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
equal_nan=True,
|
||||
err_msg=f"present_key not close: {config=}",
|
||||
)
|
||||
|
||||
present_value = ort_outputs[2]
|
||||
numpy.testing.assert_allclose(
|
||||
v_cache.detach().cpu().numpy(),
|
||||
present_value,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
equal_nan=True,
|
||||
err_msg=f"present_value not close: {config=}",
|
||||
)
|
||||
except AssertionError as e:
|
||||
print(f"Failed with {vars(config)}: {e}")
|
||||
return e
|
||||
|
||||
if verbose:
|
||||
print(f"Passed: {vars(config)}")
|
||||
return None
|
||||
|
||||
num_threads = min(max_threads, len(test_inputs))
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
future_tasks = [executor.submit(check_parity_with_config, i) for i in range(num_threads)]
|
||||
for future in concurrent.futures.as_completed(future_tasks):
|
||||
result = future.result()
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Do not run too many tests in CI pipeline. Change it to True to run all combinations in dev machine.
|
||||
comprehensive_mode = False
|
||||
|
||||
|
||||
class TestMultiHeadAttention(unittest.TestCase):
|
||||
# TODO: enable tests on CUDAExecutionProvider after fixing the issue.
|
||||
# @parameterized.expand(mha_test_cases("CUDAExecutionProvider", comprehensive_mode), skip_on_empty=True)
|
||||
# def test_mha_cuda(self, config):
|
||||
# parity_check_mha(config)
|
||||
@parameterized.expand(mha_test_cases("CUDAExecutionProvider", comprehensive_mode), skip_on_empty=True)
|
||||
def test_mha_cuda(self, config):
|
||||
parity_check_mha(config)
|
||||
|
||||
@parameterized.expand(mha_test_cases("CPUExecutionProvider", comprehensive_mode), skip_on_empty=True)
|
||||
def test_mha_cpu(self, config):
|
||||
parity_check_mha(config)
|
||||
|
||||
def run_mha_cuda_multi_threading(self, spda_kernel):
|
||||
for configs in multi_thread_test_cases("CUDAExecutionProvider", comprehensive_mode):
|
||||
test_inputs = []
|
||||
for config in configs:
|
||||
ort_inputs = config.random_inputs()
|
||||
|
||||
# Create reference inputs
|
||||
old_format = config.input_format
|
||||
config.input_format = InputFormats.Q_K_V_BSNH_BSNH_BSNH
|
||||
ref_inputs = config.random_inputs()
|
||||
config.input_format = old_format
|
||||
test_inputs.append({"config": config, "ort_inputs": ort_inputs, "ref_inputs": ref_inputs})
|
||||
|
||||
exception = parity_check_mha_multi_threading(test_inputs, sdpa_kernel=spda_kernel, max_threads=len(configs))
|
||||
assert exception is None, f"{spda_kernel=}, {vars(configs[0])}, {exception}"
|
||||
|
||||
def test_mha_cuda_multi_threading(self):
|
||||
self.run_mha_cuda_multi_threading(SdpaKernel.DEFAULT)
|
||||
|
||||
def test_mha_cuda_multi_threading_efficient(self):
|
||||
self.run_mha_cuda_multi_threading(SdpaKernel.EFFICIENT_ATTENTION)
|
||||
|
||||
def test_mha_cuda_multi_threading_trt(self):
|
||||
sm = get_compute_capability()
|
||||
if sm in [75, 80, 86, 89]:
|
||||
self.run_mha_cuda_multi_threading(
|
||||
SdpaKernel.TRT_FUSED_ATTENTION
|
||||
| SdpaKernel.TRT_FLASH_ATTENTION
|
||||
| SdpaKernel.TRT_CROSS_ATTENTION
|
||||
| SdpaKernel.TRT_CAUSAL_ATTENTION
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with torch.no_grad():
|
||||
|
|
|
|||
Loading…
Reference in a new issue