Upgrade CUTLASS to v2.11 and add sequence length threshold for cutlass FMHA (#14401)

### Description
Add sequence length threshold for triggering cutlass FMHA in FP32. See
performance test results in
https://github.com/microsoft/onnxruntime/pull/14343 to see how this
threshold is selected.

Upgrade cutlass to v2.11 and update deps.txt and cgmanifest for nuget
pipeline build (test build:
https://aiinfra.visualstudio.com/Lotus/_build/results?buildId=268574&view=results)
This commit is contained in:
Tianlei Wu 2023-01-25 09:43:48 -08:00 committed by GitHub
parent 7cc9aed314
commit 94b1791974
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 37 additions and 13 deletions

View file

@ -387,6 +387,16 @@
},
"comments": "tensorboard"
}
},
{
"component": {
"type": "git",
"git": {
"commitHash": "66d9cddc832c1cdc2b30a8755274f7f74640cfe6",
"repositoryUrl": "https://github.com/NVIDIA/cutlass.git"
},
"comments": "cutlass"
}
}
]
}

View file

@ -34,3 +34,4 @@ pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/5916273f79a21551890fd
re2;https://github.com/google/re2/archive/refs/tags/2022-06-01.zip;aa77313b76e91b531ee7f3e45f004c6a502a5374
safeint;https://github.com/dcleblanc/SafeInt/archive/ff15c6ada150a5018c5ef2172401cb4529eac9c0.zip;913a4046e5274d329af2806cb53194f617d8c0ab
tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2493f0949dc4be6b6a45e81.zip;67b833913605a4f3f499894ab11528a702c2b381
cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v2.11.0.zip;be70c559f07251ba7f33c789dba98872b444c10f

View file

@ -1,8 +1,9 @@
if (onnxruntime_USE_FLASH_ATTENTION)
include(FetchContent)
FetchContent_Declare(cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
GIT_TAG 8b42e751c63ba219755c8ed91af5f6ec1ecc1ee6
FetchContent_Declare(
cutlass
URL ${DEP_URL_cutlass}
URL_HASH SHA1=${DEP_SHA1_cutlass}
)
FetchContent_GetProperties(cutlass)

View file

@ -68,6 +68,9 @@ constexpr const char* kDisableTrtFlashAttention = "ORT_DISABLE_TRT_FLASH_ATTENTI
// Environment variable to enable or disable cutlass memory efficient attention. Default is 0 (enabled).
constexpr const char* kDisableMemoryEfficientAttention = "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION";
// Minimum sequence length to enable memory efficient attention in FP32.
constexpr int kMinSequenceLengthForMemoryEfficientAttentionFp32 = 256;
} // namespace attention
} // namespace contrib

View file

@ -73,6 +73,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
&parameters,
device_prop.maxThreadsPerBlock,
past_seq_len));
assert(parameters.sequence_length == parameters.kv_sequence_length); // self attention
int batch_size = parameters.batch_size;
int sequence_length = parameters.sequence_length;
@ -107,7 +108,6 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
nullptr == extra_add_qk &&
parameters.past_sequence_length == 0 &&
parameters.hidden_size == parameters.v_hidden_size &&
parameters.sequence_length == parameters.kv_sequence_length &&
FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
enable_trt_flash_attention_, true);
if (use_causal_fused_runner) {
@ -127,7 +127,6 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
nullptr == present &&
nullptr == extra_add_qk &&
parameters.hidden_size == parameters.v_hidden_size &&
parameters.sequence_length == parameters.kv_sequence_length &&
FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
enable_trt_flash_attention_, false);
@ -153,6 +152,8 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
nullptr == past &&
nullptr == present &&
nullptr == extra_add_qk &&
(sizeof(T) == 2 || // sequence length threshold is 0 in FP16
parameters.sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32) &&
has_memory_efficient_attention(sm, sizeof(T) == 2);
#else
constexpr bool use_memory_efficient_attention = false;

View file

@ -132,9 +132,14 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
}
#if USE_FLASH_ATTENTION
bool is_long_sequence = sizeof(T) == 2 || // sequence length threshold is 0 for FP16
parameters.sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32 ||
parameters.kv_sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32;
bool use_memory_efficient_attention = fused_runner == nullptr &&
fused_cross_attention_kernel == nullptr &&
!disable_memory_efficient_attention_ &&
is_long_sequence &&
nullptr == key_padding_mask && // TODO: support 1D mask
has_memory_efficient_attention(sm, sizeof(T) == 2);
#else

View file

@ -221,12 +221,15 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data) {
}
#if USE_FLASH_ATTENTION
kernel_type = AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention;
if (!SkipAttentionKernel(data, kernel_type)) {
RunMultiHeadAttentionKernel(
data.query_data, data.key_data, data.value_data, data.bias_data, data.key_padding_mask_data, data.mask_type,
data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length,
data.hidden_size, data.v_hidden_size, kernel_type, use_float16);
if (data.sequence_length >= contrib::attention::kMinSequenceLengthForMemoryEfficientAttentionFp32 ||
data.kv_sequence_length >= contrib::attention::kMinSequenceLengthForMemoryEfficientAttentionFp32) {
kernel_type = AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention;
if (!SkipAttentionKernel(data, kernel_type)) {
RunMultiHeadAttentionKernel(
data.query_data, data.key_data, data.value_data, data.bias_data, data.key_padding_mask_data, data.mask_type,
data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length,
data.hidden_size, data.v_hidden_size, kernel_type, use_float16);
}
}
#endif

View file

@ -11,7 +11,7 @@ steps:
packageType: upack
feed: '/7424c8e4-5c62-490e-95c4-79446f31017c'
definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0'
version: 1.0.18
version: 1.0.24
downloadPath: $(Build.BinariesDirectory)/deps
# The private ADO project
@ -22,7 +22,7 @@ steps:
packageType: upack
feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325'
definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a'
version: 1.0.18
version: 1.0.24
downloadPath: $(Build.BinariesDirectory)/deps
# You can add more ADO accounts at here.