[CUDA] Attention kernel provider option (#21344)

### Description
* Add a cuda provider option `sdpa_kernel` to choose which attention kernel to run for testing purpose. 
* Allow dump which attention kernel is used per node.
* Reserve  a flag for cudnn flash attention which will be added soon.

#### CUDA provider option sdpa_kernel
Instead of setting environment variable, we also support setting it in
provider option. Note that the setting is global per session. That could
help performance testing of each kernel.

#### Attention Kernel Debug Info
Set an environment variable `ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO=1`,
and ORT will print sdpa kernel used in each node:

For example 
```
ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO=1 ./onnxruntime_test_all --gtest_filter=MultiHeadAttentionTest*
```
It will show debug information of kernel used in testing:
```
[ RUN      ] MultiHeadAttentionTest.SelfAttention_Batch2_HeadSize32_NoBias_NoMask_PackedQKV
AttentionKernelOptions: FLASH_ATTENTION=0 EFFICIENT_ATTENTION=0 TRT_FUSED_ATTENTION=1 CUDNN_FLASH_ATTENTION=0 TRT_FLASH_ATTENTION=1 TRT_CROSS_ATTENTION=0 TRT_CAUSAL_ATTENTION=0 MATH=1
Operator=MultiHeadAttention Node=node1 DataType=fp16 TRT_FUSED_ATTENTION=1
AttentionKernelOptions: FLASH_ATTENTION=0 EFFICIENT_ATTENTION=1 TRT_FUSED_ATTENTION=0 CUDNN_FLASH_ATTENTION=0 TRT_FLASH_ATTENTION=0 TRT_CROSS_ATTENTION=0 TRT_CAUSAL_ATTENTION=0 MATH=1
Operator=MultiHeadAttention Node=node1 DataType=fp16 EFFICIENT_ATTENTION=1
```
In this test case, the debug info shows that one session uses trt fused
attention and another session use efficient attention.
This commit is contained in:
Tianlei Wu 2024-07-19 13:58:54 -07:00 committed by GitHub
parent 01df8c787d
commit 6ffaaebb60
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 645 additions and 110 deletions

View file

@ -15,6 +15,8 @@ set(contrib_ops_excluded_files
"bert/attention_softmax.h"
"bert/attention_softmax.cu"
"bert/attention_prepare_qkv.cu"
"bert/attention_kernel_options.h"
"bert/attention_kernel_options.cc"
"bert/decoder_attention_impl.h"
"bert/decoder_attention_impl.cu"
"bert/decoder_masked_multihead_attention.h"

View file

@ -786,8 +786,9 @@ if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS)
onnxruntime_add_shared_library_module(onnxruntime_providers_cuda_ut ${onnxruntime_test_providers_cuda_ut_src} $<TARGET_OBJECTS:onnxruntime_providers_cuda_obj>)
config_cuda_provider_shared_module(onnxruntime_providers_cuda_ut)
onnxruntime_add_include_to_target(onnxruntime_providers_cuda_ut GTest::gtest GTest::gmock)
add_dependencies(onnxruntime_providers_cuda_ut onnxruntime_test_utils onnxruntime_common)
target_include_directories(onnxruntime_providers_cuda_ut PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey)
target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common)
target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_test_utils onnxruntime_common)
if (MSVC)
# Cutlass code has an issue with the following:
# warning C4100: 'magic': unreferenced formal parameter

View file

@ -38,4 +38,5 @@ struct OrtCUDAProviderOptionsV2 {
int prefer_nhwc = 0; // make the CUDA EP NHWC preferred
int use_ep_level_unified_stream = 0; // flag specifying if ep level stream is used or not
int use_tf32 = 1; // use TF32
int sdpa_kernel = 0; // Scaled Dot Product Attention kernel option
};

View file

