mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-10 00:38:54 +00:00
[CUDA] Attention kernel provider option (#21344)
### Description * Add a cuda provider option `sdpa_kernel` to choose which attention kernel to run for testing purpose. * Allow dump which attention kernel is used per node. * Reserve a flag for cudnn flash attention which will be added soon. #### CUDA provider option sdpa_kernel Instead of setting environment variable, we also support setting it in provider option. Note that the setting is global per session. That could help performance testing of each kernel. #### Attention Kernel Debug Info Set an environment variable `ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO=1`, and ORT will print sdpa kernel used in each node: For example ``` ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO=1 ./onnxruntime_test_all --gtest_filter=MultiHeadAttentionTest* ``` It will show debug information of kernel used in testing: ``` [ RUN ] MultiHeadAttentionTest.SelfAttention_Batch2_HeadSize32_NoBias_NoMask_PackedQKV AttentionKernelOptions: FLASH_ATTENTION=0 EFFICIENT_ATTENTION=0 TRT_FUSED_ATTENTION=1 CUDNN_FLASH_ATTENTION=0 TRT_FLASH_ATTENTION=1 TRT_CROSS_ATTENTION=0 TRT_CAUSAL_ATTENTION=0 MATH=1 Operator=MultiHeadAttention Node=node1 DataType=fp16 TRT_FUSED_ATTENTION=1 AttentionKernelOptions: FLASH_ATTENTION=0 EFFICIENT_ATTENTION=1 TRT_FUSED_ATTENTION=0 CUDNN_FLASH_ATTENTION=0 TRT_FLASH_ATTENTION=0 TRT_CROSS_ATTENTION=0 TRT_CAUSAL_ATTENTION=0 MATH=1 Operator=MultiHeadAttention Node=node1 DataType=fp16 EFFICIENT_ATTENTION=1 ``` In this test case, the debug info shows that one session uses trt fused attention and another session use efficient attention.
This commit is contained in:
parent
01df8c787d
commit
6ffaaebb60
24 changed files with 645 additions and 110 deletions
|
|
@ -15,6 +15,8 @@ set(contrib_ops_excluded_files
|
|||
"bert/attention_softmax.h"
|
||||
"bert/attention_softmax.cu"
|
||||
"bert/attention_prepare_qkv.cu"
|
||||
"bert/attention_kernel_options.h"
|
||||
"bert/attention_kernel_options.cc"
|
||||
"bert/decoder_attention_impl.h"
|
||||
"bert/decoder_attention_impl.cu"
|
||||
"bert/decoder_masked_multihead_attention.h"
|
||||
|
|
|
|||
|
|
@ -786,8 +786,9 @@ if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS)
|
|||
onnxruntime_add_shared_library_module(onnxruntime_providers_cuda_ut ${onnxruntime_test_providers_cuda_ut_src} $<TARGET_OBJECTS:onnxruntime_providers_cuda_obj>)
|
||||
config_cuda_provider_shared_module(onnxruntime_providers_cuda_ut)
|
||||
onnxruntime_add_include_to_target(onnxruntime_providers_cuda_ut GTest::gtest GTest::gmock)
|
||||
add_dependencies(onnxruntime_providers_cuda_ut onnxruntime_test_utils onnxruntime_common)
|
||||
target_include_directories(onnxruntime_providers_cuda_ut PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey)
|
||||
target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common)
|
||||
target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_test_utils onnxruntime_common)
|
||||
if (MSVC)
|
||||
# Cutlass code has an issue with the following:
|
||||
# warning C4100: 'magic': unreferenced formal parameter
|
||||
|
|
|
|||
|
|
@ -38,4 +38,5 @@ struct OrtCUDAProviderOptionsV2 {
|
|||
int prefer_nhwc = 0; // make the CUDA EP NHWC preferred
|
||||
int use_ep_level_unified_stream = 0; // flag specifying if ep level stream is used or not
|
||||
int use_tf32 = 1; // use TF32
|
||||
int sdpa_kernel = 0; // Scaled Dot Product Attention kernel option
|
||||
};
|
||||
|
|
|
|||
|
|
@ -147,6 +147,23 @@ constexpr const char* kDisableSparseAttentionV1 = "ORT_DISABLE_SPARSE_ATTENTION_
|
|||
} // namespace sparse_attention
|
||||
|
||||
namespace attention {
|
||||
|
||||
enum class AttentionBackend : int {
|
||||
FLASH_ATTENTION = 1,
|
||||
EFFICIENT_ATTENTION = 2,
|
||||
TRT_FUSED_ATTENTION = 4,
|
||||
CUDNN_FLASH_ATTENTION = 8, // reserved for cuDNN flash attention.
|
||||
MATH = 16, // unfused kernel cannot be disabled right now.
|
||||
|
||||
// The following kernels might be deprecated in the future.
|
||||
TRT_FLASH_ATTENTION = 32,
|
||||
TRT_CROSS_ATTENTION = 64,
|
||||
TRT_CAUSAL_ATTENTION = 128,
|
||||
};
|
||||
|
||||
// Environment variable to enable debug information of attention kernel to be printed. Default is 0 (disabled).
|
||||
constexpr const char* kEnableAttentionKernelDebugInfo = "ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO";
|
||||
|
||||
// Environment variable to enable or disable TRT fused self attention kernel. Default is 0 (enabled).
|
||||
constexpr const char* kDisableFusedSelfAttention = "ORT_DISABLE_FUSED_ATTENTION";
|
||||
|
||||
|
|
@ -157,6 +174,9 @@ constexpr const char* kDisableFusedCrossAttention = "ORT_DISABLE_FUSED_CROSS_ATT
|
|||
// Note that those causal attention kernels use fp16 accumulation. There is potential accuracy drop using those kernels.
|
||||
constexpr const char* kEnableFusedCausalAttention = "ORT_ENABLE_FUSED_CAUSAL_ATTENTION";
|
||||
|
||||
// Environment variable to enable or disable cuDNN flash attention.
|
||||
constexpr const char* kEnableCudnnFlashAttention = "ORT_ENABLE_CUDNN_FLASH_ATTENTION";
|
||||
|
||||
// Environment variable to enable or disable TRT flash attention. This applies to both self and causal attention. Default is 0 (enabled).
|
||||
constexpr const char* kDisableTrtFlashAttention = "ORT_DISABLE_TRT_FLASH_ATTENTION";
|
||||
|
||||
|
|
@ -166,11 +186,15 @@ constexpr const char* kDisableMemoryEfficientAttention = "ORT_DISABLE_MEMORY_EFF
|
|||
// Environment variable to enable or disable flash attention. Default is 0 (enabled).
|
||||
constexpr const char* kDisableFlashAttention = "ORT_DISABLE_FLASH_ATTENTION";
|
||||
|
||||
// Minimum sequence length to enable memory efficient attention in FP32.
|
||||
constexpr int kMinSeqLenForMemoryEfficientAttentionFp32 = 256;
|
||||
// Minimum sequence length to perfer memory efficient attention when data type is float32
|
||||
constexpr const char* kMinSeqLenForEfficientAttentionFp32 = "ORT_MIN_SEQ_LEN_EFFICIENT_ATTENTION_FP32";
|
||||
|
||||
// Default value for minimum sequence length to enable memory efficient attention in FP32.
|
||||
constexpr int kDefaultMinSeqLenForEfficientAttentionFp32 = 256;
|
||||
|
||||
// Minimum sequence length to prefer flash attention when input format is packed QKV for MultiHeadAttention
|
||||
constexpr const char* kMinSeqLenForFlashAttentionPackedQKV = "ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV";
|
||||
|
||||
// Default value for the above setting.
|
||||
constexpr int kDefaultMinSeqLenForFlashAttentionPackedQKV = 513;
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/shared_inc/fpgeneric.h"
|
||||
#include "core/platform/env_var_utils.h"
|
||||
#include "contrib_ops/cuda/bert/attention_impl.h"
|
||||
#include "contrib_ops/cuda/bert/attention.h"
|
||||
#include "contrib_ops/cuda/bert/bert_padding.h"
|
||||
|
|
@ -40,36 +39,17 @@ REGISTER_KERNEL_TYPED(MLFloat16)
|
|||
|
||||
template <typename T>
|
||||
Attention<T>::Attention(const OpKernelInfo& info) : CudaKernel(info), AttentionBase(info, false) {
|
||||
disable_fused_self_attention_ =
|
||||
sizeof(T) != 2 ||
|
||||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedSelfAttention, false);
|
||||
kernel_options_ = this->GetAttentionKernelOptions();
|
||||
|
||||
enable_trt_flash_attention_ =
|
||||
sizeof(T) == 2 &&
|
||||
!ParseEnvironmentVariableWithDefault<bool>(attention::kDisableTrtFlashAttention, false);
|
||||
disable_fused_self_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtFusedAttention();
|
||||
|
||||
enable_fused_causal_attention_ =
|
||||
sizeof(T) == 2 &&
|
||||
ParseEnvironmentVariableWithDefault<bool>(attention::kEnableFusedCausalAttention, false);
|
||||
enable_trt_flash_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtFlashAttention();
|
||||
|
||||
#if USE_MEMORY_EFFICIENT_ATTENTION
|
||||
disable_memory_efficient_attention_ =
|
||||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableMemoryEfficientAttention, false);
|
||||
#else
|
||||
disable_memory_efficient_attention_ = true;
|
||||
#endif
|
||||
enable_fused_causal_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtCausalAttention();
|
||||
|
||||
#if USE_FLASH_ATTENTION
|
||||
disable_flash_attention_ =
|
||||
sizeof(T) != 2 ||
|
||||
onnxruntime::ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFlashAttention, false);
|
||||
min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault<int>(
|
||||
attention::kMinSeqLenForFlashAttentionPackedQKV,
|
||||
attention::kDefaultMinSeqLenForFlashAttentionPackedQKV);
|
||||
#else
|
||||
disable_flash_attention_ = true;
|
||||
min_seq_len_for_flash_attention_packed_qkv_ = 0;
|
||||
#endif
|
||||
disable_memory_efficient_attention_ = !kernel_options_->UseEfficientAttention();
|
||||
|
||||
disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -134,7 +114,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
parameters.num_heads,
|
||||
parameters.num_heads);
|
||||
// When input is packed QKV format, TensorRT kernel might be faster when sequence length <= 512.
|
||||
if (use_flash_attention && parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) {
|
||||
if (use_flash_attention && parameters.sequence_length < kernel_options_->MinSeqLenForFlashAttentionPackedQkv()) {
|
||||
use_flash_attention = false;
|
||||
}
|
||||
// Allocate buffers
|
||||
|
|
@ -220,7 +200,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
nullptr == past &&
|
||||
nullptr == present &&
|
||||
(nullptr == mask_index || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) &&
|
||||
(sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) &&
|
||||
(sizeof(T) == 2 || parameters.sequence_length >= this->kernel_options_->MinSeqLenForEfficientAttentionFp32()) &&
|
||||
has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.v_head_size);
|
||||
|
||||
if (use_memory_efficient_attention) {
|
||||
|
|
@ -231,6 +211,20 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
constexpr bool use_memory_efficient_attention = false;
|
||||
#endif
|
||||
|
||||
if (kernel_options_->AllowDebugInfo()) {
|
||||
AttentionKernelDebugInfo debug_info;
|
||||
debug_info.use_flash_attention = use_flash_attention;
|
||||
debug_info.use_efficient_attention = use_memory_efficient_attention;
|
||||
if (fused_runner != nullptr) {
|
||||
debug_info.SetTrtFusedKernel(is_unidirectional_, enable_trt_flash_attention_, sequence_length);
|
||||
}
|
||||
|
||||
debug_info.Print("Attention",
|
||||
this->Node().Name(),
|
||||
std::is_same<T, MLFloat16>::value,
|
||||
std::is_same<T, BFloat16>::value);
|
||||
}
|
||||
|
||||
cublasHandle_t cublas = GetCublasHandle(context);
|
||||
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
|
|
@ -268,7 +262,6 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
use_fused_cross_attention,
|
||||
use_memory_efficient_attention);
|
||||
IAllocatorUniquePtr<void> work_space = IAllocator::MakeUniquePtr<void>(allocator, workSpaceSize, false, context->GetComputeStream());
|
||||
;
|
||||
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
AttentionData<CudaT> data;
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
#include "core/providers/cuda/cuda_kernel.h"
|
||||
#include "contrib_ops/cpu/bert/attention_base.h"
|
||||
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h"
|
||||
#include "contrib_ops/cuda/bert/attention_kernel_options.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -27,9 +28,10 @@ class Attention final : public CudaKernel, public AttentionBase {
|
|||
bool enable_trt_flash_attention_;
|
||||
bool enable_fused_causal_attention_;
|
||||
bool disable_memory_efficient_attention_;
|
||||
int min_seq_len_for_flash_attention_packed_qkv_;
|
||||
mutable std::unique_ptr<MHARunner> fused_fp16_runner_;
|
||||
mutable std::once_flag fused_fp16_runner_created_;
|
||||
|
||||
const AttentionKernelOptions* kernel_options_;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
166
onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc
Normal file
166
onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc
Normal file
|
|
@ -0,0 +1,166 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "contrib_ops/cuda/bert/attention_kernel_options.h"
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "contrib_ops/cpu/bert/attention_common.h"
|
||||
#include "core/providers/shared_library/provider_api.h"
|
||||
#include "core/platform/env_var_utils.h"
|
||||
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h"
|
||||
|
||||
using namespace onnxruntime::contrib::attention;
|
||||
|
||||
namespace onnxruntime {
|
||||
void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {
|
||||
if (value > 0) {
|
||||
use_flash_attention_ = (value & static_cast<int>(AttentionBackend::FLASH_ATTENTION)) > 0;
|
||||
use_efficient_attention_ = (value & static_cast<int>(AttentionBackend::EFFICIENT_ATTENTION)) > 0;
|
||||
use_trt_fused_attention_ = (value & static_cast<int>(AttentionBackend::TRT_FUSED_ATTENTION)) > 0;
|
||||
use_cudnn_flash_attention_ = (value & static_cast<int>(AttentionBackend::CUDNN_FLASH_ATTENTION)) > 0;
|
||||
use_unfused_ = (value & static_cast<int>(AttentionBackend::MATH)) > 0;
|
||||
use_trt_flash_attention_ = (value & static_cast<int>(AttentionBackend::TRT_FLASH_ATTENTION)) > 0;
|
||||
use_trt_cross_attention_ = (value & static_cast<int>(AttentionBackend::TRT_CROSS_ATTENTION)) > 0;
|
||||
use_trt_causal_attention_ = (value & static_cast<int>(AttentionBackend::TRT_CAUSAL_ATTENTION)) > 0;
|
||||
} else {
|
||||
use_flash_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFlashAttention, false);
|
||||
use_efficient_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableMemoryEfficientAttention, false);
|
||||
use_trt_fused_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFusedSelfAttention, false);
|
||||
use_cudnn_flash_attention_ = ParseEnvironmentVariableWithDefault<bool>(kEnableCudnnFlashAttention, false);
|
||||
use_unfused_ = true;
|
||||
use_trt_flash_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableTrtFlashAttention, false);
|
||||
use_trt_cross_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFusedCrossAttention, false);
|
||||
use_trt_causal_attention_ = ParseEnvironmentVariableWithDefault<bool>(kEnableFusedCausalAttention, false);
|
||||
}
|
||||
|
||||
enable_kernel_debug_info_ = ParseEnvironmentVariableWithDefault<bool>(kEnableAttentionKernelDebugInfo, false);
|
||||
|
||||
// When value is positive, we use 0 as default minimum sequence lengths to align with common usage in testing.
|
||||
min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault<int>(
|
||||
kMinSeqLenForFlashAttentionPackedQKV,
|
||||
value > 0 ? 0 : kDefaultMinSeqLenForFlashAttentionPackedQKV);
|
||||
|
||||
min_seq_len_for_efficient_attention_fp32_ = ParseEnvironmentVariableWithDefault<int>(
|
||||
kMinSeqLenForEfficientAttentionFp32,
|
||||
value > 0 ? 0 : kDefaultMinSeqLenForEfficientAttentionFp32);
|
||||
|
||||
if (use_build_flag) {
|
||||
// Some kernels can be disabled at build time. If they are disabled, we should not use them.
|
||||
#ifndef USE_FLASH_ATTENTION
|
||||
use_flash_attention_ = false;
|
||||
#endif
|
||||
|
||||
#ifndef USE_MEMORY_EFFICIENT_ATTENTION
|
||||
use_efficient_attention_ = false;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void AttentionKernelOptions::InitializeOnce(
|
||||
int sdpa_kernel, bool use_build_flag) {
|
||||
std::call_once(this->initialize_once_flag_, [&]() {
|
||||
this->Initialize(sdpa_kernel, use_build_flag);
|
||||
if (this->enable_kernel_debug_info_) {
|
||||
this->Print();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void AttentionKernelOptions::Print() const {
|
||||
std::stringstream sstream;
|
||||
sstream << "AttentionKernelOptions:";
|
||||
sstream << " FLASH_ATTENTION=" << int(use_flash_attention_);
|
||||
sstream << " EFFICIENT_ATTENTION=" << int(use_efficient_attention_);
|
||||
sstream << " TRT_FUSED_ATTENTION=" << int(use_trt_fused_attention_);
|
||||
sstream << " CUDNN_FLASH_ATTENTION=" << int(use_cudnn_flash_attention_);
|
||||
sstream << " TRT_FLASH_ATTENTION=" << int(use_trt_flash_attention_);
|
||||
sstream << " TRT_CROSS_ATTENTION=" << int(use_trt_cross_attention_);
|
||||
sstream << " TRT_CAUSAL_ATTENTION=" << int(use_trt_causal_attention_);
|
||||
sstream << " MATH=" << int(use_unfused_);
|
||||
|
||||
if (!use_unfused_) {
|
||||
sstream << std::endl
|
||||
<< "Warning: Unfused kernel cannot be disabled right now. MATH=0 is ignored.";
|
||||
}
|
||||
|
||||
// Output text in Cyan color to make it easier to spot
|
||||
std::cout << "\x1B[36m" << sstream.str() << "\x1B[0m" << std::endl;
|
||||
}
|
||||
|
||||
// Classify the kernel used in TRT fused runner.
|
||||
void AttentionKernelDebugInfo::SetTrtFusedKernel(bool causal, bool enable_trt_flash_attention, int sequence_length) {
|
||||
if (causal) {
|
||||
use_trt_causal_attention = true;
|
||||
} else if (enable_trt_flash_attention && sequence_length >= contrib::cuda::kMinSequenceLengthFlashAttention) {
|
||||
use_trt_flash_attention = true;
|
||||
} else {
|
||||
use_trt_fused_attention = true;
|
||||
}
|
||||
}
|
||||
|
||||
void AttentionKernelDebugInfo::Print(const char* operator_name,
|
||||
const std::string& node_name,
|
||||
bool is_float16,
|
||||
bool is_bfloat16) const {
|
||||
std::stringstream sstream;
|
||||
sstream << "Operator=" << operator_name;
|
||||
|
||||
if (node_name.length() > 0) {
|
||||
sstream << " Node=" << node_name;
|
||||
}
|
||||
|
||||
if (is_bfloat16) {
|
||||
sstream << " DataType=bf16";
|
||||
} else if (is_float16) {
|
||||
sstream << " DataType=fp16";
|
||||
} else {
|
||||
sstream << " DataType=fp32";
|
||||
}
|
||||
|
||||
if (use_flash_attention.has_value() && use_flash_attention.value()) {
|
||||
sstream << " FLASH_ATTENTION=" << int(use_flash_attention.value());
|
||||
}
|
||||
|
||||
if (use_efficient_attention.has_value() && use_efficient_attention.value()) {
|
||||
sstream << " EFFICIENT_ATTENTION=" << int(use_efficient_attention.value());
|
||||
}
|
||||
|
||||
if (use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) {
|
||||
sstream << " TRT_FUSED_ATTENTION=" << int(use_trt_fused_attention.value());
|
||||
}
|
||||
|
||||
if (use_cudnn_flash_attention.has_value() && use_cudnn_flash_attention.value()) {
|
||||
sstream << " CUDNN_FLASH_ATTENTION=" << int(use_cudnn_flash_attention.value());
|
||||
}
|
||||
|
||||
if (use_trt_flash_attention.has_value() && use_trt_flash_attention.value()) {
|
||||
sstream << " TRT_FLASH_ATTENTION=" << int(use_trt_flash_attention.value());
|
||||
}
|
||||
|
||||
if (use_trt_cross_attention.has_value() && use_trt_cross_attention.value()) {
|
||||
sstream << " TRT_CROSS_ATTENTION=" << int(use_trt_cross_attention.value());
|
||||
}
|
||||
|
||||
if (use_trt_causal_attention.has_value() && use_trt_causal_attention.value()) {
|
||||
sstream << " TRT_CAUSAL_ATTENTION=" << int(use_trt_causal_attention.value());
|
||||
}
|
||||
|
||||
bool use_fused = (use_flash_attention.has_value() && use_flash_attention.value()) ||
|
||||
(use_efficient_attention.has_value() && use_efficient_attention.value()) ||
|
||||
(use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) ||
|
||||
(use_cudnn_flash_attention.has_value() && use_cudnn_flash_attention.value()) ||
|
||||
(use_trt_flash_attention.has_value() && use_trt_flash_attention.value()) ||
|
||||
(use_trt_cross_attention.has_value() && use_trt_cross_attention.value()) ||
|
||||
(use_trt_causal_attention.has_value() && use_trt_causal_attention.value());
|
||||
|
||||
// Fall back to unfused when no fused kernel is enabled.
|
||||
if (!use_fused) {
|
||||
sstream << " MATH=1";
|
||||
}
|
||||
|
||||
// Output text in Cyan color to make it easier to spot.
|
||||
std::cout << "\x1B[36m" << sstream.str() << "\x1B[0m" << std::endl;
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
67
onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h
Normal file
67
onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include <mutex>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
namespace onnxruntime {
|
||||
struct AttentionKernelDebugInfo {
|
||||
std::optional<bool> use_flash_attention = std::nullopt;
|
||||
std::optional<bool> use_efficient_attention = std::nullopt;
|
||||
std::optional<bool> use_trt_fused_attention = std::nullopt;
|
||||
std::optional<bool> use_cudnn_flash_attention = std::nullopt;
|
||||
std::optional<bool> use_trt_flash_attention = std::nullopt;
|
||||
std::optional<bool> use_trt_cross_attention = std::nullopt;
|
||||
std::optional<bool> use_trt_causal_attention = std::nullopt;
|
||||
void SetTrtFusedKernel(bool causal, bool enable_trt_flash_attention, int sequence_length);
|
||||
void Print(const char* operator_name, const std::string& node_name, bool is_float16, bool is_bfloat16) const;
|
||||
};
|
||||
|
||||
class AttentionKernelOptions {
|
||||
public:
|
||||
void InitializeOnce(int sdpa_kernel, bool use_build_flag);
|
||||
|
||||
bool UseFlashAttention() const { return use_flash_attention_; }
|
||||
bool UseEfficientAttention() const { return use_efficient_attention_; }
|
||||
bool UseTrtFusedAttention() const { return use_trt_fused_attention_; }
|
||||
bool UseCudnnFlashAttention() const { return use_cudnn_flash_attention_; }
|
||||
bool UseUnfusedAttention() const { return use_unfused_; }
|
||||
bool UseTrtFlashAttention() const { return use_trt_flash_attention_; }
|
||||
bool UseTrtCrossAttention() const { return use_trt_cross_attention_; }
|
||||
bool UseTrtCausalAttention() const { return use_trt_causal_attention_; }
|
||||
|
||||
bool AllowDebugInfo() const { return enable_kernel_debug_info_; }
|
||||
|
||||
int MinSeqLenForFlashAttentionPackedQkv() const { return min_seq_len_for_flash_attention_packed_qkv_; }
|
||||
int MinSeqLenForEfficientAttentionFp32() const { return min_seq_len_for_efficient_attention_fp32_; }
|
||||
|
||||
protected:
|
||||
void Print() const;
|
||||
|
||||
void Initialize(int value, bool use_build_flag);
|
||||
|
||||
private:
|
||||
bool use_flash_attention_{true};
|
||||
bool use_efficient_attention_{true};
|
||||
bool use_trt_fused_attention_{true};
|
||||
bool use_cudnn_flash_attention_{false};
|
||||
bool use_unfused_{true};
|
||||
|
||||
bool use_trt_flash_attention_{true};
|
||||
bool use_trt_cross_attention_{true};
|
||||
|
||||
// Causal attention is disabled by default in #14732.
|
||||
bool use_trt_causal_attention_{false};
|
||||
|
||||
bool enable_kernel_debug_info_{false};
|
||||
|
||||
int min_seq_len_for_flash_attention_packed_qkv_{0};
|
||||
|
||||
int min_seq_len_for_efficient_attention_fp32_{0};
|
||||
|
||||
std::once_flag initialize_once_flag_;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -52,20 +52,13 @@ GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
|
|||
rotary_interleaved_ = info.GetAttrOrDefault<int64_t>("rotary_interleaved", 0) == 1;
|
||||
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);
|
||||
|
||||
#if USE_FLASH_ATTENTION
|
||||
disable_flash_attention_ = sizeof(T) != 2 ||
|
||||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFlashAttention, false);
|
||||
#else
|
||||
disable_flash_attention_ = true;
|
||||
#endif
|
||||
kernel_options_ = this->GetAttentionKernelOptions();
|
||||
|
||||
disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention();
|
||||
|
||||
#if USE_MEMORY_EFFICIENT_ATTENTION
|
||||
// Memory efficient attention only supports float and float16, not bfloat16.
|
||||
disable_memory_efficient_attention_ = std::is_same<T, BFloat16>::value ||
|
||||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableMemoryEfficientAttention, false);
|
||||
#else
|
||||
disable_memory_efficient_attention_ = true;
|
||||
#endif
|
||||
disable_memory_efficient_attention_ = std::is_same<T, BFloat16>::value || !kernel_options_->UseEfficientAttention();
|
||||
|
||||
if (!disable_flash_attention_) {
|
||||
zeros_ = this->GetScratchBuffer<int>(kZerosCount, nullptr);
|
||||
}
|
||||
|
|
@ -161,7 +154,7 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
!use_flash_attention &&
|
||||
!disable_memory_efficient_attention_ &&
|
||||
local_window_size_ == -1 &&
|
||||
(sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) &&
|
||||
(sizeof(T) == 2 || parameters.sequence_length >= this->kernel_options_->MinSeqLenForEfficientAttentionFp32()) &&
|
||||
has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.head_size);
|
||||
if (!use_flash_attention && !use_memory_efficient_attention && local_window_size_ != -1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
|
|
@ -201,6 +194,17 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
auto unpacked_qkv_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
|
||||
#endif
|
||||
|
||||
if (kernel_options_->AllowDebugInfo()) {
|
||||
AttentionKernelDebugInfo debug_info;
|
||||
debug_info.use_flash_attention = use_flash_attention;
|
||||
debug_info.use_efficient_attention = use_memory_efficient_attention;
|
||||
|
||||
debug_info.Print("GroupQueryAttention",
|
||||
this->Node().Name(),
|
||||
std::is_same<T, MLFloat16>::value,
|
||||
std::is_same<T, BFloat16>::value);
|
||||
}
|
||||
|
||||
// seqlens_k buffer
|
||||
size_t seqlens_k_bytes = 0;
|
||||
seqlens_k_bytes = sizeof(int) * parameters.batch_size;
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include <memory>
|
||||
#include "core/providers/cuda/cuda_kernel.h"
|
||||
#include "contrib_ops/cuda/bert/group_query_attention_impl.h"
|
||||
#include "contrib_ops/cuda/bert/attention_kernel_options.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -32,6 +33,7 @@ class GroupQueryAttention final : public CudaKernel {
|
|||
bool disable_memory_efficient_attention_;
|
||||
static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256)
|
||||
IAllocatorUniquePtr<int> zeros_;
|
||||
const AttentionKernelOptions* kernel_options_;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/platform/env_var_utils.h"
|
||||
#include "contrib_ops/cuda/bert/attention_impl.h"
|
||||
#include "contrib_ops/cuda/bert/multihead_attention.h"
|
||||
#include "contrib_ops/cpu/bert/multihead_attention_helper.h"
|
||||
|
|
@ -47,31 +46,16 @@ MultiHeadAttention<T>::MultiHeadAttention(const OpKernelInfo& info)
|
|||
is_unidirectional_ = info.GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;
|
||||
ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support CUDA kernel. Consider using Attention or GQA instead.");
|
||||
|
||||
disable_fused_self_attention_ = sizeof(T) != 2 ||
|
||||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedSelfAttention, false);
|
||||
kernel_options_ = this->GetAttentionKernelOptions();
|
||||
|
||||
enable_trt_flash_attention_ = sizeof(T) == 2 &&
|
||||
!ParseEnvironmentVariableWithDefault<bool>(attention::kDisableTrtFlashAttention, false);
|
||||
disable_fused_self_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtFusedAttention();
|
||||
enable_trt_flash_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtFlashAttention();
|
||||
|
||||
#if USE_FLASH_ATTENTION
|
||||
disable_flash_attention_ = sizeof(T) != 2 ||
|
||||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFlashAttention, false);
|
||||
min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault<int>(
|
||||
attention::kMinSeqLenForFlashAttentionPackedQKV,
|
||||
attention::kDefaultMinSeqLenForFlashAttentionPackedQKV);
|
||||
#else
|
||||
disable_flash_attention_ = true;
|
||||
min_seq_len_for_flash_attention_packed_qkv_ = 0;
|
||||
#endif
|
||||
disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention();
|
||||
|
||||
#if USE_MEMORY_EFFICIENT_ATTENTION
|
||||
disable_memory_efficient_attention_ = ParseEnvironmentVariableWithDefault<bool>(attention::kDisableMemoryEfficientAttention, false);
|
||||
#else
|
||||
disable_memory_efficient_attention_ = true;
|
||||
#endif
|
||||
disable_memory_efficient_attention_ = !kernel_options_->UseEfficientAttention();
|
||||
|
||||
disable_fused_cross_attention_ = sizeof(T) != 2 ||
|
||||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedCrossAttention, false);
|
||||
disable_fused_cross_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtCrossAttention();
|
||||
|
||||
// Allocate cache buffers
|
||||
constexpr size_t cache_bytes = sizeof(int32_t) * (static_cast<size_t>(kCumulatedSequenceLengthCacheMaxBatchSize) + 1);
|
||||
|
|
@ -155,7 +139,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
parameters.num_heads);
|
||||
// When input is packed QKV format, TensorRT kernel might be faster than flash attention when sequence length <= 512.
|
||||
if (use_flash_attention && key == nullptr && value == nullptr &&
|
||||
parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) {
|
||||
parameters.sequence_length < kernel_options_->MinSeqLenForFlashAttentionPackedQkv()) {
|
||||
use_flash_attention = false;
|
||||
}
|
||||
// Allocate buffers
|
||||
|
|
@ -229,9 +213,10 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
}
|
||||
|
||||
#if USE_MEMORY_EFFICIENT_ATTENTION
|
||||
int length_threshold = this->kernel_options_->MinSeqLenForEfficientAttentionFp32();
|
||||
bool is_long_sequence = sizeof(T) == 2 || // sequence length threshold is 0 for FP16
|
||||
parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32 ||
|
||||
parameters.kv_sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32;
|
||||
parameters.sequence_length >= length_threshold ||
|
||||
parameters.kv_sequence_length >= length_threshold;
|
||||
|
||||
bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0;
|
||||
|
||||
|
|
@ -249,6 +234,21 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
constexpr bool use_memory_efficient_attention = false;
|
||||
#endif
|
||||
|
||||
if (kernel_options_->AllowDebugInfo()) {
|
||||
AttentionKernelDebugInfo debug_info;
|
||||
debug_info.use_flash_attention = use_flash_attention;
|
||||
debug_info.use_trt_cross_attention = fused_cross_attention_kernel != nullptr;
|
||||
debug_info.use_efficient_attention = use_memory_efficient_attention;
|
||||
if (fused_fp16_runner_ != nullptr) {
|
||||
debug_info.SetTrtFusedKernel(is_unidirectional_, enable_trt_flash_attention_, sequence_length);
|
||||
}
|
||||
|
||||
debug_info.Print("MultiHeadAttention",
|
||||
this->Node().Name(),
|
||||
std::is_same<T, MLFloat16>::value,
|
||||
std::is_same<T, BFloat16>::value);
|
||||
}
|
||||
|
||||
// When packed kv or packed qkv is used, there is no needed for add bias transpose thus no qkv workspace.
|
||||
// TODO(tianleiwu): flash attention or memory efficient attention might not need qkv workspace sometime.
|
||||
bool no_qkv_workspace = nullptr == value &&
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h"
|
||||
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h"
|
||||
#include "contrib_ops/cuda/bert/attention_impl.h"
|
||||
#include "contrib_ops/cuda/bert/attention_kernel_options.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -31,12 +32,12 @@ class MultiHeadAttention final : public CudaKernel {
|
|||
bool disable_fused_cross_attention_;
|
||||
bool disable_flash_attention_;
|
||||
bool disable_memory_efficient_attention_;
|
||||
int min_seq_len_for_flash_attention_packed_qkv_;
|
||||
mutable std::unique_ptr<MHARunner> fused_fp16_runner_;
|
||||
mutable std::once_flag fused_fp16_runner_created_;
|
||||
mutable const FusedMultiHeadCrossAttentionKernel* fused_fp16_cross_attention_kernel_;
|
||||
mutable CumulatedSequenceLengthCache cumulated_sequence_length_q_cache_;
|
||||
mutable CumulatedSequenceLengthCache cumulated_sequence_length_kv_cache_;
|
||||
const AttentionKernelOptions* kernel_options_;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -33,12 +33,11 @@ REGISTER_KERNEL_TYPED(float)
|
|||
REGISTER_KERNEL_TYPED(MLFloat16)
|
||||
|
||||
template <typename T>
|
||||
TrtFusedAttention<T>::TrtFusedAttention() {
|
||||
disable_fused_runner_ = sizeof(T) != 2 ||
|
||||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedSelfAttention, false);
|
||||
|
||||
enable_trt_flash_attention_ = sizeof(T) == 2 &&
|
||||
!ParseEnvironmentVariableWithDefault<bool>(attention::kDisableTrtFlashAttention, false);
|
||||
TrtFusedAttention<T>::TrtFusedAttention(const OpKernelInfo& info)
|
||||
: CudaKernel(info) {
|
||||
kernel_options_ = this->GetAttentionKernelOptions();
|
||||
disable_fused_runner_ = sizeof(T) != 2 || !kernel_options_->UseTrtFusedAttention();
|
||||
enable_trt_flash_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtFlashAttention();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -86,7 +85,8 @@ template class TrtFusedAttention<float>;
|
|||
template class TrtFusedAttention<MLFloat16>;
|
||||
|
||||
template <typename T>
|
||||
PackedAttention<T>::PackedAttention(const OpKernelInfo& info) : TrtFusedAttention<T>(), CudaKernel(info) {
|
||||
PackedAttention<T>::PackedAttention(const OpKernelInfo& info)
|
||||
: TrtFusedAttention<T>(info) {
|
||||
int64_t num_heads = 0;
|
||||
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
|
||||
num_heads_ = static_cast<int32_t>(num_heads);
|
||||
|
|
@ -268,7 +268,7 @@ Status PackedAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
const Tensor* relative_position_bias = context->Input<Tensor>(5);
|
||||
|
||||
PackedAttentionParameters parameters;
|
||||
parameters.use_tf32 = UseTF32();
|
||||
parameters.use_tf32 = this->UseTF32();
|
||||
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(),
|
||||
weights->Shape(),
|
||||
bias->Shape(),
|
||||
|
|
@ -295,6 +295,19 @@ Status PackedAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
}
|
||||
#endif
|
||||
|
||||
if (this->kernel_options_->AllowDebugInfo()) {
|
||||
AttentionKernelDebugInfo debug_info;
|
||||
debug_info.use_efficient_attention = use_memory_efficient_attention;
|
||||
if (fused_runner != nullptr) {
|
||||
debug_info.SetTrtFusedKernel(false /*causal*/, this->enable_trt_flash_attention_, parameters.sequence_length);
|
||||
}
|
||||
|
||||
debug_info.Print("PackedAttention",
|
||||
this->Node().Name(),
|
||||
std::is_same<T, MLFloat16>::value,
|
||||
std::is_same<T, BFloat16>::value);
|
||||
}
|
||||
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
CudaT one = ToCudaType<T>::FromFloat(1.0f);
|
||||
CudaT zero = ToCudaType<T>::FromFloat(0.0f);
|
||||
|
|
@ -313,7 +326,7 @@ Status PackedAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
|
||||
reinterpret_cast<const CudaT*>(weights->Data<T>()), n,
|
||||
reinterpret_cast<const CudaT*>(input->Data<T>()), k,
|
||||
&zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop, UseTF32()));
|
||||
&zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop, this->UseTF32()));
|
||||
|
||||
constexpr size_t element_size = sizeof(T);
|
||||
constexpr bool no_qkv_workspace = false; // need workspace to add bias
|
||||
|
|
@ -341,7 +354,7 @@ Status PackedAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
data.fused_runner = reinterpret_cast<void*>(fused_runner);
|
||||
data.use_memory_efficient_attention = use_memory_efficient_attention;
|
||||
|
||||
return QkvToContext<CudaT>(device_prop, cublas, Stream(context), parameters, data);
|
||||
return QkvToContext<CudaT>(device_prop, cublas, this->Stream(context), parameters, data);
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
#include "core/providers/cuda/cuda_kernel.h"
|
||||
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h"
|
||||
#include "contrib_ops/cpu/bert/attention_common.h"
|
||||
#include "contrib_ops/cuda/bert/attention_kernel_options.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -17,14 +18,16 @@ namespace cuda {
|
|||
using namespace onnxruntime::cuda;
|
||||
|
||||
template <typename T>
|
||||
class TrtFusedAttention {
|
||||
class TrtFusedAttention : public CudaKernel {
|
||||
public:
|
||||
TrtFusedAttention();
|
||||
TrtFusedAttention(const OpKernelInfo& info);
|
||||
|
||||
protected:
|
||||
MHARunner* GetFusedRunner(const cudaDeviceProp& device_prop, const PackedAttentionParameters& parameters) const;
|
||||
|
||||
protected:
|
||||
const AttentionKernelOptions* kernel_options_;
|
||||
|
||||
bool disable_fused_runner_;
|
||||
bool enable_trt_flash_attention_;
|
||||
mutable std::unique_ptr<MHARunner> fused_fp16_runner_;
|
||||
|
|
@ -32,7 +35,7 @@ class TrtFusedAttention {
|
|||
};
|
||||
|
||||
template <typename T>
|
||||
class PackedAttention final : public TrtFusedAttention<T>, public CudaKernel {
|
||||
class PackedAttention final : public TrtFusedAttention<T> {
|
||||
public:
|
||||
PackedAttention(const OpKernelInfo& info);
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
|
|
|||
|
|
@ -35,30 +35,16 @@ REGISTER_KERNEL_TYPED(MLFloat16)
|
|||
|
||||
template <typename T>
|
||||
PackedMultiHeadAttention<T>::PackedMultiHeadAttention(const OpKernelInfo& info)
|
||||
: TrtFusedAttention<T>(), CudaKernel(info) {
|
||||
: TrtFusedAttention<T>(info) {
|
||||
int64_t num_heads = 0;
|
||||
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
|
||||
num_heads_ = static_cast<int32_t>(num_heads);
|
||||
|
||||
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);
|
||||
|
||||
#if USE_FLASH_ATTENTION
|
||||
disable_flash_attention_ = sizeof(T) != 2 || onnxruntime::ParseEnvironmentVariableWithDefault<bool>(
|
||||
attention::kDisableFlashAttention, false);
|
||||
min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault<int>(
|
||||
attention::kMinSeqLenForFlashAttentionPackedQKV,
|
||||
attention::kDefaultMinSeqLenForFlashAttentionPackedQKV);
|
||||
#else
|
||||
disable_flash_attention_ = true;
|
||||
min_seq_len_for_flash_attention_packed_qkv_ = 0;
|
||||
#endif
|
||||
disable_flash_attention_ = sizeof(T) != 2 || !this->kernel_options_->UseFlashAttention();
|
||||
|
||||
#if USE_MEMORY_EFFICIENT_ATTENTION
|
||||
disable_memory_efficient_attention_ = onnxruntime::ParseEnvironmentVariableWithDefault<bool>(
|
||||
attention::kDisableMemoryEfficientAttention, false);
|
||||
#else
|
||||
disable_memory_efficient_attention_ = true;
|
||||
#endif
|
||||
disable_memory_efficient_attention_ = !this->kernel_options_->UseEfficientAttention();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -228,7 +214,7 @@ Status PackedMultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) co
|
|||
const Tensor* relative_position_bias = context->Input<Tensor>(6);
|
||||
|
||||
PackedAttentionParameters parameters;
|
||||
parameters.use_tf32 = UseTF32();
|
||||
parameters.use_tf32 = this->UseTF32();
|
||||
ORT_RETURN_IF_ERROR(CheckInputs(query->Shape(),
|
||||
key,
|
||||
value,
|
||||
|
|
@ -255,7 +241,7 @@ Status PackedMultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) co
|
|||
|
||||
// When input is packed QKV format, TensorRT kernel might be faster when sequence length <= 512.
|
||||
if (use_flash_attention && key == nullptr && value == nullptr &&
|
||||
parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) {
|
||||
parameters.sequence_length < this->kernel_options_->MinSeqLenForFlashAttentionPackedQkv()) {
|
||||
use_flash_attention = false;
|
||||
}
|
||||
}
|
||||
|
|
@ -271,11 +257,25 @@ Status PackedMultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) co
|
|||
bool is_good_for_rpb = !parameters.has_relative_position_bias || parameters.sequence_length % (4 * sizeof(T)) == 0;
|
||||
use_memory_efficient_attention =
|
||||
is_good_for_rpb &&
|
||||
(sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) &&
|
||||
(sizeof(T) == 2 || parameters.sequence_length >= this->kernel_options_->MinSeqLenForEfficientAttentionFp32()) &&
|
||||
has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.v_head_size);
|
||||
}
|
||||
#endif
|
||||
|
||||
if (this->kernel_options_->AllowDebugInfo()) {
|
||||
AttentionKernelDebugInfo debug_info;
|
||||
debug_info.use_flash_attention = use_flash_attention;
|
||||
debug_info.use_efficient_attention = use_memory_efficient_attention;
|
||||
if (fused_runner != nullptr) {
|
||||
debug_info.SetTrtFusedKernel(false /*causal*/, this->enable_trt_flash_attention_, parameters.sequence_length);
|
||||
}
|
||||
|
||||
debug_info.Print("PackedMultiHeadAttention",
|
||||
this->Node().Name(),
|
||||
std::is_same<T, MLFloat16>::value,
|
||||
std::is_same<T, BFloat16>::value);
|
||||
}
|
||||
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
|
||||
cublasHandle_t cublas = this->GetCublasHandle(context);
|
||||
|
|
|
|||
|
|
@ -4,13 +4,14 @@
|
|||
#pragma once
|
||||
|
||||
#include "contrib_ops/cuda/bert/packed_attention.h"
|
||||
#include "contrib_ops/cuda/bert/attention_kernel_options.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
template <typename T>
|
||||
class PackedMultiHeadAttention final : public TrtFusedAttention<T>, public CudaKernel {
|
||||
class PackedMultiHeadAttention final : public TrtFusedAttention<T> {
|
||||
public:
|
||||
PackedMultiHeadAttention(const OpKernelInfo& info);
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
|
@ -32,7 +33,6 @@ class PackedMultiHeadAttention final : public TrtFusedAttention<T>, public CudaK
|
|||
|
||||
bool disable_memory_efficient_attention_;
|
||||
bool disable_flash_attention_;
|
||||
int min_seq_len_for_flash_attention_packed_qkv_;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -17,6 +17,10 @@
|
|||
#include "core/providers/cuda/shared_inc/cuda_call.h"
|
||||
#include "core/providers/cuda/tunable/cuda_tuning_context.h"
|
||||
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
#include "contrib_ops/cuda/bert/attention_kernel_options.h"
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
void RunOnUnload(std::function<void()> function);
|
||||
|
|
@ -80,6 +84,14 @@ class CUDAExecutionProvider : public IExecutionProvider {
|
|||
bool IsNHWCPreferred() const { return info_.prefer_nhwc; }
|
||||
bool UseTF32() const { return info_.use_tf32; }
|
||||
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
// Attention kernel options parsed from sdpa_kernel cuda provider option.
|
||||
const AttentionKernelOptions* GetAttentionKernelOptions() const {
|
||||
attention_kernel_options_.InitializeOnce(info_.sdpa_kernel, true);
|
||||
return &attention_kernel_options_;
|
||||
}
|
||||
#endif
|
||||
|
||||
ProviderOptions GetProviderOptions() const override {
|
||||
return CUDAExecutionProviderInfo::ToProviderOptions(info_);
|
||||
}
|
||||
|
|
@ -110,6 +122,11 @@ class CUDAExecutionProvider : public IExecutionProvider {
|
|||
// the tuning context might be altered when calling into a TunableOp
|
||||
mutable cuda::tunable::CudaTuningContext tuning_context_;
|
||||
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
// Attention kernel options parsed from sdpa_kernel cuda provider option.
|
||||
mutable AttentionKernelOptions attention_kernel_options_;
|
||||
#endif
|
||||
|
||||
class PerThreadContext final {
|
||||
public:
|
||||
PerThreadContext(OrtDevice::DeviceId device_id, cudaStream_t stream, size_t cuda_mem_limit, ArenaExtendStrategy arena_extend_strategy,
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ constexpr const char* kEnableSkipLayerNormStrictMode = "enable_skip_layer_norm_s
|
|||
constexpr const char* kPreferNHWCMode = "prefer_nhwc";
|
||||
constexpr const char* kUseEPLevelUnifiedStream = "use_ep_level_unified_stream";
|
||||
constexpr const char* kUseTF32 = "use_tf32";
|
||||
constexpr const char* kSdpaKernel = "sdpa_kernel";
|
||||
|
||||
} // namespace provider_option_names
|
||||
} // namespace cuda
|
||||
|
|
@ -117,6 +118,7 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P
|
|||
.AddAssignmentToReference(cuda::provider_option_names::kPreferNHWCMode, info.prefer_nhwc)
|
||||
.AddAssignmentToReference(cuda::provider_option_names::kUseEPLevelUnifiedStream, info.use_ep_level_unified_stream)
|
||||
.AddAssignmentToReference(cuda::provider_option_names::kUseTF32, info.use_tf32)
|
||||
.AddAssignmentToReference(cuda::provider_option_names::kSdpaKernel, info.sdpa_kernel)
|
||||
.AddValueParser(
|
||||
cuda::provider_option_names::kTunableOpEnable,
|
||||
[&info](const std::string& value_str) -> Status {
|
||||
|
|
@ -170,6 +172,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecution
|
|||
{cuda::provider_option_names::kPreferNHWCMode, MakeStringWithClassicLocale(info.prefer_nhwc)},
|
||||
{cuda::provider_option_names::kUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)},
|
||||
{cuda::provider_option_names::kUseTF32, MakeStringWithClassicLocale(info.use_tf32)},
|
||||
{cuda::provider_option_names::kSdpaKernel, MakeStringWithClassicLocale(info.sdpa_kernel)},
|
||||
};
|
||||
|
||||
return options;
|
||||
|
|
@ -192,6 +195,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const OrtCUDAProvid
|
|||
{cuda::provider_option_names::kPreferNHWCMode, MakeStringWithClassicLocale(info.prefer_nhwc)},
|
||||
{cuda::provider_option_names::kUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)},
|
||||
{cuda::provider_option_names::kUseTF32, MakeStringWithClassicLocale(info.use_tf32)},
|
||||
{cuda::provider_option_names::kSdpaKernel, MakeStringWithClassicLocale(info.sdpa_kernel)},
|
||||
};
|
||||
|
||||
return options;
|
||||
|
|
|
|||
|
|
@ -79,6 +79,8 @@ struct CUDAExecutionProviderInfo {
|
|||
// By default, enable TF32 to speed up float GEMM/MatMul or cuDNN convolution of float matrices.
|
||||
bool use_tf32{true};
|
||||
|
||||
int sdpa_kernel{0};
|
||||
|
||||
static CUDAExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
|
||||
static ProviderOptions ToProviderOptions(const CUDAExecutionProviderInfo& info);
|
||||
static ProviderOptions ToProviderOptions(const OrtCUDAProviderOptionsV2& info);
|
||||
|
|
@ -91,6 +93,7 @@ struct std::hash<::onnxruntime::CUDAExecutionProviderInfo> {
|
|||
size_t value{0xbc9f1d34}; // seed
|
||||
|
||||
// Bits: device_id (16), arena_extend_strategy/cudnn_conv_algo_search (reserved 2), boolean options (1 each)
|
||||
// Do not exceed 32 bits here otherwise some bits will be lost in x86.
|
||||
size_t data = static_cast<size_t>(info.device_id) ^
|
||||
(static_cast<size_t>(info.arena_extend_strategy) << 16) ^
|
||||
(static_cast<size_t>(info.cudnn_conv_algo_search) << 18) ^
|
||||
|
|
@ -109,6 +112,7 @@ struct std::hash<::onnxruntime::CUDAExecutionProviderInfo> {
|
|||
|
||||
onnxruntime::HashCombine(info.gpu_mem_limit, value);
|
||||
onnxruntime::HashCombine(info.tunable_op.max_tuning_duration_ms, value);
|
||||
onnxruntime::HashCombine(info.sdpa_kernel, value);
|
||||
|
||||
// Memory pointers
|
||||
onnxruntime::HashCombine(reinterpret_cast<size_t>(info.user_compute_stream), value);
|
||||
|
|
|
|||
|
|
@ -94,6 +94,12 @@ class CudaKernel : public OpKernel {
|
|||
return provider_->UseTF32();
|
||||
}
|
||||
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
const AttentionKernelOptions* GetAttentionKernelOptions() const {
|
||||
return provider_->GetAttentionKernelOptions();
|
||||
}
|
||||
#endif
|
||||
|
||||
tunable::CudaTuningContext* GetTuningContext() const {
|
||||
return static_cast<tunable::CudaTuningContext*>(provider_->GetTuningContext());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -226,6 +226,7 @@ struct CUDA_Provider : Provider {
|
|||
info.enable_skip_layer_norm_strict_mode = params->enable_skip_layer_norm_strict_mode != 0;
|
||||
info.use_ep_level_unified_stream = params->use_ep_level_unified_stream != 0;
|
||||
info.use_tf32 = params->use_tf32 != 0;
|
||||
info.sdpa_kernel = params->sdpa_kernel;
|
||||
|
||||
return std::make_shared<CUDAProviderFactory>(info);
|
||||
}
|
||||
|
|
@ -260,6 +261,7 @@ struct CUDA_Provider : Provider {
|
|||
cuda_options.prefer_nhwc = internal_options.prefer_nhwc;
|
||||
cuda_options.use_ep_level_unified_stream = internal_options.use_ep_level_unified_stream;
|
||||
cuda_options.use_tf32 = internal_options.use_tf32;
|
||||
cuda_options.sdpa_kernel = internal_options.sdpa_kernel;
|
||||
}
|
||||
|
||||
ProviderOptions GetProviderOptions(const void* provider_options) override {
|
||||
|
|
|
|||
|
|
@ -394,8 +394,8 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu
|
|||
}
|
||||
|
||||
#if USE_MEMORY_EFFICIENT_ATTENTION
|
||||
if (data.sequence_length >= contrib::attention::kMinSeqLenForMemoryEfficientAttentionFp32 ||
|
||||
data.kv_sequence_length >= contrib::attention::kMinSeqLenForMemoryEfficientAttentionFp32) {
|
||||
if (data.sequence_length >= contrib::attention::kDefaultMinSeqLenForEfficientAttentionFp32 ||
|
||||
data.kv_sequence_length >= contrib::attention::kDefaultMinSeqLenForEfficientAttentionFp32) {
|
||||
kernel_type = AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention;
|
||||
if (!SkipAttentionKernel(data, kernel_type)) {
|
||||
RunMultiHeadAttentionKernel(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,221 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
|
||||
#include "contrib_ops/cuda/bert/attention_kernel_options.h"
|
||||
#include "contrib_ops/cpu/bert/attention_common.h"
|
||||
#include "test/util/include/scoped_env_vars.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
|
||||
using onnxruntime::AttentionKernelOptions;
|
||||
using onnxruntime::contrib::attention::AttentionBackend;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
TEST(AttentionKernelOptionsTest, NonZeroValue) {
|
||||
{
|
||||
AttentionKernelOptions options;
|
||||
int value = static_cast<int>(AttentionBackend::FLASH_ATTENTION) | static_cast<int>(AttentionBackend::EFFICIENT_ATTENTION);
|
||||
options.InitializeOnce(value, false);
|
||||
ASSERT_TRUE(options.UseFlashAttention());
|
||||
ASSERT_TRUE(options.UseEfficientAttention());
|
||||
ASSERT_FALSE(options.UseTrtFusedAttention());
|
||||
ASSERT_FALSE(options.UseCudnnFlashAttention());
|
||||
ASSERT_FALSE(options.UseUnfusedAttention());
|
||||
ASSERT_FALSE(options.UseTrtFlashAttention());
|
||||
ASSERT_FALSE(options.UseTrtCrossAttention());
|
||||
ASSERT_FALSE(options.UseTrtCausalAttention());
|
||||
EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 0);
|
||||
EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 0);
|
||||
}
|
||||
|
||||
{
|
||||
AttentionKernelOptions options;
|
||||
int value = static_cast<int>(AttentionBackend::TRT_FUSED_ATTENTION) | static_cast<int>(AttentionBackend::MATH);
|
||||
options.InitializeOnce(value, false);
|
||||
ASSERT_FALSE(options.UseFlashAttention());
|
||||
ASSERT_FALSE(options.UseEfficientAttention());
|
||||
ASSERT_TRUE(options.UseTrtFusedAttention());
|
||||
ASSERT_FALSE(options.UseCudnnFlashAttention());
|
||||
ASSERT_TRUE(options.UseUnfusedAttention());
|
||||
ASSERT_FALSE(options.UseTrtFlashAttention());
|
||||
ASSERT_FALSE(options.UseTrtCrossAttention());
|
||||
ASSERT_FALSE(options.UseTrtCausalAttention());
|
||||
EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 0);
|
||||
EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 0);
|
||||
}
|
||||
|
||||
{
|
||||
AttentionKernelOptions options;
|
||||
int value = static_cast<int>(AttentionBackend::CUDNN_FLASH_ATTENTION);
|
||||
options.InitializeOnce(value, false);
|
||||
ASSERT_FALSE(options.UseFlashAttention());
|
||||
ASSERT_FALSE(options.UseEfficientAttention());
|
||||
ASSERT_FALSE(options.UseTrtFusedAttention());
|
||||
ASSERT_TRUE(options.UseCudnnFlashAttention());
|
||||
ASSERT_FALSE(options.UseUnfusedAttention());
|
||||
ASSERT_FALSE(options.UseTrtFlashAttention());
|
||||
ASSERT_FALSE(options.UseTrtCrossAttention());
|
||||
ASSERT_FALSE(options.UseTrtCausalAttention());
|
||||
EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 0);
|
||||
EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 0);
|
||||
}
|
||||
|
||||
{
|
||||
AttentionKernelOptions options;
|
||||
int value = static_cast<int>(AttentionBackend::TRT_FLASH_ATTENTION);
|
||||
options.InitializeOnce(value, false);
|
||||
ASSERT_FALSE(options.UseFlashAttention());
|
||||
ASSERT_FALSE(options.UseEfficientAttention());
|
||||
ASSERT_FALSE(options.UseTrtFusedAttention());
|
||||
ASSERT_FALSE(options.UseCudnnFlashAttention());
|
||||
ASSERT_FALSE(options.UseUnfusedAttention());
|
||||
ASSERT_TRUE(options.UseTrtFlashAttention());
|
||||
ASSERT_FALSE(options.UseTrtCrossAttention());
|
||||
ASSERT_FALSE(options.UseTrtCausalAttention());
|
||||
EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 0);
|
||||
EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 0);
|
||||
}
|
||||
|
||||
{
|
||||
AttentionKernelOptions options;
|
||||
int value = static_cast<int>(AttentionBackend::TRT_CROSS_ATTENTION) | static_cast<int>(AttentionBackend::TRT_CAUSAL_ATTENTION);
|
||||
options.InitializeOnce(value, false);
|
||||
ASSERT_FALSE(options.UseFlashAttention());
|
||||
ASSERT_FALSE(options.UseEfficientAttention());
|
||||
ASSERT_FALSE(options.UseTrtFusedAttention());
|
||||
ASSERT_FALSE(options.UseCudnnFlashAttention());
|
||||
ASSERT_FALSE(options.UseUnfusedAttention());
|
||||
ASSERT_FALSE(options.UseTrtFlashAttention());
|
||||
ASSERT_TRUE(options.UseTrtCrossAttention());
|
||||
ASSERT_TRUE(options.UseTrtCausalAttention());
|
||||
EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 0);
|
||||
EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 0);
|
||||
}
|
||||
|
||||
// Test environment variables are ignored when option value is non-zero
|
||||
// Test default min sequence lengths are zeros
|
||||
{
|
||||
ScopedEnvironmentVariables scoped_env_vars{
|
||||
EnvVarMap{
|
||||
{onnxruntime::contrib::attention::kDisableFlashAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kEnableCudnnFlashAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kEnableFusedCausalAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kEnableFusedCausalAttention, "1"}}};
|
||||
AttentionKernelOptions options;
|
||||
int value = static_cast<int>(AttentionBackend::FLASH_ATTENTION);
|
||||
options.InitializeOnce(value, false);
|
||||
ASSERT_TRUE(options.UseFlashAttention());
|
||||
ASSERT_FALSE(options.UseEfficientAttention());
|
||||
ASSERT_FALSE(options.UseTrtFusedAttention());
|
||||
ASSERT_FALSE(options.UseCudnnFlashAttention());
|
||||
ASSERT_FALSE(options.UseUnfusedAttention());
|
||||
ASSERT_FALSE(options.UseTrtFlashAttention());
|
||||
ASSERT_FALSE(options.UseTrtCrossAttention());
|
||||
ASSERT_FALSE(options.UseTrtCausalAttention());
|
||||
EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 0);
|
||||
EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 0);
|
||||
}
|
||||
|
||||
// Test min sequence lengths can be parsed from environment variables when option value is non-zero
|
||||
{
|
||||
ScopedEnvironmentVariables scoped_env_vars{
|
||||
EnvVarMap{
|
||||
{onnxruntime::contrib::attention::kDisableFlashAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kEnableFusedCausalAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kEnableFusedCausalAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kMinSeqLenForFlashAttentionPackedQKV, "128"},
|
||||
{onnxruntime::contrib::attention::kMinSeqLenForEfficientAttentionFp32, "256"}}};
|
||||
AttentionKernelOptions options;
|
||||
int value = static_cast<int>(AttentionBackend::FLASH_ATTENTION);
|
||||
options.InitializeOnce(value, false);
|
||||
ASSERT_TRUE(options.UseFlashAttention());
|
||||
ASSERT_FALSE(options.UseEfficientAttention());
|
||||
ASSERT_FALSE(options.UseTrtFusedAttention());
|
||||
ASSERT_FALSE(options.UseCudnnFlashAttention());
|
||||
ASSERT_FALSE(options.UseUnfusedAttention());
|
||||
ASSERT_FALSE(options.UseTrtFlashAttention());
|
||||
ASSERT_FALSE(options.UseTrtCrossAttention());
|
||||
ASSERT_FALSE(options.UseTrtCausalAttention());
|
||||
EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 128);
|
||||
EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 256);
|
||||
}
|
||||
}
|
||||
|
||||
// Test all environment variables take effect when option value is 0.
|
||||
TEST(AttentionKernelOptionsTest, DefaultOptionWithEnvVar) {
|
||||
constexpr int value = 0;
|
||||
ScopedEnvironmentVariables scoped_env_vars{
|
||||
EnvVarMap{
|
||||
{onnxruntime::contrib::attention::kDisableFlashAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kEnableCudnnFlashAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kEnableFusedCausalAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kEnableFusedCausalAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kMinSeqLenForFlashAttentionPackedQKV, "128"},
|
||||
{onnxruntime::contrib::attention::kMinSeqLenForEfficientAttentionFp32, "256"}}};
|
||||
AttentionKernelOptions options;
|
||||
options.InitializeOnce(value, false);
|
||||
ASSERT_TRUE(options.UseFlashAttention());
|
||||
ASSERT_TRUE(options.UseEfficientAttention());
|
||||
ASSERT_TRUE(options.UseTrtFusedAttention());
|
||||
ASSERT_TRUE(options.UseCudnnFlashAttention());
|
||||
ASSERT_TRUE(options.UseUnfusedAttention());
|
||||
ASSERT_TRUE(options.UseTrtFlashAttention());
|
||||
ASSERT_TRUE(options.UseTrtCrossAttention());
|
||||
ASSERT_TRUE(options.UseTrtCausalAttention());
|
||||
ASSERT_TRUE(options.UseTrtCausalAttention());
|
||||
EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(), 128);
|
||||
EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(), 256);
|
||||
}
|
||||
|
||||
// Test default min sequence lengths when environment variables are not set.
|
||||
TEST(AttentionKernelOptionsTest, DefaultMinSeqLens) {
|
||||
constexpr int value = 0;
|
||||
ScopedEnvironmentVariables scoped_env_vars{
|
||||
EnvVarMap{
|
||||
{onnxruntime::contrib::attention::kDisableFlashAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kEnableCudnnFlashAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kEnableFusedCausalAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kEnableFusedCausalAttention, "0"}}};
|
||||
AttentionKernelOptions options;
|
||||
options.InitializeOnce(value, false);
|
||||
ASSERT_FALSE(options.UseFlashAttention());
|
||||
ASSERT_FALSE(options.UseEfficientAttention());
|
||||
ASSERT_FALSE(options.UseTrtFusedAttention());
|
||||
ASSERT_FALSE(options.UseCudnnFlashAttention());
|
||||
ASSERT_TRUE(options.UseUnfusedAttention());
|
||||
ASSERT_FALSE(options.UseTrtFlashAttention());
|
||||
ASSERT_FALSE(options.UseTrtCrossAttention());
|
||||
ASSERT_FALSE(options.UseTrtCausalAttention());
|
||||
ASSERT_FALSE(options.UseTrtCausalAttention());
|
||||
EXPECT_EQ(options.MinSeqLenForFlashAttentionPackedQkv(),
|
||||
onnxruntime::contrib::attention::kDefaultMinSeqLenForFlashAttentionPackedQKV);
|
||||
EXPECT_EQ(options.MinSeqLenForEfficientAttentionFp32(),
|
||||
onnxruntime::contrib::attention::kDefaultMinSeqLenForEfficientAttentionFp32);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
||||
#endif
|
||||
|
|
@ -446,6 +446,8 @@ class TestInferenceSession(unittest.TestCase):
|
|||
|
||||
test_get_and_set_option_with_values("use_tf32", ["1", "0"])
|
||||
|
||||
test_get_and_set_option_with_values("sdpa_kernel", ["0", "1", "2"])
|
||||
|
||||
option["gpu_external_alloc"] = "0"
|
||||
option["gpu_external_free"] = "0"
|
||||
option["gpu_external_empty_cache"] = "0"
|
||||
|
|
|
|||
Loading…
Reference in a new issue