mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Upgrade CUTLASS to v2.11 and add sequence length threshold for cutlass FMHA (#14401)
### Description Add sequence length threshold for triggering cutlass FMHA in FP32. See performance test results in https://github.com/microsoft/onnxruntime/pull/14343 to see how this threshold is selected. Upgrade cutlass to v2.11 and update deps.txt and cgmanifest for nuget pipeline build (test build: https://aiinfra.visualstudio.com/Lotus/_build/results?buildId=268574&view=results)
This commit is contained in:
parent
7cc9aed314
commit
94b1791974
8 changed files with 37 additions and 13 deletions
|
|
@ -387,6 +387,16 @@
|
|||
},
|
||||
"comments": "tensorboard"
|
||||
}
|
||||
},
|
||||
{
|
||||
"component": {
|
||||
"type": "git",
|
||||
"git": {
|
||||
"commitHash": "66d9cddc832c1cdc2b30a8755274f7f74640cfe6",
|
||||
"repositoryUrl": "https://github.com/NVIDIA/cutlass.git"
|
||||
},
|
||||
"comments": "cutlass"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -34,3 +34,4 @@ pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/5916273f79a21551890fd
|
|||
re2;https://github.com/google/re2/archive/refs/tags/2022-06-01.zip;aa77313b76e91b531ee7f3e45f004c6a502a5374
|
||||
safeint;https://github.com/dcleblanc/SafeInt/archive/ff15c6ada150a5018c5ef2172401cb4529eac9c0.zip;913a4046e5274d329af2806cb53194f617d8c0ab
|
||||
tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2493f0949dc4be6b6a45e81.zip;67b833913605a4f3f499894ab11528a702c2b381
|
||||
cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v2.11.0.zip;be70c559f07251ba7f33c789dba98872b444c10f
|
||||
|
|
|
|||
7
cmake/external/cutlass.cmake
vendored
7
cmake/external/cutlass.cmake
vendored
|
|
@ -1,8 +1,9 @@
|
|||
if (onnxruntime_USE_FLASH_ATTENTION)
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(cutlass
|
||||
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
|
||||
GIT_TAG 8b42e751c63ba219755c8ed91af5f6ec1ecc1ee6
|
||||
FetchContent_Declare(
|
||||
cutlass
|
||||
URL ${DEP_URL_cutlass}
|
||||
URL_HASH SHA1=${DEP_SHA1_cutlass}
|
||||
)
|
||||
|
||||
FetchContent_GetProperties(cutlass)
|
||||
|
|
|
|||
|
|
@ -68,6 +68,9 @@ constexpr const char* kDisableTrtFlashAttention = "ORT_DISABLE_TRT_FLASH_ATTENTI
|
|||
// Environment variable to enable or disable cutlass memory efficient attention. Default is 0 (enabled).
|
||||
constexpr const char* kDisableMemoryEfficientAttention = "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION";
|
||||
|
||||
// Minimum sequence length to enable memory efficient attention in FP32.
|
||||
constexpr int kMinSequenceLengthForMemoryEfficientAttentionFp32 = 256;
|
||||
|
||||
} // namespace attention
|
||||
|
||||
} // namespace contrib
|
||||
|
|
|
|||
|
|
@ -73,6 +73,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
¶meters,
|
||||
device_prop.maxThreadsPerBlock,
|
||||
past_seq_len));
|
||||
assert(parameters.sequence_length == parameters.kv_sequence_length); // self attention
|
||||
|
||||
int batch_size = parameters.batch_size;
|
||||
int sequence_length = parameters.sequence_length;
|
||||
|
|
@ -107,7 +108,6 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
nullptr == extra_add_qk &&
|
||||
parameters.past_sequence_length == 0 &&
|
||||
parameters.hidden_size == parameters.v_hidden_size &&
|
||||
parameters.sequence_length == parameters.kv_sequence_length &&
|
||||
FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
|
||||
enable_trt_flash_attention_, true);
|
||||
if (use_causal_fused_runner) {
|
||||
|
|
@ -127,7 +127,6 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
nullptr == present &&
|
||||
nullptr == extra_add_qk &&
|
||||
parameters.hidden_size == parameters.v_hidden_size &&
|
||||
parameters.sequence_length == parameters.kv_sequence_length &&
|
||||
FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
|
||||
enable_trt_flash_attention_, false);
|
||||
|
||||
|
|
@ -153,6 +152,8 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
nullptr == past &&
|
||||
nullptr == present &&
|
||||
nullptr == extra_add_qk &&
|
||||
(sizeof(T) == 2 || // sequence length threshold is 0 in FP16
|
||||
parameters.sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32) &&
|
||||
has_memory_efficient_attention(sm, sizeof(T) == 2);
|
||||
#else
|
||||
constexpr bool use_memory_efficient_attention = false;
|
||||
|
|
|
|||
|
|
@ -132,9 +132,14 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
}
|
||||
|
||||
#if USE_FLASH_ATTENTION
|
||||
bool is_long_sequence = sizeof(T) == 2 || // sequence length threshold is 0 for FP16
|
||||
parameters.sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32 ||
|
||||
parameters.kv_sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32;
|
||||
|
||||
bool use_memory_efficient_attention = fused_runner == nullptr &&
|
||||
fused_cross_attention_kernel == nullptr &&
|
||||
!disable_memory_efficient_attention_ &&
|
||||
is_long_sequence &&
|
||||
nullptr == key_padding_mask && // TODO: support 1D mask
|
||||
has_memory_efficient_attention(sm, sizeof(T) == 2);
|
||||
#else
|
||||
|
|
|
|||
|
|
@ -221,12 +221,15 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data) {
|
|||
}
|
||||
|
||||
#if USE_FLASH_ATTENTION
|
||||
kernel_type = AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention;
|
||||
if (!SkipAttentionKernel(data, kernel_type)) {
|
||||
RunMultiHeadAttentionKernel(
|
||||
data.query_data, data.key_data, data.value_data, data.bias_data, data.key_padding_mask_data, data.mask_type,
|
||||
data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length,
|
||||
data.hidden_size, data.v_hidden_size, kernel_type, use_float16);
|
||||
if (data.sequence_length >= contrib::attention::kMinSequenceLengthForMemoryEfficientAttentionFp32 ||
|
||||
data.kv_sequence_length >= contrib::attention::kMinSequenceLengthForMemoryEfficientAttentionFp32) {
|
||||
kernel_type = AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention;
|
||||
if (!SkipAttentionKernel(data, kernel_type)) {
|
||||
RunMultiHeadAttentionKernel(
|
||||
data.query_data, data.key_data, data.value_data, data.bias_data, data.key_padding_mask_data, data.mask_type,
|
||||
data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length,
|
||||
data.hidden_size, data.v_hidden_size, kernel_type, use_float16);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ steps:
|
|||
packageType: upack
|
||||
feed: '/7424c8e4-5c62-490e-95c4-79446f31017c'
|
||||
definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0'
|
||||
version: 1.0.18
|
||||
version: 1.0.24
|
||||
downloadPath: $(Build.BinariesDirectory)/deps
|
||||
|
||||
# The private ADO project
|
||||
|
|
@ -22,7 +22,7 @@ steps:
|
|||
packageType: upack
|
||||
feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325'
|
||||
definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a'
|
||||
version: 1.0.18
|
||||
version: 1.0.24
|
||||
downloadPath: $(Build.BinariesDirectory)/deps
|
||||
|
||||
# You can add more ADO accounts at here.
|
||||
|
|
|
|||
Loading…
Reference in a new issue