@ -147,6 +147,23 @@ constexpr const char* kDisableSparseAttentionV1 = "ORT_DISABLE_SPARSE_ATTENTION_
} // namespace sparse_attention
namespace attention {
enum class AttentionBackend : int {
FLASH_ATTENTION = 1,
EFFICIENT_ATTENTION = 2,
TRT_FUSED_ATTENTION = 4,
CUDNN_FLASH_ATTENTION = 8, // reserved for cuDNN flash attention.
MATH = 16, // unfused kernel cannot be disabled right now.
// The following kernels might be deprecated in the future.
TRT_FLASH_ATTENTION = 32,
TRT_CROSS_ATTENTION = 64,
TRT_CAUSAL_ATTENTION = 128,
};
// Environment variable to enable debug information of attention kernel to be printed. Default is 0 (disabled).
constexpr const char* kEnableAttentionKernelDebugInfo = "ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO";
// Environment variable to enable or disable TRT fused self attention kernel. Default is 0 (enabled).
constexpr const char* kDisableFusedSelfAttention = "ORT_DISABLE_FUSED_ATTENTION";
@ -157,6 +174,9 @@ constexpr const char* kDisableFusedCrossAttention = "ORT_DISABLE_FUSED_CROSS_ATT
// Note that those causal attention kernels use fp16 accumulation. There is potential accuracy drop using those kernels.
constexpr const char* kEnableFusedCausalAttention = "ORT_ENABLE_FUSED_CAUSAL_ATTENTION";
// Environment variable to enable or disable cuDNN flash attention.
constexpr const char* kEnableCudnnFlashAttention = "ORT_ENABLE_CUDNN_FLASH_ATTENTION";
// Environment variable to enable or disable TRT flash attention. This applies to both self and causal attention. Default is 0 (enabled).
constexpr const char* kDisableTrtFlashAttention = "ORT_DISABLE_TRT_FLASH_ATTENTION";
@ -166,11 +186,15 @@ constexpr const char* kDisableMemoryEfficientAttention = "ORT_DISABLE_MEMORY_EFF
// Environment variable to enable or disable flash attention. Default is 0 (enabled).
constexpr const char* kDisableFlashAttention = "ORT_DISABLE_FLASH_ATTENTION";
// Minimum sequence length to enable memory efficient attention in FP32.
constexpr int kMinSeqLenForMemoryEfficientAttentionFp32 = 256;
// Minimum sequence length to perfer memory efficient attention when data type is float32
constexpr const char* kMinSeqLenForEfficientAttentionFp32 = "ORT_MIN_SEQ_LEN_EFFICIENT_ATTENTION_FP32";
// Default value for minimum sequence length to enable memory efficient attention in FP32.
constexpr int kDefaultMinSeqLenForEfficientAttentionFp32 = 256;
// Minimum sequence length to prefer flash attention when input format is packed QKV for MultiHeadAttention
constexpr const char* kMinSeqLenForFlashAttentionPackedQKV = "ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV";
// Default value for the above setting.
constexpr int kDefaultMinSeqLenForFlashAttentionPackedQKV = 513;

View file

@ -3,7 +3,6 @@
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/shared_inc/fpgeneric.h"
#include "core/platform/env_var_utils.h"
#include "contrib_ops/cuda/bert/attention_impl.h"
#include "contrib_ops/cuda/bert/attention.h"
#include "contrib_ops/cuda/bert/bert_padding.h"
@ -40,36 +39,17 @@ REGISTER_KERNEL_TYPED(MLFloat16)
template <typename T>
Attention<T>::Attention(const OpKernelInfo& info) : CudaKernel(info), AttentionBase(info, false) {
disable_fused_self_attention_ =
sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedSelfAttention, false);
kernel_options_ = this->GetAttentionKernelOptions();
enable_trt_flash_attention_ =
sizeof(T) == 2 &&
!ParseEnvironmentVariableWithDefault<bool>(attention::kDisableTrtFlashAttention, false);
disable_fused_self_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtFusedAttention();
enable_fused_causal_attention_ =
sizeof(T) == 2 &&
ParseEnvironmentVariableWithDefault<bool>(attention::kEnableFusedCausalAttention, false);
enable_trt_flash_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtFlashAttention();
#if USE_MEMORY_EFFICIENT_ATTENTION
disable_memory_efficient_attention_ =
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableMemoryEfficientAttention, false);
#else
disable_memory_efficient_attention_ = true;
#endif
enable_fused_causal_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtCausalAttention();
#if USE_FLASH_ATTENTION
disable_flash_attention_ =
sizeof(T) != 2 ||
onnxruntime::ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFlashAttention, false);
min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault<int>(
attention::kMinSeqLenForFlashAttentionPackedQKV,
attention::kDefaultMinSeqLenForFlashAttentionPackedQKV);
#else
disable_flash_attention_ = true;
min_seq_len_for_flash_attention_packed_qkv_ = 0;
#endif
disable_memory_efficient_attention_ = !kernel_options_->UseEfficientAttention();
disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention();
}
template <typename T>
@ -134,7 +114,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
parameters.num_heads,
parameters.num_heads);
// When input is packed QKV format, TensorRT kernel might be faster when sequence length <= 512.
if (use_flash_attention && parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) {
if (use_flash_attention && parameters.sequence_length < kernel_options_->MinSeqLenForFlashAttentionPackedQkv()) {
use_flash_attention = false;
}
// Allocate buffers
@ -220,7 +200,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
nullptr == past &&
nullptr == present &&
(nullptr == mask_index || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) &&
(sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) &&
(sizeof(T) == 2 || parameters.sequence_length >= this->kernel_options_->MinSeqLenForEfficientAttentionFp32()) &&
has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.v_head_size);
if (use_memory_efficient_attention) {
@ -231,6 +211,20 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
constexpr bool use_memory_efficient_attention = false;
#endif
if (kernel_options_->AllowDebugInfo()) {
AttentionKernelDebugInfo debug_info;
debug_info.use_flash_attention = use_flash_attention;
debug_info.use_efficient_attention = use_memory_efficient_attention;
if (fused_runner != nullptr) {
debug_info.SetTrtFusedKernel(is_unidirectional_, enable_trt_flash_attention_, sequence_length);
}
debug_info.Print("Attention",
this->Node().Name(),
std::is_same<T, MLFloat16>::value,
std::is_same<T, BFloat16>::value);
}
cublasHandle_t cublas = GetCublasHandle(context);
typedef typename ToCudaType<T>::MappedType CudaT;
@ -268,7 +262,6 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
use_fused_cross_attention,
use_memory_efficient_attention);
IAllocatorUniquePtr<void> work_space = IAllocator::MakeUniquePtr<void>(allocator, workSpaceSize, false, context->GetComputeStream());
;
typedef typename ToCudaType<T>::MappedType CudaT;
AttentionData<CudaT> data;

View file

