[CUDA] FusedMHARunnerFP16v2 thread-safe (#21420)

### Description
- [x] Rewrite FusedMHARunnerFP16v2 to make it thread-safe. 
- [x] Add multi-threading tests

Previously, the kernel parameters params is stored as a member of mha
runner, which means that different threads might change the params at
the same time and impacts the other threads.

For example, if batch_size and seq_len was changed by another thread to
larger values in setup(...), buffer overrun might happen in run(...)
because a kernel could read/write memory out of range of allocated
buffers.

In new implementation, I change the api and remove mutable member
variables to make it thread safe. Below is summary of change:

Before:
```
class FusedMHARunnerFP16v2::mhaImpl {
   void setup(int seq_len, int batch_size) {
      // change scalar params
   }

   void run(input, output) {
      // change params for input and output pointers
      // launch kernel using params
   }

   Fused_multihead_attention_params_v2 params; // mutable, not thread-safe
}
```

After:
```
class FusedMHARunnerFP16v2::FmhaImpl {
   void setup(int seq_len, int batch_size, Fused_multihead_attention_params_v2& params) {
      // change params
   }

   void run(params, input, output) {
      // change params with input and output pointers
      // launch kernel using params
   }
}
```

### Motivation and Context
https://github.com/microsoft/onnxruntime/issues/18854
https://github.com/microsoft/onnxruntime/issues/21413
This commit is contained in:
Tianlei Wu 2024-07-22 10:41:08 -07:00 committed by GitHub
parent 11bf309736
commit a6c5e2cd20
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 537 additions and 237 deletions

View file

@ -149,8 +149,8 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
nullptr == relative_position_bias &&
parameters.past_sequence_length == 0 &&
parameters.hidden_size == parameters.v_hidden_size &&
FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
enable_trt_flash_attention_, true);
FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length,
enable_trt_flash_attention_, true);
if (use_causal_fused_runner) {
// Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node.
if (nullptr == fused_fp16_runner_.get()) {
@ -171,8 +171,8 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
nullptr == present &&
nullptr == relative_position_bias &&
parameters.hidden_size == parameters.v_hidden_size &&
FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
enable_trt_flash_attention_, false);
FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length,
enable_trt_flash_attention_, false);
if (use_fused_runner) {
// Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node.
@ -184,8 +184,8 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
}
// In case some kernel not loaded due to shared memory limit, we need to double check here.
const int S = fused_fp16_runner_->getSFromMaxSeqLen(sequence_length);
if (fused_fp16_runner_->isValid(S)) {
const int normalized_seq_len = fused_fp16_runner_->NormalizeSequenceLength(sequence_length);
if (fused_fp16_runner_->IsValid(normalized_seq_len)) {
fused_runner = fused_fp16_runner_.get();
}
}

View file

@ -245,12 +245,10 @@ Status FusedTrtSelfAttention(
FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast<FusedMHARunnerFP16v2*>(data.fused_runner);
const int S = causal ? sequence_length : fused_fp16_runner->getSFromMaxSeqLen(sequence_length);
const int s = causal ? sequence_length : fused_fp16_runner->NormalizeSequenceLength(sequence_length);
// B = 2 * batch_size when there is padding in input, and B = batch_size when padding is removed.
const int B = (nullptr == data.mask_index ? batch_size : 2 * batch_size);
fused_fp16_runner->setup(S, B);
const int b = (nullptr == data.mask_index ? batch_size : 2 * batch_size);
if (!causal) {
assert(data.qkv_format == AttentionQkvFormat::QKV_BSN3H);
@ -261,12 +259,12 @@ Status FusedTrtSelfAttention(
packed_qkv = data.query;
}
fused_fp16_runner->run(packed_qkv, sequence_offset, data.output, stream);
fused_fp16_runner->Run(b, s, packed_qkv, sequence_offset, data.output, stream);
DUMP_TENSOR("fused output", data.output,
batch_size, sequence_length, parameters.num_heads, parameters.v_head_size);
} else {
assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH);
fused_fp16_runner->run(data.gemm_buffer, sequence_offset, data.output, stream);
fused_fp16_runner->Run(b, s, data.gemm_buffer, sequence_offset, data.output, stream);
DUMP_TENSOR("fused causal output", data.output,
batch_size, sequence_length, parameters.num_heads, parameters.v_head_size);
}

View file

