diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index f848df01b4..738ec72998 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -387,6 +387,16 @@ }, "comments": "tensorboard" } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "66d9cddc832c1cdc2b30a8755274f7f74640cfe6", + "repositoryUrl": "https://github.com/NVIDIA/cutlass.git" + }, + "comments": "cutlass" + } } ] } diff --git a/cmake/deps.txt b/cmake/deps.txt index e35c704dc2..f6a9b4a2b2 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -34,3 +34,4 @@ pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/5916273f79a21551890fd re2;https://github.com/google/re2/archive/refs/tags/2022-06-01.zip;aa77313b76e91b531ee7f3e45f004c6a502a5374 safeint;https://github.com/dcleblanc/SafeInt/archive/ff15c6ada150a5018c5ef2172401cb4529eac9c0.zip;913a4046e5274d329af2806cb53194f617d8c0ab tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2493f0949dc4be6b6a45e81.zip;67b833913605a4f3f499894ab11528a702c2b381 +cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v2.11.0.zip;be70c559f07251ba7f33c789dba98872b444c10f diff --git a/cmake/external/cutlass.cmake b/cmake/external/cutlass.cmake index c4d52cfd2c..dc02168b86 100644 --- a/cmake/external/cutlass.cmake +++ b/cmake/external/cutlass.cmake @@ -1,8 +1,9 @@ if (onnxruntime_USE_FLASH_ATTENTION) include(FetchContent) - FetchContent_Declare(cutlass - GIT_REPOSITORY https://github.com/nvidia/cutlass.git - GIT_TAG 8b42e751c63ba219755c8ed91af5f6ec1ecc1ee6 + FetchContent_Declare( + cutlass + URL ${DEP_URL_cutlass} + URL_HASH SHA1=${DEP_SHA1_cutlass} ) FetchContent_GetProperties(cutlass) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index f45bbecfc7..d0137a2049 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -68,6 +68,9 @@ constexpr const char* kDisableTrtFlashAttention = "ORT_DISABLE_TRT_FLASH_ATTENTI // Environment variable to enable or disable cutlass memory efficient attention. Default is 0 (enabled). constexpr const char* kDisableMemoryEfficientAttention = "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION"; +// Minimum sequence length to enable memory efficient attention in FP32. +constexpr int kMinSequenceLengthForMemoryEfficientAttentionFp32 = 256; + } // namespace attention } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 3abacc935d..4a6d2dc137 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -73,6 +73,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { ¶meters, device_prop.maxThreadsPerBlock, past_seq_len)); + assert(parameters.sequence_length == parameters.kv_sequence_length); // self attention int batch_size = parameters.batch_size; int sequence_length = parameters.sequence_length; @@ -107,7 +108,6 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { nullptr == extra_add_qk && parameters.past_sequence_length == 0 && parameters.hidden_size == parameters.v_hidden_size && - parameters.sequence_length == parameters.kv_sequence_length && FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, enable_trt_flash_attention_, true); if (use_causal_fused_runner) { @@ -127,7 +127,6 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { nullptr == present && nullptr == extra_add_qk && parameters.hidden_size == parameters.v_hidden_size && - parameters.sequence_length == parameters.kv_sequence_length && FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, enable_trt_flash_attention_, false); @@ -153,6 +152,8 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { nullptr == past && nullptr == present && nullptr == extra_add_qk && + (sizeof(T) == 2 || // sequence length threshold is 0 in FP16 + parameters.sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32) && has_memory_efficient_attention(sm, sizeof(T) == 2); #else constexpr bool use_memory_efficient_attention = false; diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 363a901858..c7e5d34e16 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -132,9 +132,14 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { } #if USE_FLASH_ATTENTION + bool is_long_sequence = sizeof(T) == 2 || // sequence length threshold is 0 for FP16 + parameters.sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32 || + parameters.kv_sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32; + bool use_memory_efficient_attention = fused_runner == nullptr && fused_cross_attention_kernel == nullptr && !disable_memory_efficient_attention_ && + is_long_sequence && nullptr == key_padding_mask && // TODO: support 1D mask has_memory_efficient_attention(sm, sizeof(T) == 2); #else diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc index 26d832c64a..eddf9ea4c7 100644 --- a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc @@ -221,12 +221,15 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data) { } #if USE_FLASH_ATTENTION - kernel_type = AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention; - if (!SkipAttentionKernel(data, kernel_type)) { - RunMultiHeadAttentionKernel( - data.query_data, data.key_data, data.value_data, data.bias_data, data.key_padding_mask_data, data.mask_type, - data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, - data.hidden_size, data.v_hidden_size, kernel_type, use_float16); + if (data.sequence_length >= contrib::attention::kMinSequenceLengthForMemoryEfficientAttentionFp32 || + data.kv_sequence_length >= contrib::attention::kMinSequenceLengthForMemoryEfficientAttentionFp32) { + kernel_type = AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention; + if (!SkipAttentionKernel(data, kernel_type)) { + RunMultiHeadAttentionKernel( + data.query_data, data.key_data, data.value_data, data.bias_data, data.key_padding_mask_data, data.mask_type, + data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, + data.hidden_size, data.v_hidden_size, kernel_type, use_float16); + } } #endif diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index 99f4462950..0026fa8479 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.18 + version: 1.0.24 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.18 + version: 1.0.24 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here.