@ -8,6 +8,7 @@
#include "core/providers/cuda/cuda_kernel.h"
#include "contrib_ops/cpu/bert/attention_base.h"
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h"
#include "contrib_ops/cuda/bert/attention_kernel_options.h"
namespace onnxruntime {
namespace contrib {
@ -27,9 +28,10 @@ class Attention final : public CudaKernel, public AttentionBase {
bool enable_trt_flash_attention_;
bool enable_fused_causal_attention_;
bool disable_memory_efficient_attention_;
int min_seq_len_for_flash_attention_packed_qkv_;
mutable std::unique_ptr<MHARunner> fused_fp16_runner_;
mutable std::once_flag fused_fp16_runner_created_;
const AttentionKernelOptions* kernel_options_;
};
} // namespace cuda

View file

@ -0,0 +1,166 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/cuda/bert/attention_kernel_options.h"
#include <iomanip>
#include <iostream>
#include <sstream>
#include "contrib_ops/cpu/bert/attention_common.h"
#include "core/providers/shared_library/provider_api.h"
#include "core/platform/env_var_utils.h"
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h"
using namespace onnxruntime::contrib::attention;
namespace onnxruntime {
void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {
if (value > 0) {
use_flash_attention_ = (value & static_cast<int>(AttentionBackend::FLASH_ATTENTION)) > 0;
use_efficient_attention_ = (value & static_cast<int>(AttentionBackend::EFFICIENT_ATTENTION)) > 0;
use_trt_fused_attention_ = (value & static_cast<int>(AttentionBackend::TRT_FUSED_ATTENTION)) > 0;
use_cudnn_flash_attention_ = (value & static_cast<int>(AttentionBackend::CUDNN_FLASH_ATTENTION)) > 0;
use_unfused_ = (value & static_cast<int>(AttentionBackend::MATH)) > 0;
use_trt_flash_attention_ = (value & static_cast<int>(AttentionBackend::TRT_FLASH_ATTENTION)) > 0;
use_trt_cross_attention_ = (value & static_cast<int>(AttentionBackend::TRT_CROSS_ATTENTION)) > 0;
use_trt_causal_attention_ = (value & static_cast<int>(AttentionBackend::TRT_CAUSAL_ATTENTION)) > 0;
} else {
use_flash_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFlashAttention, false);
use_efficient_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableMemoryEfficientAttention, false);
use_trt_fused_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFusedSelfAttention, false);
use_cudnn_flash_attention_ = ParseEnvironmentVariableWithDefault<bool>(kEnableCudnnFlashAttention, false);
use_unfused_ = true;
use_trt_flash_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableTrtFlashAttention, false);
use_trt_cross_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFusedCrossAttention, false);
use_trt_causal_attention_ = ParseEnvironmentVariableWithDefault<bool>(kEnableFusedCausalAttention, false);
}
enable_kernel_debug_info_ = ParseEnvironmentVariableWithDefault<bool>(kEnableAttentionKernelDebugInfo, false);
// When value is positive, we use 0 as default minimum sequence lengths to align with common usage in testing.
min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault<int>(
kMinSeqLenForFlashAttentionPackedQKV,
value > 0 ? 0 : kDefaultMinSeqLenForFlashAttentionPackedQKV);
min_seq_len_for_efficient_attention_fp32_ = ParseEnvironmentVariableWithDefault<int>(
kMinSeqLenForEfficientAttentionFp32,
value > 0 ? 0 : kDefaultMinSeqLenForEfficientAttentionFp32);
if (use_build_flag) {
// Some kernels can be disabled at build time. If they are disabled, we should not use them.
#ifndef USE_FLASH_ATTENTION
use_flash_attention_ = false;
#endif
#ifndef USE_MEMORY_EFFICIENT_ATTENTION
use_efficient_attention_ = false;
#endif
}
}
void AttentionKernelOptions::InitializeOnce(
int sdpa_kernel, bool use_build_flag) {
std::call_once(this->initialize_once_flag_, [&]() {
this->Initialize(sdpa_kernel, use_build_flag);
if (this->enable_kernel_debug_info_) {
this->Print();
}
});
}
void AttentionKernelOptions::Print() const {
std::stringstream sstream;
sstream << "AttentionKernelOptions:";
sstream << " FLASH_ATTENTION=" << int(use_flash_attention_);
sstream << " EFFICIENT_ATTENTION=" << int(use_efficient_attention_);
sstream << " TRT_FUSED_ATTENTION=" << int(use_trt_fused_attention_);
sstream << " CUDNN_FLASH_ATTENTION=" << int(use_cudnn_flash_attention_);
sstream << " TRT_FLASH_ATTENTION=" << int(use_trt_flash_attention_);
sstream << " TRT_CROSS_ATTENTION=" << int(use_trt_cross_attention_);
sstream << " TRT_CAUSAL_ATTENTION=" << int(use_trt_causal_attention_);
sstream << " MATH=" << int(use_unfused_);
if (!use_unfused_) {
sstream << std::endl
<< "Warning: Unfused kernel cannot be disabled right now. MATH=0 is ignored.";
}
// Output text in Cyan color to make it easier to spot
std::cout << "\x1B[36m" << sstream.str() << "\x1B[0m" << std::endl;
}
// Classify the kernel used in TRT fused runner.
void AttentionKernelDebugInfo::SetTrtFusedKernel(bool causal, bool enable_trt_flash_attention, int sequence_length) {
if (causal) {
use_trt_causal_attention = true;
} else if (enable_trt_flash_attention && sequence_length >= contrib::cuda::kMinSequenceLengthFlashAttention) {
use_trt_flash_attention = true;
} else {
use_trt_fused_attention = true;
}
}
void AttentionKernelDebugInfo::Print(const char* operator_name,
const std::string& node_name,
bool is_float16,
bool is_bfloat16) const {
std::stringstream sstream;
sstream << "Operator=" << operator_name;
if (node_name.length() > 0) {
sstream << " Node=" << node_name;
}
if (is_bfloat16) {
sstream << " DataType=bf16";
} else if (is_float16) {
sstream << " DataType=fp16";
} else {
sstream << " DataType=fp32";
}
if (use_flash_attention.has_value() && use_flash_attention.value()) {
sstream << " FLASH_ATTENTION=" << int(use_flash_attention.value());
}
if (use_efficient_attention.has_value() && use_efficient_attention.value()) {
sstream << " EFFICIENT_ATTENTION=" << int(use_efficient_attention.value());
}
if (use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) {
sstream << " TRT_FUSED_ATTENTION=" << int(use_trt_fused_attention.value());
}
if (use_cudnn_flash_attention.has_value() && use_cudnn_flash_attention.value()) {
sstream << " CUDNN_FLASH_ATTENTION=" << int(use_cudnn_flash_attention.value());
}
if (use_trt_flash_attention.has_value() && use_trt_flash_attention.value()) {
sstream << " TRT_FLASH_ATTENTION=" << int(use_trt_flash_attention.value());
}
if (use_trt_cross_attention.has_value() && use_trt_cross_attention.value()) {
sstream << " TRT_CROSS_ATTENTION=" << int(use_trt_cross_attention.value());
}
if (use_trt_causal_attention.has_value() && use_trt_causal_attention.value()) {
sstream << " TRT_CAUSAL_ATTENTION=" << int(use_trt_causal_attention.value());
}
bool use_fused = (use_flash_attention.has_value() && use_flash_attention.value()) ||
(use_efficient_attention.has_value() && use_efficient_attention.value()) ||
(use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) ||
(use_cudnn_flash_attention.has_value() && use_cudnn_flash_attention.value()) ||
(use_trt_flash_attention.has_value() && use_trt_flash_attention.value()) ||
(use_trt_cross_attention.has_value() && use_trt_cross_attention.value()) ||
(use_trt_causal_attention.has_value() && use_trt_causal_attention.value());
// Fall back to unfused when no fused kernel is enabled.
if (!use_fused) {
sstream << " MATH=1";
}
// Output text in Cyan color to make it easier to spot.
std::cout << "\x1B[36m" << sstream.str() << "\x1B[0m" << std::endl;
}
} // namespace onnxruntime

View file

@ -0,0 +1,67 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <mutex>
#include <optional>
#include <string>
namespace onnxruntime {
struct AttentionKernelDebugInfo {
std::optional<bool> use_flash_attention = std::nullopt;
std::optional<bool> use_efficient_attention = std::nullopt;
std::optional<bool> use_trt_fused_attention = std::nullopt;
std::optional<bool> use_cudnn_flash_attention = std::nullopt;
std::optional<bool> use_trt_flash_attention = std::nullopt;
std::optional<bool> use_trt_cross_attention = std::nullopt;
std::optional<bool> use_trt_causal_attention = std::nullopt;
void SetTrtFusedKernel(bool causal, bool enable_trt_flash_attention, int sequence_length);
void Print(const char* operator_name, const std::string& node_name, bool is_float16, bool is_bfloat16) const;
};
class AttentionKernelOptions {
public:
void InitializeOnce(int sdpa_kernel, bool use_build_flag);
bool UseFlashAttention() const { return use_flash_attention_; }
bool UseEfficientAttention() const { return use_efficient_attention_; }
bool UseTrtFusedAttention() const { return use_trt_fused_attention_; }
bool UseCudnnFlashAttention() const { return use_cudnn_flash_attention_; }
bool UseUnfusedAttention() const { return use_unfused_; }
bool UseTrtFlashAttention() const { return use_trt_flash_attention_; }
bool UseTrtCrossAttention() const { return use_trt_cross_attention_; }
bool UseTrtCausalAttention() const { return use_trt_causal_attention_; }
bool AllowDebugInfo() const { return enable_kernel_debug_info_; }
int MinSeqLenForFlashAttentionPackedQkv() const { return min_seq_len_for_flash_attention_packed_qkv_; }
int MinSeqLenForEfficientAttentionFp32() const { return min_seq_len_for_efficient_attention_fp32_; }
protected:
void Print() const;
void Initialize(int value, bool use_build_flag);
private:
bool use_flash_attention_{true};
bool use_efficient_attention_{true};
bool use_trt_fused_attention_{true};
bool use_cudnn_flash_attention_{false};
bool use_unfused_{true};
bool use_trt_flash_attention_{true};
bool use_trt_cross_attention_{true};
// Causal attention is disabled by default in #14732.
bool use_trt_causal_attention_{false};
bool enable_kernel_debug_info_{false};
int min_seq_len_for_flash_attention_packed_qkv_{0};
int min_seq_len_for_efficient_attention_fp32_{0};
std::once_flag initialize_once_flag_;
};
} // namespace onnxruntime