@ -193,8 +193,8 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
(nullptr == key_padding_mask || is_mask_1d_seq_len) &&
parameters.hidden_size == parameters.v_hidden_size &&
parameters.sequence_length == parameters.kv_sequence_length &&
FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
enable_trt_flash_attention_, false);
FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length,
enable_trt_flash_attention_, false);
if (use_fused_runner) {
// Here we assume that num_heads and head_size does not change for a MultiHeadAttention node.
if (nullptr == fused_fp16_runner_.get()) {
@ -206,8 +206,8 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
}
// In case some kernel not loaded due to shared memory limit, we need to double check here.
const int S = fused_fp16_runner_->getSFromMaxSeqLen(sequence_length);
if (fused_fp16_runner_->isValid(S)) {
const int normalized_seq_len = fused_fp16_runner_->NormalizeSequenceLength(sequence_length);
if (fused_fp16_runner_->IsValid(normalized_seq_len)) {
fused_runner = fused_fp16_runner_.get();
}
}

View file

@ -55,11 +55,11 @@ MHARunner* TrtFusedAttention<T>::GetFusedRunner(const cudaDeviceProp& device_pro
// Check whether we can use fused kernel
int sm = device_prop.major * 10 + device_prop.minor;
bool is_fMHA_supported = FusedMHARunnerFP16v2::is_supported(sm,
parameters.head_size,
parameters.sequence_length,
enable_trt_flash_attention_,
false /*causal*/);
bool is_fMHA_supported = FusedMHARunnerFP16v2::IsSupported(sm,
parameters.head_size,
parameters.sequence_length,
enable_trt_flash_attention_,
false /*causal*/);
if (!is_fMHA_supported) {
return fused_runner;
@ -72,8 +72,8 @@ MHARunner* TrtFusedAttention<T>::GetFusedRunner(const cudaDeviceProp& device_pro
}
// In case some kernel not loaded due to shared memory limit, we need to double check here.
const int S = fused_fp16_runner_->getSFromMaxSeqLen(parameters.sequence_length);
if (fused_fp16_runner_->isValid(S)) {
const int normalized_seq_len = fused_fp16_runner_->NormalizeSequenceLength(parameters.sequence_length);
if (fused_fp16_runner_->IsValid(normalized_seq_len)) {
fused_runner = fused_fp16_runner_.get();
}

View file

@ -459,10 +459,9 @@ Status FusedScaledDotProductAttention(
parameters.token_count, stream);
FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast<FusedMHARunnerFP16v2*>(fused_runner);
const int S = fused_fp16_runner->getSFromMaxSeqLen(sequence_length);
fused_fp16_runner->setup(S, batch_size);
fused_fp16_runner->run(data.workspace, data.cumulative_sequence_length, data.output, stream);
const int normalized_seq_len = fused_fp16_runner->NormalizeSequenceLength(sequence_length);
fused_fp16_runner->Run(batch_size, normalized_seq_len,
data.workspace, data.cumulative_sequence_length, data.output, stream);
return Status::OK();
}

View file

@ -575,10 +575,8 @@ Status FusedAttentionTrt(
}
FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast<FusedMHARunnerFP16v2*>(fused_runner);
const int S = fused_fp16_runner->getSFromMaxSeqLen(sequence_length);
fused_fp16_runner->setup(S, batch_size);
fused_fp16_runner->run(qkv, data.cumulative_sequence_length, data.output, stream);
const int normalized_seq_len = fused_fp16_runner->NormalizeSequenceLength(sequence_length);
fused_fp16_runner->Run(batch_size, normalized_seq_len, qkv, data.cumulative_sequence_length, data.output, stream);
return Status::OK();
}

View file

@ -14,6 +14,10 @@
* limitations under the License.
*/
// Modifications: Update interface and implmentation to be thread-safe
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h"
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/fused_multihead_attention_v2.h"
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/flash_attention/fmha_flash_attention.h"
@ -34,28 +38,28 @@ void set_alpha_fp16(uint32_t& alpha, float norm) {
alpha = temp.u32;
}
class FusedMHARunnerFP16v2::mhaImpl {
class FusedMHARunnerFP16v2::FmhaImpl {
public:
mhaImpl(FusedMHARunnerFP16v2* interface)
: interface(interface),
sm(interface->mSm),
xmmaKernel(getXMMAKernelsV2(DATA_TYPE_FP16, sm)) {
FmhaImpl(FusedMHARunnerFP16v2* interface, int sm)
: interface_(interface),
sm_(sm),
xmma_kernel_(getXMMAKernelsV2(DATA_TYPE_FP16, sm)) {
ORT_ENFORCE((sm == kSM_70 || sm == kSM_75 || sm == kSM_80 || sm == kSM_86 || sm == kSM_89),
"Unsupported architecture");
flash_attention_kernel = nullptr;
if (interface->mEnableFlashAttention) {
flash_attention_kernel = get_flash_attention_kernels(DATA_TYPE_FP16, sm);
flash_kernel_ = nullptr;
if (interface_->enable_flash_attention_) {
flash_kernel_ = get_flash_attention_kernels(DATA_TYPE_FP16, sm);
}
params.clear();
}
~mhaImpl() {}
~FmhaImpl() {}
void setup(const int seq_len, const int B) {
// For bert and vit, use flash attention when sequence length is larger than the threshold.
use_flash_attention = is_flash_attention(seq_len);
void Setup(Fused_multihead_attention_params_v2& params,
int sequence_length, // normalized sequence length
int batch_size,
bool& use_flash_attention) const {
use_flash_attention = UseFlashAttention(sequence_length);
params.force_unroll = use_flash_attention;
@ -67,27 +71,27 @@ class FusedMHARunnerFP16v2::mhaImpl {
warps_m = 4;
warps_n = 1;
} else {
if (sm == 70) {
if (seq_len == 64 || seq_len == 96) {
if (sm_ == 70) {
if (sequence_length == 64 || sequence_length == 96) {
warps_m = 2;
warps_n = 2;
} else if (seq_len == 128) {
} else if (sequence_length == 128) {
warps_m = 1;
warps_n = 4;
} else if (seq_len == 256 || seq_len == 384) {
} else if (sequence_length == 256 || sequence_length == 384) {
warps_m = 1;
warps_n = 8;
} else {
ORT_ENFORCE(false, "Unsupported sequence length");
}
} else {
if (seq_len == 32 || seq_len == 64 || seq_len == 96 || seq_len == 128) {
if (sequence_length == 32 || sequence_length == 64 || sequence_length == 96 || sequence_length == 128) {
warps_m = 2;
warps_n = 2;
} else if (seq_len == 192 || seq_len == 256) {
} else if (sequence_length == 192 || sequence_length == 256) {
warps_m = 1;
warps_n = 4;
} else if (seq_len == 384) {
} else if (sequence_length == 384) {
warps_m = 1;
warps_n = 8;
} else {
@ -97,11 +101,11 @@ class FusedMHARunnerFP16v2::mhaImpl {
}
// The number of threads per CTA.
threads_per_cta = warps_m * warps_n * warps_k * 32;
size_t threads_per_cta = warps_m * warps_n * warps_k * 32;
// The number of xmmas in the M dimension. We use one uint32_t per XMMA in the M dimension.
xmmas_m = (seq_len + 16 * warps_m - 1) / (16 * warps_m);
size_t xmmas_m = (sequence_length + 16 * warps_m - 1) / (16 * warps_m);
const float scale_bmm1 = interface->mScale;
const float scale_bmm1 = interface_->scale_;
const float scale_softmax = 1.f; // Seems to be only required for int8
const float scale_bmm2 = 1.f;
@ -109,20 +113,21 @@ class FusedMHARunnerFP16v2::mhaImpl {
set_alpha_fp16(params.scale_softmax, scale_softmax);
set_alpha_fp16(params.scale_bmm2, scale_bmm2);
params.b = B;
params.h = interface->mNumHeads;
params.s = seq_len;
params.d = interface->mHeadSize;
params.b = batch_size;
params.h = interface_->num_heads_;
params.s = sequence_length;
params.d = interface_->head_size_;
params.qkv_stride_in_bytes = 3 * interface->mNumHeads * interface->mHeadSize * sizeof(half);
params.qkv_stride_in_bytes = 3 * interface_->num_heads_ * interface_->head_size_ * sizeof(half);
params.packed_mask_stride_in_bytes = xmmas_m * threads_per_cta * sizeof(uint32_t);
params.o_stride_in_bytes = interface->mNumHeads * interface->mHeadSize * sizeof(half);
has_causal_mask = false;
params.o_stride_in_bytes = interface_->num_heads_ * interface_->head_size_ * sizeof(half);
}
void setup_causal_masked_fmha(const int seq_len, const int B) {
const float scale_bmm1 = interface->mScale;
void SetupCausal(Fused_multihead_attention_params_v2& params,
int sequence_length, // normalized sequence length
int batch_size,
bool& use_flash_attention) const {
const float scale_bmm1 = interface_->scale_;
const float scale_softmax = 1.f; // Seems to be only required for int8
const float scale_bmm2 = 1.f;
@ -130,16 +135,17 @@ class FusedMHARunnerFP16v2::mhaImpl {
set_alpha_fp16(params.scale_softmax, scale_softmax);
set_alpha_fp16(params.scale_bmm2, scale_bmm2);
params.b = B;
params.h = interface->mNumHeads;
params.s = seq_len;
params.d = interface->mHeadSize;
params.b = batch_size;
params.h = interface_->num_heads_;
params.s = sequence_length;
params.d = interface_->head_size_;
params.qkv_stride_in_bytes = 3 * interface->mNumHeads * interface->mHeadSize * sizeof(half);
params.o_stride_in_bytes = interface->mNumHeads * interface->mHeadSize * sizeof(half);
params.qkv_stride_in_bytes = 3 * interface_->num_heads_ * interface_->head_size_ * sizeof(half);
params.o_stride_in_bytes = interface_->num_heads_ * interface_->head_size_ * sizeof(half);
// fallback to original fmha_v2 when head_size <= 64 and seq_len <- 128
use_flash_attention = interface->mEnableFlashAttention;
use_flash_attention = interface_->enable_flash_attention_;
// fallback to original fmha_v2 when head_size <= 64 and sequence_length <= 128
if (params.d <= 64 && params.s <= 128) {
use_flash_attention = false;
// get max sequence length
@ -152,97 +158,87 @@ class FusedMHARunnerFP16v2::mhaImpl {
// set flags
params.force_unroll = use_flash_attention;
has_causal_mask = true;
}
void run(const void* input, const void* cu_seqlens, void* output, cudaStream_t stream) {
void Run(Fused_multihead_attention_params_v2& params,
const void* input,
const void* cu_seqlens,
void* output,
cudaStream_t stream,
bool use_flash_attention,
bool has_causal_mask) const {
params.qkv_ptr = const_cast<void*>(input);
params.o_ptr = output;
params.cu_seqlens = static_cast<int*>(const_cast<void*>(cu_seqlens));
if (use_flash_attention && flash_attention_kernel != nullptr && !has_causal_mask) {
flash_attention_kernel->run(params, stream);
if (use_flash_attention && flash_kernel_ != nullptr && !has_causal_mask) {
flash_kernel_->run(params, stream);
} else {
xmmaKernel->run(params, stream, use_flash_attention, has_causal_mask);
xmma_kernel_->run(params, stream, use_flash_attention, has_causal_mask);
}
CUDA_CALL_THROW(cudaPeekAtLastError());
}
bool isValid(int s) const {
if (is_flash_attention(s)) {
return (flash_attention_kernel != nullptr) && flash_attention_kernel->isValid(s);
bool IsValid(int sequence_length) const {
if (UseFlashAttention(sequence_length)) {
return (flash_kernel_ != nullptr) && flash_kernel_->isValid(sequence_length);
}
return xmmaKernel->isValid(s);
return xmma_kernel_->isValid(sequence_length);
}
int getSFromMaxSeqLen(const int max_seq_len) const {
if (is_flash_attention(max_seq_len)) {
int NormalizeSequenceLength(int max_seq_len) const {
if (UseFlashAttention(max_seq_len)) {
return max_seq_len;
}
int seq_len = max_seq_len;
int sequence_length = max_seq_len;
if (max_seq_len <= 32) {
seq_len = (sm == 70) ? 64 : 32;
sequence_length = (sm_ == 70) ? 64 : 32;
} else if (max_seq_len <= 64) {
seq_len = 64;
sequence_length = 64;
} else if (max_seq_len <= 96) {
seq_len = 96;
sequence_length = 96;
} else if (max_seq_len <= 128) {
seq_len = 128;
sequence_length = 128;
} else if (max_seq_len <= 192) {
seq_len = (sm == 70) ? 256 : 192;
sequence_length = (sm_ == 70) ? 256 : 192;
} else if (max_seq_len <= 256) {
seq_len = 256;
sequence_length = 256;
} else if (max_seq_len <= 384) {
seq_len = 384;
sequence_length = 384;
}
return seq_len;
return sequence_length;
}
protected:
bool is_flash_attention(const int seq_len) const {
ORT_ENFORCE(interface->mHasCausalMask == false);
return interface->mEnableFlashAttention && seq_len >= kMinSequenceLengthFlashAttention;
bool UseFlashAttention(int sequence_length) const {
ORT_ENFORCE(interface_->is_causal_ == false);
return interface_->enable_flash_attention_ && sequence_length >= kMinSequenceLengthFlashAttention;
}
private:
FusedMHARunnerFP16v2* interface;
Fused_multihead_attention_params_v2 params;
int sm;
const FusedMultiHeadAttentionXMMAKernelV2* xmmaKernel;
const FusedMultiHeadFlashAttentionKernel* flash_attention_kernel;
size_t xmmas_m;
size_t threads_per_cta;
bool use_flash_attention = false;
bool has_causal_mask = false;
FusedMHARunnerFP16v2* interface_;
int sm_;
const FusedMultiHeadAttentionXMMAKernelV2* xmma_kernel_;
const FusedMultiHeadFlashAttentionKernel* flash_kernel_;
};
FusedMHARunnerFP16v2::FusedMHARunnerFP16v2(const int numHeads,
const int headSize,
const int sm,
bool causal_mask,
FusedMHARunnerFP16v2::FusedMHARunnerFP16v2(int num_heads,
int head_size,
int sm,
bool causal,
bool enable_flash_attention,
const float scale)
: MHARunner(numHeads, headSize, 2, causal_mask, scale),
mSm(sm),
mEnableFlashAttention(enable_flash_attention),
pimpl(new mhaImpl(this)) {
float scale)
: MHARunner(num_heads, head_size, causal, scale),
enable_flash_attention_(enable_flash_attention),
impl_(new FmhaImpl(this, sm)) {
}
void FusedMHARunnerFP16v2::setup(const int seq_len, const int B) {
MHARunner::setup(seq_len, B);
if (mHasCausalMask) {
pimpl->setup_causal_masked_fmha(seq_len, B);
} else {
pimpl->setup(seq_len, B);
}
}
bool FusedMHARunnerFP16v2::is_supported(int sm, int head_size, int sequence_length,
bool enable_flash_attention, bool causal) {
bool FusedMHARunnerFP16v2::IsSupported(int sm, int head_size, int sequence_length,
bool enable_flash_attention, bool causal) {
if (causal) {
if (!(sm == kSM_70 || sm == kSM_75 || sm == kSM_80 || sm == kSM_86 || sm == kSM_89)) {
return false;
@ -284,34 +280,44 @@ bool FusedMHARunnerFP16v2::is_supported(int sm, int head_size, int sequence_leng
return sequence_length <= max_sequence_length;
}
size_t FusedMHARunnerFP16v2::getWorkspaceSize() const {
return 0;
void FusedMHARunnerFP16v2::Run(int batch_size,
int normalized_sequence_length,
const void* input,
const void* cu_seqlens,
void* output,
cudaStream_t stream) const {
Fused_multihead_attention_params_v2 params;
bool use_flash_attention = false;
if (is_causal_) {
impl_->SetupCausal(params, normalized_sequence_length, batch_size, use_flash_attention);
} else {
impl_->Setup(params, normalized_sequence_length, batch_size, use_flash_attention);
}
impl_->Run(params, input, cu_seqlens, output, stream, use_flash_attention, is_causal_);
}
void FusedMHARunnerFP16v2::run(const void* input, const void* cu_seqlens, void* output, cudaStream_t stream) {
pimpl->run(input, cu_seqlens, output, stream);
bool FusedMHARunnerFP16v2::IsValid(int normalized_sequence_length) const {
return impl_->IsValid(normalized_sequence_length);
}
bool FusedMHARunnerFP16v2::isValid(int s) const {
return pimpl->isValid(s);
int FusedMHARunnerFP16v2::NormalizeSequenceLength(int max_seq_len) const {
return impl_->NormalizeSequenceLength(max_seq_len);
}
int FusedMHARunnerFP16v2::getSFromMaxSeqLen(const int max_seq_len) const {
return pimpl->getSFromMaxSeqLen(max_seq_len);
}
std::unique_ptr<MHARunner> FusedMHARunnerFP16v2::Create(const int numHeads,
const int headSize,
const int sm,
bool causal_mask,
bool enable_flash_attention,
const float scale) {
std::unique_ptr<MHARunner> FusedMHARunnerFP16v2::Create(int num_heads,
int head_size,
int sm,
bool causal,
bool enable_flash_attention,
const float scale) {
#ifdef _MSC_VER
return std::make_unique<FusedMHARunnerFP16v2>(numHeads, headSize, sm, causal_mask, enable_flash_attention, scale);
return std::make_unique<FusedMHARunnerFP16v2>(num_heads, head_size, sm, causal, enable_flash_attention, scale);
#else
// Linux build has error using make_unique: invalid application of sizeof to incomplete type onnxruntime::contrib::cuda::FusedMHARunnerFP16v2::mhaImpl
// Linux build has error using make_unique: invalid application of sizeof to
// incomplete type onnxruntime::contrib::cuda::FusedMHARunnerFP16v2::FmhaImpl
std::unique_ptr<MHARunner> runner;
runner.reset(new FusedMHARunnerFP16v2(numHeads, headSize, sm, causal_mask, enable_flash_attention, scale));
runner.reset(new FusedMHARunnerFP16v2(num_heads, head_size, sm, causal, enable_flash_attention, scale));
return runner;
#endif
}

View file

@ -14,6 +14,10 @@
* limitations under the License.
*/
// Modifications: Update interface and implmentation to be thread-safe
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <memory>
@ -25,103 +29,70 @@ namespace cuda {
constexpr int kMinSequenceLengthFlashAttention = 385;
// Multi-Head Attention runner
class MHARunner {
public:
MHARunner(const int numHeads, const int headSize, const int wordSize, bool causal_mask, const float scale)
: mS(0),
mB(0),
mOmatSize(0),
mNumMats(0),
mNumHeads(numHeads),
mHeadSize(headSize),
mWordSize(wordSize),
mLdQKV(0),
mStrideQKV(0),
mLdOut(0),
mStrideOut(0),
mScale(scale == 0.0f ? 1.f / sqrtf(static_cast<float>(headSize))
: scale),
mHasCausalMask(causal_mask) {
MHARunner(int num_heads, int head_size, bool causal, float scale)
: num_heads_(num_heads),
head_size_(head_size),
scale_(scale == 0.0f ? 1.f / sqrtf(static_cast<float>(head_size)) : scale),
is_causal_(causal) {
}
virtual ~MHARunner() = default;
virtual void setup(const int S, const int B) {
ORT_ENFORCE(S > 0);
ORT_ENFORCE(B > 0);
virtual int NormalizeSequenceLength(int max_seq_len) const = 0;
mB = B;
mS = S;
virtual bool IsValid(int normalized_sequence_length) const = 0;
mLdQKV = 3 * B * mNumHeads * mHeadSize;
mStrideQKV = 3 * mHeadSize;
mLdOut = B * mNumHeads * mHeadSize;
mStrideOut = mHeadSize;
mOmatSize = S * S;
mNumMats = B * mNumHeads;
}
virtual void run(const void* input, const void* cu_seqlens, void* output, cudaStream_t stream) = 0;
virtual size_t getWorkspaceSize() const = 0;
virtual bool isValid(int s) const = 0;
virtual int getSFromMaxSeqLen(const int max_seq_len) const = 0;
virtual void Run(int batch_size,
int normalized_sequence_length,
const void* input,
const void* cu_seqlens,
void* output,
cudaStream_t stream) const = 0;
protected:
int mS;
int mB;
int mOmatSize;
int mNumMats;
int mNumHeads;
int mHeadSize;
int mWordSize;
int mLdQKV;
int mStrideQKV;
int mLdOut;
int mStrideOut;
float mScale;
bool mHasCausalMask;
int num_heads_;
int head_size_;
float scale_;
bool is_causal_;
};
class FusedMHARunnerFP16v2 : public MHARunner {
public:
FusedMHARunnerFP16v2(const int numHeads,
const int headSize,
const int sm,
bool causal_mask,
FusedMHARunnerFP16v2(int num_heads,
int head_size,
int sm,
bool causal,
bool enable_flash_attention,
const float scale);
~FusedMHARunnerFP16v2() = default; // for pimpl
float scale);
virtual void setup(const int S, const int B) override;
~FusedMHARunnerFP16v2() = default; // for impl_
static bool is_supported(int sm, int head_size, int sequence_length, bool enable_flash_attention, bool causal);
static bool IsSupported(int sm, int head_size, int sequence_length, bool enable_flash_attention, bool causal);
void run(const void* input, const void* cu_seqlens, void* output, cudaStream_t stream) override;
size_t getWorkspaceSize() const override;
bool isValid(int s) const override;
int getSFromMaxSeqLen(const int max_seq_len) const override;
static std::unique_ptr<MHARunner> Create(const int numHeads,
const int headSize,
const int sm,
bool causal_mask,
static std::unique_ptr<MHARunner> Create(int num_heads,
int head_size,
int sm,
bool causal,
bool enable_flash_attention,
const float scale);
float scale);
bool IsValid(int normalized_sequence_length) const override;
int NormalizeSequenceLength(int max_seq_len) const override;
void Run(int batch_size,
int normalized_sequence_length,
const void* input,
const void* cu_seqlens,
void* output,
cudaStream_t stream) const override;
private:
int mSm;
bool mEnableFlashAttention;
class mhaImpl;
std::unique_ptr<mhaImpl> pimpl;
bool enable_flash_attention_;
class FmhaImpl;
std::unique_ptr<FmhaImpl> impl_;
};
} // namespace cuda

View file

@ -156,6 +156,49 @@ class MultiHeadAttentionConfig:
)
return shapes
def symbolic_shape_dict(self, input_format=None):
input_format = input_format or self.input_format
if input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH:
# cross attention does not have past state
return {
"query": ("batch_size", "sequence_length", self.num_heads * self.head_size),
"key": ("batch_size", self.num_heads, "sequence_length", self.head_size),
"value": ("batch_size", self.num_heads, "sequence_length", self.head_size),
"output": ("batch_size", "sequence_length", self.num_heads * self.head_size),
}
if self.use_kv_cache:
shapes = {
"past_key": ("batch_size", self.num_heads, "past_buffer_length", self.head_size),
"past_value": ("batch_size", self.num_heads, "past_buffer_length", self.head_size),
"output": ("batch_size", "sequence_length", self.num_heads * self.head_size),
"present_key": ("batch_size", self.num_heads, "present_buffer_length", self.head_size),
"present_value": ("batch_size", self.num_heads, "present_buffer_length", self.head_size),
}
else:
shapes = {
"output": ("batch_size", "sequence_length", self.num_heads * self.head_size),
}
if input_format == InputFormats.QKV_BSN3H:
shapes.update({"query": ("batch_size", "sequence_length", self.num_heads, 3, self.head_size)})
elif input_format == InputFormats.Q_KV_BSNH_BSN2H:
shapes.update(
{
"query": ("batch_size", "sequence_length", self.num_heads * self.head_size),
"key": ("batch_size", "sequence_length", self.num_heads, 2, self.head_size),
}
)
else: # input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH
shapes.update(
{
"query": ("batch_size", "sequence_length", self.num_heads * self.head_size),
"key": ("batch_size", "sequence_length", self.num_heads * self.head_size),
"value": ("batch_size", "sequence_length", self.num_heads * self.head_size),
}
)
return shapes
def random_inputs(self, seed: int = 123):
device = self.device
dtype = self.dtype
@ -215,7 +258,7 @@ class MultiHeadAttentionConfig:
def get_input_output_names(self):
if self.input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH:
return ["query", "key"], ["output"]
return ["query", "key", "value"], ["output"]
if self.input_format == InputFormats.QKV_BSN3H:
inputs, outputs = ["query"], ["output"]
@ -235,7 +278,7 @@ def fill_optional_mha_inputs(input_names):
return input_names[:-2] + [""] * (len(inputs) - len(input_names)) + input_names[-2:]
def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig):
def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig, use_symbolic_shape=False):
input_names, output_names = config.get_input_output_names()
float_type = TensorProto.FLOAT16 if config.dtype == torch.float16 else TensorProto.FLOAT
@ -252,7 +295,7 @@ def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig):
),
]
shape_dict = config.shape_dict()
shape_dict = config.symbolic_shape_dict() if use_symbolic_shape else config.shape_dict()
inputs = [
helper.make_tensor_value_info(input_name, float_type, list(shape_dict[input_name]))
for input_name in input_names

View file

@ -7,17 +7,39 @@
Test MultiHeadAttention operator for CUDA and CPU.
"""
import concurrent.futures
import itertools
import unittest
from typing import Optional
from enum import IntEnum
from typing import Dict, List, Optional
import numpy
import torch
from benchmark_mha import InputFormats, MultiHeadAttentionConfig, OrtMultiHeadAttention
from benchmark_mha import (
InputFormats,
MultiHeadAttentionConfig,
OrtMultiHeadAttention,
create_multi_head_attention_onnx_model,
)
from einops import rearrange
from parameterized import parameterized
import onnxruntime
from onnxruntime import InferenceSession
class SdpaKernel(IntEnum):
"""Bit flags for sdpa_kernel CUDA provider option"""
DEFAULT = 0
FLASH_ATTENTION = 1
EFFICIENT_ATTENTION = 2
TRT_FUSED_ATTENTION = 4
CUDNN_FLASH_ATTENTION = 8
MATH = 16
TRT_FLASH_ATTENTION = 32
TRT_CROSS_ATTENTION = 64
TRT_CAUSAL_ATTENTION = 128
def attention_reference(
@ -105,9 +127,16 @@ def mha_with_past_reference(
def get_provider_support_info(provider: str, use_kv_cache: bool):
if provider == "CUDAExecutionProvider":
formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH, InputFormats.Q_KV_BSNH_BSN2H, InputFormats.QKV_BSN3H]
if not use_kv_cache:
formats.append(InputFormats.Q_K_V_BSNH_BSNH_BSNH)
formats = [
InputFormats.Q_K_V_BSNH_BSNH_BSNH,
InputFormats.Q_KV_BSNH_BSN2H,
InputFormats.QKV_BSN3H,
InputFormats.Q_K_V_BSNH_BNSH_BNSH,
]
else:
formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH]
device_id = torch.cuda.current_device()
device = torch.device("cuda", device_id)
dtype = torch.float16
@ -121,15 +150,16 @@ def get_provider_support_info(provider: str, use_kv_cache: bool):
return device, dtype, formats
def has_cuda_support():
def get_compute_capability():
if torch.cuda.is_available() and "CUDAExecutionProvider" in onnxruntime.get_available_providers():
major, _ = torch.cuda.get_device_capability()
return major >= 6
return False
major, minor = torch.cuda.get_device_capability()
sm = major * 10 + minor
return sm
return 0
def no_kv_cache_test_cases(provider: str, comprehensive: bool):
if provider == "CUDAExecutionProvider" and not has_cuda_support():
if provider == "CUDAExecutionProvider" and get_compute_capability() < 60:
return
yield
@ -192,7 +222,7 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool):
def kv_cache_test_cases(provider: str, comprehensive: bool):
if provider == "CUDAExecutionProvider" and not has_cuda_support():
if provider == "CUDAExecutionProvider" and get_compute_capability() < 60:
return
yield
@ -262,6 +292,92 @@ def mha_test_cases(provider: str, comprehensive: bool):
)
def no_kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool):
if provider == "CUDAExecutionProvider" and get_compute_capability() < 60:
return
yield
batch_sizes = [1, 2]
sequence_lengths = [1, 16, 127, 128, 255, 256, 383, 384, 400] if comprehensive else [1, 64, 128, 256]
heads = [4]
head_sizes = [8, 16, 32, 40, 64, 80, 96, 128, 160, 192, 224, 256] if comprehensive else [32, 64]
device, dtype, formats = get_provider_support_info(provider, False)
for format in formats:
for causal in [True, False]:
for num_heads in heads:
for head_size in head_sizes:
configs = [] # list of configurations to run in parallel
for batch_size in batch_sizes:
for sequence_length in sequence_lengths:
config = MultiHeadAttentionConfig(
batch_size=batch_size,
sequence_length=sequence_length,
num_heads=num_heads,
head_size=head_size,
causal=causal,
past_sequence_length=0,
kv_sequence_length=sequence_length,
max_cache_sequence_length=None,
provider=provider,
device=device,
dtype=dtype,
use_kv_cache=False,
share_past_present_buffer=False,
input_format=format,
)
configs.append(config)
yield configs
def kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool):
if provider == "CUDAExecutionProvider" and get_compute_capability() < 60:
return
yield
batch_sizes = [1, 2]
sequence_lengths = [1, 32, 127, 128, 383, 384, 400] if comprehensive else [1, 32, 127, 128]
heads = [4]
head_sizes = [8, 16, 32, 40, 64, 80, 96, 128, 160, 192, 224, 256] if comprehensive else [32, 64]
sequence_length = 1
device, dtype, formats = get_provider_support_info(provider, True)
for format in formats:
for causal in [True, False]:
for num_heads in heads:
for head_size in head_sizes:
configs = []
for batch_size in batch_sizes:
for past_sequence_length in sequence_lengths:
config = MultiHeadAttentionConfig(
batch_size=batch_size,
sequence_length=sequence_length,
num_heads=num_heads,
head_size=head_size,
causal=causal,
past_sequence_length=past_sequence_length,
kv_sequence_length=sequence_length,
max_cache_sequence_length=None,
provider=provider,
device=device,
dtype=dtype,
use_kv_cache=True,
share_past_present_buffer=False,
input_format=format,
)
configs.append(config)
yield configs
def multi_thread_test_cases(provider: str, comprehensive: bool):
return itertools.chain(
no_kv_cache_multi_thread_test_cases(provider, comprehensive),
kv_cache_multi_thread_test_cases(provider, comprehensive),
)
def causal_mask(seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, device=None):
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
@ -346,20 +462,189 @@ def parity_check_mha(
)
def parity_check_mha_multi_threading(
test_inputs: List[Dict],
rtol: float = 1e-3,
atol: float = 1e-3,
sdpa_kernel: int = SdpaKernel.DEFAULT,
max_threads: int = 5,
verbose: bool = False,
):
# Use the first config to create a session, which is shared by all configs to run in parallel.
config = test_inputs[0]["config"]
# For now, MHA CUDA kernel does not support causal so skip such test cases.
if config.causal and config.provider == "CUDAExecutionProvider":
return None
# Some kernel does not support certain input format.
if sdpa_kernel not in [
SdpaKernel.DEFAULT,
SdpaKernel.FLASH_ATTENTION,
SdpaKernel.EFFICIENT_ATTENTION,
] and config.input_format in [InputFormats.Q_KV_BSNH_BSN2H]:
return None
if verbose:
print(f"create a shared session with {vars(config)}")
onnx_model_str = create_multi_head_attention_onnx_model(config, use_symbolic_shape=True)
if config.provider == "CUDAExecutionProvider":
provider_options = {"arena_extend_strategy": "kSameAsRequested", "sdpa_kernel": int(sdpa_kernel)}
providers = [(config.provider, provider_options), "CPUExecutionProvider"]
else:
providers = ["CPUExecutionProvider"]
ort_session = InferenceSession(onnx_model_str, providers=providers)
def convert_to_ort_inputs(feed_dict):
ort_inputs = {}
for k, v in feed_dict.items():
if isinstance(v, numpy.ndarray):
ort_inputs[k] = v
else:
ort_inputs[k] = v.detach().cpu().numpy()
return ort_inputs
def check_parity_with_config(i: int):
config = test_inputs[i]["config"]
if verbose:
print(f"Thread {i} with {vars(config)}")
ort_inputs = test_inputs[i]["ort_inputs"]
if verbose:
print(f"Thread {i} ort inputs: {ort_inputs}")
ort_outputs = ort_session.run(None, convert_to_ort_inputs(ort_inputs))
out = numpy.reshape(
ort_outputs[0], (config.batch_size, config.sequence_length, config.num_heads, config.head_size)
)
# Create reference inputs
config.input_format = InputFormats.Q_K_V_BSNH_BSNH_BSNH
ref_inputs = test_inputs[i]["ref_inputs"]
if verbose:
print(f"Thread {i} ref inputs: {ref_inputs}")
q = (
ref_inputs["query"]
.reshape((config.batch_size, config.sequence_length, config.num_heads, config.head_size))
.transpose(1, 2)
)
k = (
ref_inputs["key"]
.reshape((config.batch_size, config.kv_sequence_length, config.num_heads, config.head_size))
.transpose(1, 2)
)
v = (
ref_inputs["value"]
.reshape((config.batch_size, config.kv_sequence_length, config.num_heads, config.head_size))
.transpose(1, 2)
)
mask = None
if config.causal:
mask = causal_mask(config.sequence_length, config.total_sequence_length, device=config.device)
k_cache = None
v_cache = None
if config.use_kv_cache:
past_k = ref_inputs["past_key"]
past_v = ref_inputs["past_value"]
out_ref, k_cache, v_cache = mha_with_past_reference(config, past_k, past_v, q, k, v, mask=mask)
else:
out_ref = attention_reference(config.head_size, q, k, v, mask=mask)
try:
numpy.testing.assert_allclose(
out,
out_ref.detach().cpu().numpy(),
rtol=rtol,
atol=atol,
equal_nan=True,
err_msg=f"output not close: {config=}",
)
if config.use_kv_cache:
present_key = ort_outputs[1]
numpy.testing.assert_allclose(
k_cache.detach().cpu().numpy(),
present_key,
rtol=rtol,
atol=atol,
equal_nan=True,
err_msg=f"present_key not close: {config=}",
)
present_value = ort_outputs[2]
numpy.testing.assert_allclose(
v_cache.detach().cpu().numpy(),
present_value,
rtol=rtol,
atol=atol,
equal_nan=True,
err_msg=f"present_value not close: {config=}",
)
except AssertionError as e:
print(f"Failed with {vars(config)}: {e}")
return e
if verbose:
print(f"Passed: {vars(config)}")
return None
num_threads = min(max_threads, len(test_inputs))
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
future_tasks = [executor.submit(check_parity_with_config, i) for i in range(num_threads)]
for future in concurrent.futures.as_completed(future_tasks):
result = future.result()
if result is not None:
return result
return None
# Do not run too many tests in CI pipeline. Change it to True to run all combinations in dev machine.
comprehensive_mode = False
class TestMultiHeadAttention(unittest.TestCase):
# TODO: enable tests on CUDAExecutionProvider after fixing the issue.
# @parameterized.expand(mha_test_cases("CUDAExecutionProvider", comprehensive_mode), skip_on_empty=True)
# def test_mha_cuda(self, config):
# parity_check_mha(config)
@parameterized.expand(mha_test_cases("CUDAExecutionProvider", comprehensive_mode), skip_on_empty=True)
def test_mha_cuda(self, config):
parity_check_mha(config)
@parameterized.expand(mha_test_cases("CPUExecutionProvider", comprehensive_mode), skip_on_empty=True)
def test_mha_cpu(self, config):
parity_check_mha(config)
def run_mha_cuda_multi_threading(self, spda_kernel):
for configs in multi_thread_test_cases("CUDAExecutionProvider", comprehensive_mode):
test_inputs = []
for config in configs:
ort_inputs = config.random_inputs()
# Create reference inputs
old_format = config.input_format
config.input_format = InputFormats.Q_K_V_BSNH_BSNH_BSNH
ref_inputs = config.random_inputs()
config.input_format = old_format
test_inputs.append({"config": config, "ort_inputs": ort_inputs, "ref_inputs": ref_inputs})
exception = parity_check_mha_multi_threading(test_inputs, sdpa_kernel=spda_kernel, max_threads=len(configs))
assert exception is None, f"{spda_kernel=}, {vars(configs[0])}, {exception}"
def test_mha_cuda_multi_threading(self):
self.run_mha_cuda_multi_threading(SdpaKernel.DEFAULT)
def test_mha_cuda_multi_threading_efficient(self):
self.run_mha_cuda_multi_threading(SdpaKernel.EFFICIENT_ATTENTION)
def test_mha_cuda_multi_threading_trt(self):
sm = get_compute_capability()
if sm in [75, 80, 86, 89]:
self.run_mha_cuda_multi_threading(
SdpaKernel.TRT_FUSED_ATTENTION
| SdpaKernel.TRT_FLASH_ATTENTION
| SdpaKernel.TRT_CROSS_ATTENTION
| SdpaKernel.TRT_CAUSAL_ATTENTION
)
if __name__ == "__main__":
with torch.no_grad():