View file

@ -52,20 +52,13 @@ GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
rotary_interleaved_ = info.GetAttrOrDefault<int64_t>("rotary_interleaved", 0) == 1;
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);
#if USE_FLASH_ATTENTION
disable_flash_attention_ = sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFlashAttention, false);
#else
disable_flash_attention_ = true;
#endif
kernel_options_ = this->GetAttentionKernelOptions();
disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention();
#if USE_MEMORY_EFFICIENT_ATTENTION
// Memory efficient attention only supports float and float16, not bfloat16.
disable_memory_efficient_attention_ = std::is_same<T, BFloat16>::value ||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableMemoryEfficientAttention, false);
#else
disable_memory_efficient_attention_ = true;
#endif
disable_memory_efficient_attention_ = std::is_same<T, BFloat16>::value || !kernel_options_->UseEfficientAttention();
if (!disable_flash_attention_) {
zeros_ = this->GetScratchBuffer<int>(kZerosCount, nullptr);
}
@ -161,7 +154,7 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
!use_flash_attention &&
!disable_memory_efficient_attention_ &&
local_window_size_ == -1 &&
(sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) &&
(sizeof(T) == 2 || parameters.sequence_length >= this->kernel_options_->MinSeqLenForEfficientAttentionFp32()) &&
has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.head_size);
if (!use_flash_attention && !use_memory_efficient_attention && local_window_size_ != -1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
@ -201,6 +194,17 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
auto unpacked_qkv_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
#endif
if (kernel_options_->AllowDebugInfo()) {
AttentionKernelDebugInfo debug_info;
debug_info.use_flash_attention = use_flash_attention;
debug_info.use_efficient_attention = use_memory_efficient_attention;
debug_info.Print("GroupQueryAttention",
this->Node().Name(),
std::is_same<T, MLFloat16>::value,
std::is_same<T, BFloat16>::value);
}
// seqlens_k buffer
size_t seqlens_k_bytes = 0;
seqlens_k_bytes = sizeof(int) * parameters.batch_size;

View file

@ -6,6 +6,7 @@
#include <memory>
#include "core/providers/cuda/cuda_kernel.h"
#include "contrib_ops/cuda/bert/group_query_attention_impl.h"
#include "contrib_ops/cuda/bert/attention_kernel_options.h"
namespace onnxruntime {
namespace contrib {
@ -32,6 +33,7 @@ class GroupQueryAttention final : public CudaKernel {
bool disable_memory_efficient_attention_;
static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256)
IAllocatorUniquePtr<int> zeros_;
const AttentionKernelOptions* kernel_options_;
};
} // namespace cuda

View file

@ -2,7 +2,6 @@
// Licensed under the MIT License.
#include "core/providers/cuda/cuda_common.h"
#include "core/platform/env_var_utils.h"
#include "contrib_ops/cuda/bert/attention_impl.h"
#include "contrib_ops/cuda/bert/multihead_attention.h"
#include "contrib_ops/cpu/bert/multihead_attention_helper.h"
@ -47,31 +46,16 @@ MultiHeadAttention<T>::MultiHeadAttention(const OpKernelInfo& info)
is_unidirectional_ = info.GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;
ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support CUDA kernel. Consider using Attention or GQA instead.");
disable_fused_self_attention_ = sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedSelfAttention, false);
kernel_options_ = this->GetAttentionKernelOptions();
enable_trt_flash_attention_ = sizeof(T) == 2 &&
!ParseEnvironmentVariableWithDefault<bool>(attention::kDisableTrtFlashAttention, false);
disable_fused_self_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtFusedAttention();
enable_trt_flash_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtFlashAttention();
#if USE_FLASH_ATTENTION
disable_flash_attention_ = sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFlashAttention, false);
min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault<int>(
attention::kMinSeqLenForFlashAttentionPackedQKV,
attention::kDefaultMinSeqLenForFlashAttentionPackedQKV);
#else
disable_flash_attention_ = true;
min_seq_len_for_flash_attention_packed_qkv_ = 0;
#endif
disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention();
#if USE_MEMORY_EFFICIENT_ATTENTION
disable_memory_efficient_attention_ = ParseEnvironmentVariableWithDefault<bool>(attention::kDisableMemoryEfficientAttention, false);
#else
disable_memory_efficient_attention_ = true;
#endif
disable_memory_efficient_attention_ = !kernel_options_->UseEfficientAttention();
disable_fused_cross_attention_ = sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedCrossAttention, false);
disable_fused_cross_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtCrossAttention();
// Allocate cache buffers
constexpr size_t cache_bytes = sizeof(int32_t) * (static_cast<size_t>(kCumulatedSequenceLengthCacheMaxBatchSize) + 1);
@ -155,7 +139,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
parameters.num_heads);
// When input is packed QKV format, TensorRT kernel might be faster than flash attention when sequence length <= 512.
if (use_flash_attention && key == nullptr && value == nullptr &&
parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) {
parameters.sequence_length < kernel_options_->MinSeqLenForFlashAttentionPackedQkv()) {
use_flash_attention = false;
}
// Allocate buffers
@ -229,9 +213,10 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
}
#if USE_MEMORY_EFFICIENT_ATTENTION
int length_threshold = this->kernel_options_->MinSeqLenForEfficientAttentionFp32();
bool is_long_sequence = sizeof(T) == 2 || // sequence length threshold is 0 for FP16
parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32 ||
parameters.kv_sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32;
parameters.sequence_length >= length_threshold ||
parameters.kv_sequence_length >= length_threshold;
bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0;
@ -249,6 +234,21 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
constexpr bool use_memory_efficient_attention = false;
#endif
if (kernel_options_->AllowDebugInfo()) {
AttentionKernelDebugInfo debug_info;
debug_info.use_flash_attention = use_flash_attention;
debug_info.use_trt_cross_attention = fused_cross_attention_kernel != nullptr;
debug_info.use_efficient_attention = use_memory_efficient_attention;
if (fused_fp16_runner_ != nullptr) {
debug_info.SetTrtFusedKernel(is_unidirectional_, enable_trt_flash_attention_, sequence_length);
}
debug_info.Print("MultiHeadAttention",
this->Node().Name(),
std::is_same<T, MLFloat16>::value,
std::is_same<T, BFloat16>::value);
}
// When packed kv or packed qkv is used, there is no needed for add bias transpose thus no qkv workspace.
// TODO(tianleiwu): flash attention or memory efficient attention might not need qkv workspace sometime.
bool no_qkv_workspace = nullptr == value &&

View file

@ -8,6 +8,7 @@
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h"
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h"
#include "contrib_ops/cuda/bert/attention_impl.h"
#include "contrib_ops/cuda/bert/attention_kernel_options.h"
namespace onnxruntime {
namespace contrib {
@ -31,12 +32,12 @@ class MultiHeadAttention final : public CudaKernel {
bool disable_fused_cross_attention_;
bool disable_flash_attention_;
bool disable_memory_efficient_attention_;
int min_seq_len_for_flash_attention_packed_qkv_;
mutable std::unique_ptr<MHARunner> fused_fp16_runner_;
mutable std::once_flag fused_fp16_runner_created_;
mutable const FusedMultiHeadCrossAttentionKernel* fused_fp16_cross_attention_kernel_;
mutable CumulatedSequenceLengthCache cumulated_sequence_length_q_cache_;
mutable CumulatedSequenceLengthCache cumulated_sequence_length_kv_cache_;
const AttentionKernelOptions* kernel_options_;
};
} // namespace cuda

View file

@ -33,12 +33,11 @@ REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
template <typename T>
TrtFusedAttention<T>::TrtFusedAttention() {
disable_fused_runner_ = sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedSelfAttention, false);
enable_trt_flash_attention_ = sizeof(T) == 2 &&
!ParseEnvironmentVariableWithDefault<bool>(attention::kDisableTrtFlashAttention, false);
TrtFusedAttention<T>::TrtFusedAttention(const OpKernelInfo& info)
: CudaKernel(info) {
kernel_options_ = this->GetAttentionKernelOptions();
disable_fused_runner_ = sizeof(T) != 2 || !kernel_options_->UseTrtFusedAttention();
enable_trt_flash_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtFlashAttention();
}
template <typename T>
@ -86,7 +85,8 @@ template class TrtFusedAttention<float>;
template class TrtFusedAttention<MLFloat16>;
template <typename T>
PackedAttention<T>::PackedAttention(const OpKernelInfo& info) : TrtFusedAttention<T>(), CudaKernel(info) {
PackedAttention<T>::PackedAttention(const OpKernelInfo& info)
: TrtFusedAttention<T>(info) {
int64_t num_heads = 0;
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
num_heads_ = static_cast<int32_t>(num_heads);
@ -268,7 +268,7 @@ Status PackedAttention<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* relative_position_bias = context->Input<Tensor>(5);
PackedAttentionParameters parameters;
parameters.use_tf32 = UseTF32();
parameters.use_tf32 = this->UseTF32();
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(),
weights->Shape(),
bias->Shape(),
@ -295,6 +295,19 @@ Status PackedAttention<T>::ComputeInternal(OpKernelContext* context) const {
}
#endif
if (this->kernel_options_->AllowDebugInfo()) {
AttentionKernelDebugInfo debug_info;
debug_info.use_efficient_attention = use_memory_efficient_attention;
if (fused_runner != nullptr) {
debug_info.SetTrtFusedKernel(false /*causal*/, this->enable_trt_flash_attention_, parameters.sequence_length);
}
debug_info.Print("PackedAttention",
this->Node().Name(),
std::is_same<T, MLFloat16>::value,
std::is_same<T, BFloat16>::value);
}
typedef typename ToCudaType<T>::MappedType CudaT;
CudaT one = ToCudaType<T>::FromFloat(1.0f);
CudaT zero = ToCudaType<T>::FromFloat(0.0f);
@ -313,7 +326,7 @@ Status PackedAttention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
reinterpret_cast<const CudaT*>(weights->Data<T>()), n,
reinterpret_cast<const CudaT*>(input->Data<T>()), k,
&zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop, UseTF32()));
&zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop, this->UseTF32()));
constexpr size_t element_size = sizeof(T);
constexpr bool no_qkv_workspace = false; // need workspace to add bias
@ -341,7 +354,7 @@ Status PackedAttention<T>::ComputeInternal(OpKernelContext* context) const {
data.fused_runner = reinterpret_cast<void*>(fused_runner);
data.use_memory_efficient_attention = use_memory_efficient_attention;
return QkvToContext<CudaT>(device_prop, cublas, Stream(context), parameters, data);
return QkvToContext<CudaT>(device_prop, cublas, this->Stream(context), parameters, data);
}
} // namespace cuda

View file

@ -9,6 +9,7 @@
#include "core/providers/cuda/cuda_kernel.h"
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h"
#include "contrib_ops/cpu/bert/attention_common.h"
#include "contrib_ops/cuda/bert/attention_kernel_options.h"
namespace onnxruntime {
namespace contrib {
@ -17,14 +18,16 @@ namespace cuda {
using namespace onnxruntime::cuda;
template <typename T>
class TrtFusedAttention {
class TrtFusedAttention : public CudaKernel {
public:
TrtFusedAttention();
TrtFusedAttention(const OpKernelInfo& info);
protected:
MHARunner* GetFusedRunner(const cudaDeviceProp& device_prop, const PackedAttentionParameters& parameters) const;
protected:
const AttentionKernelOptions* kernel_options_;
bool disable_fused_runner_;
bool enable_trt_flash_attention_;
mutable std::unique_ptr<MHARunner> fused_fp16_runner_;
@ -32,7 +35,7 @@ class TrtFusedAttention {
};
template <typename T>
class PackedAttention final : public TrtFusedAttention<T>, public CudaKernel {
class PackedAttention final : public TrtFusedAttention<T> {
public:
PackedAttention(const OpKernelInfo& info);
Status ComputeInternal(OpKernelContext* context) const override;

View file

@ -35,30 +35,16 @@ REGISTER_KERNEL_TYPED(MLFloat16)
template <typename T>
PackedMultiHeadAttention<T>::PackedMultiHeadAttention(const OpKernelInfo& info)
: TrtFusedAttention<T>(), CudaKernel(info) {
: TrtFusedAttention<T>(info) {
int64_t num_heads = 0;
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
num_heads_ = static_cast<int32_t>(num_heads);
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);
#if USE_FLASH_ATTENTION
disable_flash_attention_ = sizeof(T) != 2 || onnxruntime::ParseEnvironmentVariableWithDefault<bool>(
attention::kDisableFlashAttention, false);
min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault<int>(
attention::kMinSeqLenForFlashAttentionPackedQKV,
attention::kDefaultMinSeqLenForFlashAttentionPackedQKV);
#else
disable_flash_attention_ = true;
min_seq_len_for_flash_attention_packed_qkv_ = 0;
#endif
disable_flash_attention_ = sizeof(T) != 2 || !this->kernel_options_->UseFlashAttention();
#if USE_MEMORY_EFFICIENT_ATTENTION
disable_memory_efficient_attention_ = onnxruntime::ParseEnvironmentVariableWithDefault<bool>(
attention::kDisableMemoryEfficientAttention, false);
#else
disable_memory_efficient_attention_ = true;
#endif
disable_memory_efficient_attention_ = !this->kernel_options_->UseEfficientAttention();
}
template <typename T>
@ -228,7 +214,7 @@ Status PackedMultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) co
const Tensor* relative_position_bias = context->Input<Tensor>(6);
PackedAttentionParameters parameters;
parameters.use_tf32 = UseTF32();
parameters.use_tf32 = this->UseTF32();
ORT_RETURN_IF_ERROR(CheckInputs(query->Shape(),
key,
value,
@ -255,7 +241,7 @@ Status PackedMultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) co
// When input is packed QKV format, TensorRT kernel might be faster when sequence length <= 512.
if (use_flash_attention && key == nullptr && value == nullptr &&
parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) {
parameters.sequence_length < this->kernel_options_->MinSeqLenForFlashAttentionPackedQkv()) {
use_flash_attention = false;
}
}
@ -271,11 +257,25 @@ Status PackedMultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) co
bool is_good_for_rpb = !parameters.has_relative_position_bias || parameters.sequence_length % (4 * sizeof(T)) == 0;
use_memory_efficient_attention =
is_good_for_rpb &&
(sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) &&
(sizeof(T) == 2 || parameters.sequence_length >= this->kernel_options_->MinSeqLenForEfficientAttentionFp32()) &&
has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.v_head_size);
}
#endif
if (this->kernel_options_->AllowDebugInfo()) {
AttentionKernelDebugInfo debug_info;
debug_info.use_flash_attention = use_flash_attention;
debug_info.use_efficient_attention = use_memory_efficient_attention;
if (fused_runner != nullptr) {
debug_info.SetTrtFusedKernel(false /*causal*/, this->enable_trt_flash_attention_, parameters.sequence_length);
}
debug_info.Print("PackedMultiHeadAttention",
this->Node().Name(),
std::is_same<T, MLFloat16>::value,
std::is_same<T, BFloat16>::value);
}
typedef typename ToCudaType<T>::MappedType CudaT;
cublasHandle_t cublas = this->GetCublasHandle(context);

View file

@ -4,13 +4,14 @@
#pragma once
#include "contrib_ops/cuda/bert/packed_attention.h"
#include "contrib_ops/cuda/bert/attention_kernel_options.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
template <typename T>
class PackedMultiHeadAttention final : public TrtFusedAttention<T>, public CudaKernel {
class PackedMultiHeadAttention final : public TrtFusedAttention<T> {
public:
PackedMultiHeadAttention(const OpKernelInfo& info);
Status ComputeInternal(OpKernelContext* context) const override;
@ -32,7 +33,6 @@ class PackedMultiHeadAttention final : public TrtFusedAttention<T>, public CudaK
bool disable_memory_efficient_attention_;
bool disable_flash_attention_;
int min_seq_len_for_flash_attention_packed_qkv_;
};
} // namespace cuda

View file

@ -17,6 +17,10 @@
#include "core/providers/cuda/shared_inc/cuda_call.h"
#include "core/providers/cuda/tunable/cuda_tuning_context.h"
#ifndef DISABLE_CONTRIB_OPS
#include "contrib_ops/cuda/bert/attention_kernel_options.h"
#endif
namespace onnxruntime {
void RunOnUnload(std::function<void()> function);
@ -80,6 +84,14 @@ class CUDAExecutionProvider : public IExecutionProvider {
bool IsNHWCPreferred() const { return info_.prefer_nhwc; }
bool UseTF32() const { return info_.use_tf32; }
#ifndef DISABLE_CONTRIB_OPS
// Attention kernel options parsed from sdpa_kernel cuda provider option.
const AttentionKernelOptions* GetAttentionKernelOptions() const {
attention_kernel_options_.InitializeOnce(info_.sdpa_kernel, true);
return &attention_kernel_options_;
}
#endif
ProviderOptions GetProviderOptions() const override {
return CUDAExecutionProviderInfo::ToProviderOptions(info_);
}
@ -110,6 +122,11 @@ class CUDAExecutionProvider : public IExecutionProvider {
// the tuning context might be altered when calling into a TunableOp
mutable cuda::tunable::CudaTuningContext tuning_context_;
#ifndef DISABLE_CONTRIB_OPS
// Attention kernel options parsed from sdpa_kernel cuda provider option.
mutable AttentionKernelOptions attention_kernel_options_;
#endif
class PerThreadContext final {
public:
PerThreadContext(OrtDevice::DeviceId device_id, cudaStream_t stream, size_t cuda_mem_limit, ArenaExtendStrategy arena_extend_strategy,

View file

@ -34,6 +34,7 @@ constexpr const char* kEnableSkipLayerNormStrictMode = "enable_skip_layer_norm_s
constexpr const char* kPreferNHWCMode = "prefer_nhwc";
constexpr const char* kUseEPLevelUnifiedStream = "use_ep_level_unified_stream";
constexpr const char* kUseTF32 = "use_tf32";
constexpr const char* kSdpaKernel = "sdpa_kernel";
} // namespace provider_option_names
} // namespace cuda
@ -117,6 +118,7 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P
.AddAssignmentToReference(cuda::provider_option_names::kPreferNHWCMode, info.prefer_nhwc)
.AddAssignmentToReference(cuda::provider_option_names::kUseEPLevelUnifiedStream, info.use_ep_level_unified_stream)
.AddAssignmentToReference(cuda::provider_option_names::kUseTF32, info.use_tf32)
.AddAssignmentToReference(cuda::provider_option_names::kSdpaKernel, info.sdpa_kernel)
.AddValueParser(
cuda::provider_option_names::kTunableOpEnable,
[&info](const std::string& value_str) -> Status {
@ -170,6 +172,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecution
{cuda::provider_option_names::kPreferNHWCMode, MakeStringWithClassicLocale(info.prefer_nhwc)},
{cuda::provider_option_names::kUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)},
{cuda::provider_option_names::kUseTF32, MakeStringWithClassicLocale(info.use_tf32)},
{cuda::provider_option_names::kSdpaKernel, MakeStringWithClassicLocale(info.sdpa_kernel)},
};
return options;
@ -192,6 +195,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const OrtCUDAProvid
{cuda::provider_option_names::kPreferNHWCMode, MakeStringWithClassicLocale(info.prefer_nhwc)},
{cuda::provider_option_names::kUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)},
{cuda::provider_option_names::kUseTF32, MakeStringWithClassicLocale(info.use_tf32)},
{cuda::provider_option_names::kSdpaKernel, MakeStringWithClassicLocale(info.sdpa_kernel)},
};
return options;

View file

@ -79,6 +79,8 @@ struct CUDAExecutionProviderInfo {
// By default, enable TF32 to speed up float GEMM/MatMul or cuDNN convolution of float matrices.
bool use_tf32{true};
int sdpa_kernel{0};
static CUDAExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
static ProviderOptions ToProviderOptions(const CUDAExecutionProviderInfo& info);
static ProviderOptions ToProviderOptions(const OrtCUDAProviderOptionsV2& info);
@ -91,6 +93,7 @@ struct std::hash<::onnxruntime::CUDAExecutionProviderInfo> {
size_t value{0xbc9f1d34}; // seed
// Bits: device_id (16), arena_extend_strategy/cudnn_conv_algo_search (reserved 2), boolean options (1 each)
// Do not exceed 32 bits here otherwise some bits will be lost in x86.
size_t data = static_cast<size_t>(info.device_id) ^
(static_cast<size_t>(info.arena_extend_strategy) << 16) ^
(static_cast<size_t>(info.cudnn_conv_algo_search) << 18) ^
@ -109,6 +112,7 @@ struct std::hash<::onnxruntime::CUDAExecutionProviderInfo> {
onnxruntime::HashCombine(info.gpu_mem_limit, value);
onnxruntime::HashCombine(info.tunable_op.max_tuning_duration_ms, value);
onnxruntime::HashCombine(info.sdpa_kernel, value);
// Memory pointers
onnxruntime::HashCombine(reinterpret_cast<size_t>(info.user_compute_stream), value);

View file

@ -94,6 +94,12 @@ class CudaKernel : public OpKernel {
return provider_->UseTF32();
}
#ifndef DISABLE_CONTRIB_OPS
const AttentionKernelOptions* GetAttentionKernelOptions() const {
return provider_->GetAttentionKernelOptions();
}
#endif
tunable::CudaTuningContext* GetTuningContext() const {
return static_cast<tunable::CudaTuningContext*>(provider_->GetTuningContext());
}

View file

@ -226,6 +226,7 @@ struct CUDA_Provider : Provider {
info.enable_skip_layer_norm_strict_mode = params->enable_skip_layer_norm_strict_mode != 0;
info.use_ep_level_unified_stream = params->use_ep_level_unified_stream != 0;
info.use_tf32 = params->use_tf32 != 0;
info.sdpa_kernel = params->sdpa_kernel;
return std::make_shared<CUDAProviderFactory>(info);
}
@ -260,6 +261,7 @@ struct CUDA_Provider : Provider {
cuda_options.prefer_nhwc = internal_options.prefer_nhwc;
cuda_options.use_ep_level_unified_stream = internal_options.use_ep_level_unified_stream;
cuda_options.use_tf32 = internal_options.use_tf32;
cuda_options.sdpa_kernel = internal_options.sdpa_kernel;
}
ProviderOptions GetProviderOptions(const void* provider_options) override {

View file

@ -394,8 +394,8 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu
}
#if USE_MEMORY_EFFICIENT_ATTENTION
if (data.sequence_length >= contrib::attention::kMinSeqLenForMemoryEfficientAttentionFp32 ||
data.kv_sequence_length >= contrib::attention::kMinSeqLenForMemoryEfficientAttentionFp32) {
if (data.sequence_length >= contrib::attention::kDefaultMinSeqLenForEfficientAttentionFp32 ||
data.kv_sequence_length >= contrib::attention::kDefaultMinSeqLenForEfficientAttentionFp32) {
kernel_type = AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention;
if (!SkipAttentionKernel(data, kernel_type)) {
RunMultiHeadAttentionKernel(

View file

@ -0,0 +1,221 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef DISABLE_CONTRIB_OPS
#include "contrib_ops/cuda/bert/attention_kernel_options.h"
#include "contrib_ops/cpu/bert/attention_common.h"
#include "test/util/include/scoped_env_vars.h"
#include "gtest/gtest.h"
#include <unordered_map>
#include <string>
using onnxruntime::AttentionKernelOptions;
using onnxruntime::contrib::attention::AttentionBackend;
namespace onnxruntime {
namespace test {
TEST(AttentionKernelOptionsTest, NonZeroValue) {
{
AttentionKernelOptions options;
int value = static_cast<int>(AttentionBackend::FLASH_ATTENTION) | static_cast<int>(AttentionBackend::EFFICIENT_ATTENTION);
options.InitializeOnce(value, false);
ASSERT_TRUE(options.UseFlashAttention());
ASSERT_TRUE(options.UseEfficientAttention());
ASSERT_FALSE(options.UseTrtFusedAttention());
ASSERT_FALSE(options.UseCudnnFlashAttention());
ASSERT_FALSE(options.UseUnfusedAttention());
ASSERT_FALSE(options.UseTrtFlashAttention());
ASSERT_FALSE(options.UseTrtCrossAttention());
ASSERT_FALSE(options.UseTrtCausalAttention());
EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 0);
EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 0);
}
{
AttentionKernelOptions options;
int value = static_cast<int>(AttentionBackend::TRT_FUSED_ATTENTION) | static_cast<int>(AttentionBackend::MATH);
options.InitializeOnce(value, false);
ASSERT_FALSE(options.UseFlashAttention());
ASSERT_FALSE(options.UseEfficientAttention());
ASSERT_TRUE(options.UseTrtFusedAttention());
ASSERT_FALSE(options.UseCudnnFlashAttention());
ASSERT_TRUE(options.UseUnfusedAttention());
ASSERT_FALSE(options.UseTrtFlashAttention());
ASSERT_FALSE(options.UseTrtCrossAttention());
ASSERT_FALSE(options.UseTrtCausalAttention());
EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 0);
EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 0);
}
{
AttentionKernelOptions options;
int value = static_cast<int>(AttentionBackend::CUDNN_FLASH_ATTENTION);
options.InitializeOnce(value, false);
ASSERT_FALSE(options.UseFlashAttention());
ASSERT_FALSE(options.UseEfficientAttention());
ASSERT_FALSE(options.UseTrtFusedAttention());
ASSERT_TRUE(options.UseCudnnFlashAttention());
ASSERT_FALSE(options.UseUnfusedAttention());
ASSERT_FALSE(options.UseTrtFlashAttention());
ASSERT_FALSE(options.UseTrtCrossAttention());
ASSERT_FALSE(options.UseTrtCausalAttention());
EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 0);
EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 0);
}
{
AttentionKernelOptions options;
int value = static_cast<int>(AttentionBackend::TRT_FLASH_ATTENTION);
options.InitializeOnce(value, false);
ASSERT_FALSE(options.UseFlashAttention());
ASSERT_FALSE(options.UseEfficientAttention());
ASSERT_FALSE(options.UseTrtFusedAttention());
ASSERT_FALSE(options.UseCudnnFlashAttention());
ASSERT_FALSE(options.UseUnfusedAttention());
ASSERT_TRUE(options.UseTrtFlashAttention());
ASSERT_FALSE(options.UseTrtCrossAttention());
ASSERT_FALSE(options.UseTrtCausalAttention());
EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 0);
EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 0);
}
{
AttentionKernelOptions options;
int value = static_cast<int>(AttentionBackend::TRT_CROSS_ATTENTION) | static_cast<int>(AttentionBackend::TRT_CAUSAL_ATTENTION);
options.InitializeOnce(value, false);
ASSERT_FALSE(options.UseFlashAttention());
ASSERT_FALSE(options.UseEfficientAttention());
ASSERT_FALSE(options.UseTrtFusedAttention());
ASSERT_FALSE(options.UseCudnnFlashAttention());
ASSERT_FALSE(options.UseUnfusedAttention());
ASSERT_FALSE(options.UseTrtFlashAttention());
ASSERT_TRUE(options.UseTrtCrossAttention());
ASSERT_TRUE(options.UseTrtCausalAttention());
EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 0);
EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 0);
}
// Test environment variables are ignored when option value is non-zero
// Test default min sequence lengths are zeros
{
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{
{onnxruntime::contrib::attention::kDisableFlashAttention, "0"},
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"},
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"},
{onnxruntime::contrib::attention::kEnableCudnnFlashAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"},
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"},
{onnxruntime::contrib::attention::kEnableFusedCausalAttention, "1"},
{onnxruntime::contrib::attention::kEnableFusedCausalAttention, "1"}}};
AttentionKernelOptions options;
int value = static_cast<int>(AttentionBackend::FLASH_ATTENTION);
options.InitializeOnce(value, false);
ASSERT_TRUE(options.UseFlashAttention());
ASSERT_FALSE(options.UseEfficientAttention());
ASSERT_FALSE(options.UseTrtFusedAttention());
ASSERT_FALSE(options.UseCudnnFlashAttention());
ASSERT_FALSE(options.UseUnfusedAttention());
ASSERT_FALSE(options.UseTrtFlashAttention());
ASSERT_FALSE(options.UseTrtCrossAttention());
ASSERT_FALSE(options.UseTrtCausalAttention());
EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 0);
EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 0);
}
// Test min sequence lengths can be parsed from environment variables when option value is non-zero
{
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{
{onnxruntime::contrib::attention::kDisableFlashAttention, "1"},
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"},
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"},
{onnxruntime::contrib::attention::kEnableFusedCausalAttention, "0"},
{onnxruntime::contrib::attention::kEnableFusedCausalAttention, "0"},
{onnxruntime::contrib::attention::kMinSeqLenForFlashAttentionPackedQKV, "128"},
{onnxruntime::contrib::attention::kMinSeqLenForEfficientAttentionFp32, "256"}}};
AttentionKernelOptions options;
int value = static_cast<int>(AttentionBackend::FLASH_ATTENTION);
options.InitializeOnce(value, false);
ASSERT_TRUE(options.UseFlashAttention());
ASSERT_FALSE(options.UseEfficientAttention());
ASSERT_FALSE(options.UseTrtFusedAttention());
ASSERT_FALSE(options.UseCudnnFlashAttention());
ASSERT_FALSE(options.UseUnfusedAttention());
ASSERT_FALSE(options.UseTrtFlashAttention());
ASSERT_FALSE(options.UseTrtCrossAttention());
ASSERT_FALSE(options.UseTrtCausalAttention());
EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 128);
EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 256);
}
}
// Test all environment variables take effect when option value is 0.
TEST(AttentionKernelOptionsTest, DefaultOptionWithEnvVar) {
constexpr int value = 0;
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{
{onnxruntime::contrib::attention::kDisableFlashAttention, "0"},
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"},
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"},
{onnxruntime::contrib::attention::kEnableCudnnFlashAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"},
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"},
{onnxruntime::contrib::attention::kEnableFusedCausalAttention, "1"},
{onnxruntime::contrib::attention::kEnableFusedCausalAttention, "1"},
{onnxruntime::contrib::attention::kMinSeqLenForFlashAttentionPackedQKV, "128"},
{onnxruntime::contrib::attention::kMinSeqLenForEfficientAttentionFp32, "256"}}};
AttentionKernelOptions options;
options.InitializeOnce(value, false);
ASSERT_TRUE(options.UseFlashAttention());
ASSERT_TRUE(options.UseEfficientAttention());
ASSERT_TRUE(options.UseTrtFusedAttention());
ASSERT_TRUE(options.UseCudnnFlashAttention());
ASSERT_TRUE(options.UseUnfusedAttention());
ASSERT_TRUE(options.UseTrtFlashAttention());
ASSERT_TRUE(options.UseTrtCrossAttention());
ASSERT_TRUE(options.UseTrtCausalAttention());
ASSERT_TRUE(options.UseTrtCausalAttention());
EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 128);
EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 256);
}
// Test default min sequence lengths when environment variables are not set.
TEST(AttentionKernelOptionsTest, DefaultMinSeqLens) {
constexpr int value = 0;
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{
{onnxruntime::contrib::attention::kDisableFlashAttention, "1"},
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"},
{onnxruntime::contrib::attention::kEnableCudnnFlashAttention, "0"},
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"},
{onnxruntime::contrib::attention::kEnableFusedCausalAttention, "0"},
{onnxruntime::contrib::attention::kEnableFusedCausalAttention, "0"}}};
AttentionKernelOptions options;
options.InitializeOnce(value, false);
ASSERT_FALSE(options.UseFlashAttention());
ASSERT_FALSE(options.UseEfficientAttention());
ASSERT_FALSE(options.UseTrtFusedAttention());
ASSERT_FALSE(options.UseCudnnFlashAttention());
ASSERT_TRUE(options.UseUnfusedAttention());
ASSERT_FALSE(options.UseTrtFlashAttention());
ASSERT_FALSE(options.UseTrtCrossAttention());
ASSERT_FALSE(options.UseTrtCausalAttention());
ASSERT_FALSE(options.UseTrtCausalAttention());
EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(),
onnxruntime::contrib::attention::kDefaultMinSeqLenForFlashAttentionPackedQKV);
EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(),
onnxruntime::contrib::attention::kDefaultMinSeqLenForEfficientAttentionFp32);
}
} // namespace test
} // namespace onnxruntime
#endif

View file

@ -446,6 +446,8 @@ class TestInferenceSession(unittest.TestCase):
test_get_and_set_option_with_values("use_tf32", ["1", "0"])
test_get_and_set_option_with_values("sdpa_kernel", ["0", "1", "2"])
option["gpu_external_alloc"] = "0"
option["gpu_external_free"] = "0"
option["gpu_external_empty_cache"] = "0"