diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index 953a1948d2..ce8b9bf4b1 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -392,7 +392,7 @@ "component": { "type": "git", "git": { - "commitHash": "66d9cddc832c1cdc2b30a8755274f7f74640cfe6", + "commitHash": "c4f6b8c6bc94ff69048492fb34df0dfaf1983933", "repositoryUrl": "https://github.com/NVIDIA/cutlass.git" }, "comments": "cutlass" diff --git a/cmake/deps.txt b/cmake/deps.txt index 9d91eb97d2..a97ce38c7e 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -34,7 +34,7 @@ pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/5916273f79a21551890fd re2;https://github.com/google/re2/archive/refs/tags/2022-06-01.zip;aa77313b76e91b531ee7f3e45f004c6a502a5374 safeint;https://github.com/dcleblanc/SafeInt/archive/ff15c6ada150a5018c5ef2172401cb4529eac9c0.zip;913a4046e5274d329af2806cb53194f617d8c0ab tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2493f0949dc4be6b6a45e81.zip;67b833913605a4f3f499894ab11528a702c2b381 -cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v2.11.0.zip;be70c559f07251ba7f33c789dba98872b444c10f +cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.0.0.zip;0f95b3c1fc1bd1175c4a90b2c9e39074d1bccefd # below are deps introduced by triton client, might remove after 1.14 release openssl;https://github.com/openssl/openssl/archive/refs/tags/openssl-3.0.7.zip;dda8fc81308555410505eb4a9eab3e1da0436a1d rapidjson;https://github.com/Tencent/rapidjson/archive/refs/tags/v1.1.0.zip;0fe7b4f7b83df4b3d517f4a202f3a383af7a0818 diff --git a/cmake/external/cutlass.cmake b/cmake/external/cutlass.cmake index dc02168b86..18ac668bb1 100644 --- a/cmake/external/cutlass.cmake +++ b/cmake/external/cutlass.cmake @@ -4,6 +4,7 @@ if (onnxruntime_USE_FLASH_ATTENTION) cutlass URL ${DEP_URL_cutlass} URL_HASH SHA1=${DEP_SHA1_cutlass} + PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/cutlass/cutlass.patch ) FetchContent_GetProperties(cutlass) diff --git a/cmake/patches/cutlass/cutlass.patch b/cmake/patches/cutlass/cutlass.patch new file mode 100644 index 0000000000..bda1de8b46 --- /dev/null +++ b/cmake/patches/cutlass/cutlass.patch @@ -0,0 +1,92 @@ +diff --git a/include/cute/numeric/complex.hpp b/include/cute/numeric/complex.hpp +index 3790ebd3..cf727d09 100644 +--- a/include/cute/numeric/complex.hpp ++++ b/include/cute/numeric/complex.hpp +@@ -41,10 +41,14 @@ + // With CUDA 11.4, builds show spurious "-Wconversion" warnings + // on line 656 of thrust/detail/type_traits.h. + // These pragmas suppress the warnings. ++#ifdef __GNUC__ + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wconversion" ++#endif + #include ++#ifdef __GNUC__ + #pragma GCC diagnostic pop ++#endif + + #include + +diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h +index 59aec46a..8f2a913a 100644 +--- a/include/cutlass/functional.h ++++ b/include/cutlass/functional.h +@@ -89,7 +89,7 @@ struct multiplies { + } + }; + +-#if defined(__CUDA_ARCH__) ++#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + /// Partial specializations needed when __CUDA_NO_HALF2_OPERATORS__ is set + template<> + struct plus<__half2> { +@@ -143,12 +143,12 @@ struct multiplies<__half> { + + + // Maximum with nan propogation +-// To propgate the NANs, the "max" of a two element that contains NaNs should also return a NaN ++// To propgate the NANs, the "max" of a two element that contains NaNs should also return a NaN + template + struct maximum_with_nan_propogation { + CUTLASS_HOST_DEVICE + T operator()(T const &lhs, T const &rhs) const { +- return lhs > rhs or std::isnan(lhs) ? lhs : rhs; ++ return lhs > rhs or isnan(lhs) ? lhs : rhs; + } + }; + +@@ -160,7 +160,7 @@ struct maximum_with_nan_propogation { + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + asm volatile("max.NaN.f32 %0, %1, %2;\n" : "=f"(res) : "f"(lhs), "f"(rhs)); + #else +- res = lhs > rhs or std::isnan(lhs) ? lhs : rhs; ++ res = lhs > rhs or isnan(lhs) ? lhs : rhs; + #endif + return res; + } +@@ -233,7 +233,7 @@ struct negate { + } + }; + +-/// Greater equal ++/// Greater equal + template + struct greater_equal { + CUTLASS_HOST_DEVICE +@@ -242,7 +242,7 @@ struct greater_equal { + } + }; + +-/// Greater ++/// Greater + template + struct greater { + CUTLASS_HOST_DEVICE +@@ -251,7 +251,7 @@ struct greater { + } + }; + +-/// Less equal ++/// Less equal + template + struct less_equal { + CUTLASS_HOST_DEVICE +@@ -260,7 +260,7 @@ struct less_equal { + } + }; + +-/// Less ++/// Less + template + struct less { + CUTLASS_HOST_DEVICE diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md old mode 100644 new mode 100755 index 656f0e86d2..7bf1e3d0f6 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -155,7 +155,7 @@ This version of the operator has been available since version 1 of the 'com.micr
bias (optional) : T
Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) for input projection
mask_index (optional) : M
-
Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length), or index with shape (batch_size) or (2 * batch_size)
+
Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length), or index with shape (batch_size) or (2 * batch_size) or (3 * batch_size + 2)
past (optional) : T
past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size)When past_present_share_buffer is set, its shape is (2, batch_size, num_heads, max_sequence_length, head_size)
relative_position_bias (optional) : T
@@ -2404,7 +2404,7 @@ This version of the operator has been available since version 1 of the 'com.micr
bias (optional) : T
Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection
key_padding_mask (optional) : M
-
Key padding mask with shape (batch_size) or (batch_size, kv_sequence_length)
+
Key padding mask with shape (batch_size) or (3 * batch_size + 2) or (batch_size, kv_sequence_length)
relative_position_bias (optional) : T
relative position bias: addition to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length) or (1, num_heads, sequence_length, total_sequence_length)
past_key (optional) : T
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc index 07f3b49b4e..00e843ffb9 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc @@ -41,7 +41,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, // For mask_index, the following shapes are supported: // NULL, (B, 1), (1, 1) - // (B), (2 * B), + // (B), (2 * B), (3 * B + 2) // (B, T) // (B, S, T) // (B, 1, M, M) @@ -274,11 +274,13 @@ Status AttentionBase::CheckMask(const Tensor* mask_index, int64_t total_sequence_length) const { const auto& mask_dims = mask_index->Shape().GetDims(); if (mask_dims.size() == 1) { - if (mask_dims[0] != batch_size && mask_dims[0] != 2 * batch_size) { + if (mask_dims[0] != batch_size && mask_dims[0] != 2 * batch_size && mask_dims[0] != 3 * batch_size + 2) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Inputs 'mask_index' with 1D data shall have length of batch_size or 2 * batch_size"); + "Inputs 'mask_index' with 1D data shall have length of batch_size or 2 * batch_size or 3 * batch_size + 2"); } - mask_type = (mask_dims[0] == batch_size ? AttentionMaskType::MASK_1D_KEY_SEQ_LEN : AttentionMaskType::MASK_1D_END_START); + mask_type = (mask_dims[0] == batch_size ? + AttentionMaskType::MASK_1D_KEY_SEQ_LEN : + mask_dims[0] == 2 * batch_size ? AttentionMaskType::MASK_1D_END_START : AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START); } else if (mask_dims.size() == 2) { if (mask_dims[0] == batch_size && mask_dims[1] == total_sequence_length) { mask_type = AttentionMaskType::MASK_2D_KEY_PADDING; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 680f875d23..1b52ff2a0f 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -7,13 +7,16 @@ namespace onnxruntime { namespace contrib { enum AttentionMaskType { - MASK_NONE, // No mask - MASK_1D_KEY_SEQ_LEN, // [batch_size], key sequence length - MASK_1D_END_START, // [2 * batch_size] with end positions and start positions - MASK_2D_DUMMY, // dummy mask with shape [1, 1] or [batch_size, 1]. It has same effect as no mask. - MASK_2D_KEY_PADDING, // [batch_size, total_sequence_length] - MASK_3D_ATTENTION, // [batch_size, sequence_length, total_sequence_length] - MASK_4D_MEGATRON, // Megatron causal mask with shape [batch_size, 1, max_sequence_length, max_sequence_length] + MASK_NONE, // No mask + MASK_1D_KEY_SEQ_LEN, // [batch_size], key sequence length + MASK_1D_END_START, // [2 * batch_size] with end positions and start positions + MASK_1D_KEY_SEQ_LEN_START, // [3 * batch_size + 2] with [key_len[0], ..., key_len[batch_size - 1], query_start[0], + // ..., query_start[batch_size - 1], query_end[batch_size - 1], key_start[0], ..., + // key_start[batch_size - 1], key_end[batch_size - 1]] + MASK_2D_DUMMY, // dummy mask with shape [1, 1] or [batch_size, 1]. It has same effect as no mask. + MASK_2D_KEY_PADDING, // [batch_size, total_sequence_length] + MASK_3D_ATTENTION, // [batch_size, sequence_length, total_sequence_length] + MASK_4D_MEGATRON, // Megatron causal mask with shape [batch_size, 1, max_sequence_length, max_sequence_length] MASK_UNKNOWN }; diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index 80e506b499..cc7dad81b4 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -25,7 +25,7 @@ Status CheckInputs(const T* query, float mask_filter_value, float scale, int max_threads_per_block) { - // key_padding_mask (K/V) : (B) or (B, L) or None + // key_padding_mask (K/V) : (B) or (2*B + 1) or (B, L) or None // relative_position_bias : (B, 1, S, L) // past_key : (B, N, S*, H) // past_value : (B, N, S*, H) @@ -188,8 +188,12 @@ Status CheckInputs(const T* query, if (key_padding_mask != nullptr) { mask_type = AttentionMaskType::MASK_UNKNOWN; const auto& mask_dims = key_padding_mask->Shape().GetDims(); - if (mask_dims.size() == 1 && mask_dims[0] == static_cast(batch_size)) { - mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN; + if (mask_dims.size() == 1) { + if (mask_dims[0] == static_cast(batch_size)) { + mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN; + } else if (mask_dims[0] == static_cast(3 * batch_size + 2)) { + mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START; + } } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) && mask_dims[1] == static_cast(kv_sequence_length)) { mask_type = AttentionMaskType::MASK_2D_KEY_PADDING; } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 11c982e53a..def1508ca2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -102,6 +102,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { // Check whether we can use fused kernel int sm = device_prop.major * 10 + device_prop.minor; bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN; + bool is_mask_1d_key_seq_len_start = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START; if (is_unidirectional_ && enable_fused_causal_attention_) { // GPT // GPT fused kernels requires left side padding. mask can be: @@ -151,12 +152,13 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { } #if USE_FLASH_ATTENTION + bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0; bool use_memory_efficient_attention = fused_runner == nullptr && !disable_memory_efficient_attention_ && - nullptr == mask_index && // TODO: support 1D mask + (nullptr == mask_index || is_mask_1d_key_seq_len_start) && nullptr == past && nullptr == present && - nullptr == relative_position_bias && + (nullptr == relative_position_bias || is_good_for_rpb) && (sizeof(T) == 2 || // sequence length threshold is 0 in FP16 parameters.sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32) && has_memory_efficient_attention(sm, sizeof(T) == 2); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 28daf5d4af..c7127feced 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -445,6 +445,14 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, DUMP_TENSOR_D("value", data.value, batch_size * kv_sequence_length, num_heads, v_head_size); DUMP_TENSOR_D("value_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size); + if (data.relative_position_bias != nullptr && parameters.broadcast_res_pos_bias) { + DUMP_TENSOR_D("relative_position_bias", data.relative_position_bias, num_heads, sequence_length, kv_sequence_length); + } + + if (data.mask_index != nullptr && parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) { + DUMP_TENSOR_D("mask_index", data.mask_index, 3 * batch_size + 2, 1); + } + if (data.fused_cross_attention_kernel != nullptr) { assert(qk_head_size == v_head_size); @@ -735,11 +743,14 @@ Status QkvToContext( return Status::OK(); } + // For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation. + const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) + : parameters.scale; + #if USE_FLASH_ATTENTION if (data.use_memory_efficient_attention) { // We only enable fused cross attention when there is no key padding mask. // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. - assert(data.mask_index == nullptr); assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH); const void* query = q; @@ -754,23 +765,26 @@ Status QkvToContext( MemoryEfficientAttentionParams p; p.sm = device_prop.major * 10 + device_prop.minor; p.is_half = sizeof(T) == 2; - p.batch_size = data.mask_index == nullptr ? parameters.batch_size : 2 * parameters.batch_size; + p.batch_size = parameters.batch_size; p.num_heads = parameters.num_heads; p.sequence_length = parameters.sequence_length; p.kv_sequence_length = parameters.total_sequence_length; p.qk_head_size = parameters.head_size; p.v_head_size = parameters.v_head_size; p.causal = parameters.is_unidirectional; - p.cu_seqlens_q = nullptr; - p.cu_seqlens_k = nullptr; + p.scale = scale; + p.seqlen_k_ptr = nullptr == data.mask_index ? nullptr : const_cast(reinterpret_cast(data.mask_index)); + p.seqstart_q_ptr = nullptr == data.mask_index ? nullptr : const_cast(reinterpret_cast(data.mask_index + batch_size)); + p.seqstart_k_ptr = nullptr == data.mask_index ? nullptr : const_cast(reinterpret_cast(data.mask_index + 2 * batch_size + 1)); p.query = query; p.key = key; p.value = value; + p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias; + p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; p.output = data.output; p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) ? scratch1 : nullptr; p.stream = stream; run_memory_efficient_attention(p); - DUMP_TENSOR("cutlass output", data.output, batch_size * sequence_length, num_heads, v_head_size); return Status::OK(); } @@ -789,9 +803,6 @@ Status QkvToContext( float one = 1.0f; float zero = 0.f; - // For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation. - const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) - : parameters.scale; float alpha = use_raw_attention_mask ? one : scale; cublasSetStream(cublas, stream); diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index 17f4665a80..ed38cabc46 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -10,7 +10,7 @@ #endif #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" -#include "contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h" +#include "41_fused_multi_head_attention/kernel_forward.h" namespace onnxruntime { namespace contrib { @@ -24,8 +24,10 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { p.query_ptr = const_cast(reinterpret_cast(params.query)); p.key_ptr = const_cast(reinterpret_cast(params.key)); p.value_ptr = const_cast(reinterpret_cast(params.value)); - p.cu_seqlens_q_ptr = params.cu_seqlens_q; - p.cu_seqlens_k_ptr = params.cu_seqlens_k; + p.attn_bias_ptr = const_cast(reinterpret_cast(params.attn_bias)); + p.seqstart_q_ptr = params.seqstart_q_ptr; + p.seqstart_k_ptr = params.seqstart_k_ptr; + p.seqlen_k_ptr = params.seqlen_k_ptr; p.logsumexp_ptr = nullptr; // [num_heads, num_queries] for backward or nullptr for forward p.output_ptr = reinterpret_cast(params.output); @@ -42,28 +44,32 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { p.head_dim = params.qk_head_size; p.head_dim_value = params.v_head_size; + p.scale = params.scale; + // When params.cu_seqlens_q is provided, num_queries is max_seq_q and num_keys will be set inside the kernel p.num_queries = params.sequence_length; p.num_keys = params.kv_sequence_length; - p.causal = params.causal; + if (params.causal) { + p.custom_mask_type = Attention::CausalFromTopLeft; + } // Input format is BxSxNxH, output is BxSxNxH p.q_strideH = params.qk_head_size; p.k_strideH = params.qk_head_size; p.v_strideH = params.v_head_size; - p.o_strideH = params.v_head_size; + p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys; p.q_strideM = params.num_heads * params.qk_head_size; p.k_strideM = params.num_heads * params.qk_head_size; p.v_strideM = params.num_heads * params.v_head_size; + p.o_strideM = params.num_heads * params.v_head_size; + p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys; p.q_strideB = static_cast(p.q_strideM) * params.sequence_length; p.k_strideB = static_cast(p.k_strideM) * params.kv_sequence_length; p.v_strideB = static_cast(p.v_strideM) * params.kv_sequence_length; - p.o_strideB = static_cast(params.num_heads) * params.v_head_size * params.sequence_length; - - p.causal = params.causal; + p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; } constexpr auto kernel_fn = attention_kernel_batched_impl; diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h deleted file mode 100644 index 7885983f99..0000000000 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h +++ /dev/null @@ -1,947 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holdvr nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#if USE_FLASH_ATTENTION - -#include -#include - -#include "cutlass/bfloat16.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/vector.h" -#include "cutlass/numeric_types.h" - -#include "41_fused_multi_head_attention/attention_scaling_coefs_updater.h" -#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" -#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -#include "cutlass/gemm/device/default_gemm_configuration.h" -#include "cutlass/gemm/kernel/default_gemm.h" -#include "cutlass/gemm/threadblock/default_mma.h" -#include "cutlass/gemm/threadblock/default_mma_core_simt.h" -#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" -#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" -#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/platform/platform.h" -#include "cutlass/transform/threadblock/predicated_tile_iterator.h" -#include "41_fused_multi_head_attention/debug_utils.h" -#include "41_fused_multi_head_attention/epilogue_pipelined.h" -#include "41_fused_multi_head_attention/epilogue_rescale_output.h" -#include "41_fused_multi_head_attention/find_default_mma.h" -#include "41_fused_multi_head_attention/gemm_kernel_utils.h" -#include "41_fused_multi_head_attention/mma_from_smem.h" - -#include - -using namespace gemm_kernel_utils; - -namespace { -template -constexpr int getWarpsPerSm() { - return ( - Arch::kMinComputeCapability >= 80 && - !cutlass::platform::is_same::value - ? 16 - : 12); -} -} // namespace - -template < - // The datatype of Q/K/V - typename scalar_t_, - // Architecture we are targeting (eg `cutlass::arch::Sm80`) - typename ArchTag, - // If Q/K/V are correctly aligned in memory and we can run a fast kernel - bool isAligned_, - int kQueriesPerBlock, - int kKeysPerBlock, - bool kSingleValueIteration // = `value.shape[-1] <= kKeysPerBlock` - > -struct AttentionKernel { - using scalar_t = scalar_t_; - using accum_t = float; - using lse_scalar_t = float; - using output_t = scalar_t; - // Accumulator between 2 iterations - // Using `accum_t` improves perf on f16 at the cost of - // numerical errors - using output_accum_t = accum_t; - static constexpr bool kIsAligned = isAligned_; - static constexpr int32_t kAlignLSE = 32; // block size of backward - static constexpr bool kPreloadV = ArchTag::kMinComputeCapability >= 80 && - cutlass::sizeof_bits::value == 16; - static constexpr bool kKeepOutputInRF = kSingleValueIteration; - static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF && - !cutlass::platform::is_same::value; - - static_assert(kQueriesPerBlock % 32 == 0, ""); - static_assert(kKeysPerBlock % 32 == 0, ""); - static constexpr int kNumWarpsPerBlock = - kQueriesPerBlock * kKeysPerBlock / (32 * 32); - static constexpr int kWarpSize = 32; - - // Launch bounds - static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock; - static constexpr int kMinBlocksPerSm = - getWarpsPerSm() / kNumWarpsPerBlock; - - struct Params { - // Input tensors - scalar_t* query_ptr; // [num_queries, num_heads, head_dim] - scalar_t* key_ptr; // [num_keys, num_heads, head_dim] - scalar_t* value_ptr; // [num_keys, num_heads, head_dim_value] - int32_t* cu_seqlens_q_ptr = nullptr; - int32_t* cu_seqlens_k_ptr = nullptr; - - // Output tensors - output_t* output_ptr; // [num_queries, num_heads, head_dim_value] - output_accum_t* - output_accum_ptr; // [num_queries, num_heads, head_dim_value] - lse_scalar_t* logsumexp_ptr; // [num_heads, num_queries] - can be null - - // Dimensions/strides - int32_t head_dim; - int32_t head_dim_value; - int32_t num_queries; - int32_t num_keys; - - bool causal; - - int32_t q_strideM; - int32_t k_strideM; - int32_t v_strideM; - - // Everything below is only used in `advance_to_block` - // and shouldn't use registers - int32_t q_strideH; - int32_t k_strideH; - int32_t v_strideH; - int32_t o_strideH; - int64_t q_strideB; - int64_t k_strideB; - int64_t v_strideB; - int64_t o_strideB; - int32_t num_batches; - int32_t num_heads; - - // https://github.com/NVIDIA/cutlass/issues/771 - CUTLASS_HOST_DEVICE int32_t o_strideM() const { - return head_dim_value * num_heads; - } - - // Moves pointers to what we should process - // Returns "false" if there is no work to do - CUTLASS_DEVICE bool advance_to_block() { - auto batch_id = blockIdx.z; - auto head_id = blockIdx.y; - auto query_start = blockIdx.x * kQueriesPerBlock; - - auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE; - - int64_t q_start, k_start; - // Advance to current batch - in case of different sequence lengths - if (cu_seqlens_q_ptr != nullptr) { - assert(cu_seqlens_k_ptr != nullptr); - cu_seqlens_q_ptr += batch_id; - cu_seqlens_k_ptr += batch_id; - q_start = cu_seqlens_q_ptr[0]; - k_start = cu_seqlens_k_ptr[0]; - int64_t q_next_start = cu_seqlens_q_ptr[1]; - int64_t k_next_start = cu_seqlens_k_ptr[1]; - num_queries = q_next_start - q_start; - num_keys = k_next_start - k_start; - - if (query_start >= num_queries) { - return false; - } - } else { - query_ptr += batch_id * q_strideB; - key_ptr += batch_id * k_strideB; - value_ptr += batch_id * v_strideB; - output_ptr += batch_id * o_strideB; - if (output_accum_ptr != nullptr) { - output_accum_ptr += batch_id * o_strideB; - } - q_start = 0; - k_start = 0; - } - - // Advance to the current batch / head / query_start - query_ptr += (q_start + query_start) * q_strideM + head_id * q_strideH; - key_ptr += k_start * k_strideM + head_id * k_strideH; - value_ptr += k_start * v_strideM + head_id * v_strideH; - output_ptr += int64_t(q_start + query_start) * o_strideM() + - head_id * o_strideH; - - if (output_accum_ptr != nullptr) { - output_accum_ptr += int64_t(q_start + query_start) * o_strideM() + - head_id * o_strideH; - } else { - // Accumulate directly in the destination buffer (eg for f32) - output_accum_ptr = (accum_t*)output_ptr; - } - if (logsumexp_ptr != nullptr) { - // lse[batch_id, head_id, query_start] - logsumexp_ptr += - batch_id * lse_dim * num_heads + head_id * lse_dim + query_start; - } - - num_queries -= query_start; - if (causal) { - num_keys = cutlass::fast_min( - int32_t(query_start + kQueriesPerBlock), num_keys); - } - num_batches = 0; // no longer used after - - // Make sure the compiler knows these variables are the same on all - // the threads of the warp. - query_ptr = warp_uniform(query_ptr); - key_ptr = warp_uniform(key_ptr); - value_ptr = warp_uniform(value_ptr); - output_ptr = warp_uniform(output_ptr); - output_accum_ptr = warp_uniform(output_accum_ptr); - logsumexp_ptr = warp_uniform(logsumexp_ptr); - num_queries = warp_uniform(num_queries); - num_keys = warp_uniform(num_keys); - head_dim = warp_uniform(head_dim); - head_dim_value = warp_uniform(head_dim_value); - return true; - } - - __host__ dim3 getBlocksGrid() const { - return dim3( - ceil_div(num_queries, (int32_t)kQueriesPerBlock), - num_heads, - num_batches); - } - __host__ dim3 getThreadsGrid() const { - return dim3(kWarpSize, kNumWarpsPerBlock, 1); - } - }; - - struct MM0 { - /* - In this first matmul, we compute a block of `Q @ K.T`. - While the calculation result is still hot in registers, we update - `mi`, `m_prime`, `s_prime` in shared-memory, and then store this value - into a shared-memory ("AccumulatorSharedStorage") that is used later as - operand A for the second matmul (see MM1) - */ - using GemmType = DefaultGemmType; - - using OpClass = typename GemmType::OpClass; - using DefaultConfig = - typename cutlass::gemm::device::DefaultGemmConfiguration< - OpClass, - ArchTag, - scalar_t, - scalar_t, - scalar_t, // ElementC - accum_t // ElementAccumulator - >; - static constexpr int kAlignmentA = - kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment; - static constexpr int kAlignmentB = - kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; - using ThreadblockShape = cutlass::gemm:: - GemmShape; - using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; - using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma< - scalar_t, // ElementA, - cutlass::layout::RowMajor, // LayoutA, - kAlignmentA, - scalar_t, // ElementB, - cutlass::layout::ColumnMajor, // LayoutB, - kAlignmentB, - accum_t, - cutlass::layout::RowMajor, // LayoutC, - OpClass, - ArchTag, // ArchTag - ThreadblockShape, // ThreadblockShape - WarpShape, // WarpShape - typename GemmType::InstructionShape, // InstructionShape - DefaultConfig::kStages, // Should use `DefaultConfig::kStages`, but that - // uses too much smem - typename GemmType::Operator // Operator - >::DefaultMma; - using MmaCore = typename DefaultMma::MmaCore; - using IteratorA = typename DefaultMma::IteratorA; - using IteratorB = typename DefaultMma::IteratorB; - using Mma = typename DefaultMma::ThreadblockMma; - using ScalingCoefsUpdater = typename DefaultAttentionScalingCoefsUpdater< - typename Mma::Operator::IteratorC, - accum_t, - kWarpSize>::Updater; - static_assert( - MmaCore::WarpCount::kM * MmaCore::WarpCount::kN * - MmaCore::WarpCount::kK == - kNumWarpsPerBlock, - ""); - - // Epilogue to store to shared-memory in a format that we can use later for - // the second matmul - using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< - typename Mma::Operator::IteratorC, - typename Mma::Operator, - scalar_t, - WarpShape, - ThreadblockShape>; - using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; - }; - - struct MM1 { - /** - Second matmul: perform `attn @ V` where `attn` is the attention (not - normalized) and stored in shared memory - */ - using GemmType = DefaultGemmType; - - using OpClass = typename GemmType::OpClass; - using DefaultConfig = - typename cutlass::gemm::device::DefaultGemmConfiguration< - OpClass, - ArchTag, - scalar_t, - scalar_t, - output_accum_t, // ElementC - accum_t // ElementAccumulator - >; - static constexpr int kAlignmentA = DefaultConfig::kAlignmentA; // from smem - static constexpr int kAlignmentB = - kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; - using ThreadblockShape = cutlass::gemm:: - GemmShape; - using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; - using InstructionShape = typename GemmType::InstructionShape; - - using LayoutB = cutlass::layout::RowMajor; - using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< - scalar_t, // ElementA, - cutlass::layout::RowMajor, // LayoutA, - kAlignmentA, - scalar_t, // ElementB, - LayoutB, // LayoutB, - kAlignmentB, - output_accum_t, - cutlass::layout::RowMajor, // LayoutC, - accum_t, - OpClass, - ArchTag, - ThreadblockShape, - WarpShape, - typename GemmType::InstructionShape, - typename DefaultConfig::EpilogueOutputOp, - void, // ThreadblockSwizzle - not used - DefaultConfig::kStages, - false, // SplitKSerial - typename GemmType::Operator>; - - using DefaultMmaFromSmem = - typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< - typename DefaultGemm::Mma, - typename MM0::AccumulatorSharedStorage>; - using Mma = typename DefaultMmaFromSmem::Mma; - using IteratorB = typename Mma::IteratorB; - using WarpCount = typename Mma::WarpCount; - static_assert( - WarpCount::kM * WarpCount::kN * WarpCount::kK == kNumWarpsPerBlock, - ""); - - using DefaultEpilogue = typename DefaultGemm::Epilogue; - using OutputTileIterator = - typename cutlass::epilogue::threadblock::PredicatedTileIterator< - typename DefaultEpilogue::OutputTileIterator::ThreadMap, - output_t>; - using OutputTileIteratorAccum = - typename cutlass::epilogue::threadblock::PredicatedTileIterator< - typename DefaultEpilogue::OutputTileIterator::ThreadMap, - output_accum_t>; - - struct SharedStorageMM1 { - typename Mma::SharedStorage mm; - }; - }; - - static constexpr int64_t kAlignmentQ = MM0::kAlignmentA; - static constexpr int64_t kAlignmentK = MM0::kAlignmentB; - static constexpr int64_t kAlignmentV = 1; - - // Shared storage - depends on kernel params - struct ScalingCoefs { - cutlass::Array m_prime; - cutlass::Array s_prime; - cutlass::Array mi; - }; - - struct SharedStorageEpilogueAtEnd : ScalingCoefs { - struct SharedStorageAfterMM0 { - // Everything here might be overwritten during MM0 - typename MM0::AccumulatorSharedStorage si; - typename MM1::SharedStorageMM1 mm1; - }; - - union { - typename MM0::Mma::SharedStorage mm0; - SharedStorageAfterMM0 after_mm0; - typename MM1::DefaultEpilogue::SharedStorage epilogue; - }; - - CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& - epilogue_shared_storage() { - return epilogue; - } - }; - - struct SharedStorageEpilogueInLoop : ScalingCoefs { - struct SharedStorageAfterMM0 { - // Everything here might be overwritten during MM0 - typename MM0::AccumulatorSharedStorage si; - typename MM1::SharedStorageMM1 mm1; - typename MM1::DefaultEpilogue::SharedStorage epilogue; - }; - - union { - typename MM0::Mma::SharedStorage mm0; - SharedStorageAfterMM0 after_mm0; - }; - - CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& - epilogue_shared_storage() { - return after_mm0.epilogue; - } - }; - - using SharedStorage = typename cutlass::platform::conditional< - kSingleValueIteration || kKeepOutputInRF, - SharedStorageEpilogueAtEnd, - SharedStorageEpilogueInLoop>::type; - - static bool __host__ check_supported(Params const& p) { - CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ); - CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK); - CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV); - XFORMERS_CHECK( - p.q_strideM % kAlignmentQ == 0, "query is not correctly aligned"); - XFORMERS_CHECK( - p.k_strideM % kAlignmentK == 0, "key is not correctly aligned"); - XFORMERS_CHECK( - p.v_strideM % kAlignmentV == 0, "value is not correctly aligned"); - XFORMERS_CHECK( - p.q_strideH % kAlignmentQ == 0, "query is not correctly aligned"); - XFORMERS_CHECK( - p.k_strideH % kAlignmentK == 0, "key is not correctly aligned"); - XFORMERS_CHECK( - p.v_strideH % kAlignmentV == 0, "value is not correctly aligned"); - return true; - } - - static void CUTLASS_DEVICE attention_kernel(Params& p) { - // In this block, we will only ever: - // - read query[query_start:query_end, :] - // - write to output[query_start:query_end, :] - - extern __shared__ char smem_buffer[]; - SharedStorage& shared_storage = *((SharedStorage*)smem_buffer); - auto& m_prime = shared_storage.m_prime; - auto& s_prime = shared_storage.s_prime; - auto& mi = shared_storage.mi; - - static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); - if (thread_id() < kQueriesPerBlock) { - s_prime[thread_id()] = accum_t(0); - m_prime[thread_id()] = - -cutlass::platform::numeric_limits::infinity(); - mi[thread_id()] = -cutlass::platform::numeric_limits::infinity(); - } - typename MM1::Mma::FragmentC accum_o; - accum_o.clear(); - - auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator { - using OutputTileIterator = typename MM1::OutputTileIterator; - return OutputTileIterator( - typename OutputTileIterator::Params{(int32_t)p.o_strideM()}, - p.output_ptr, - typename OutputTileIterator::TensorCoord{ - p.num_queries, p.head_dim_value}, - thread_id(), - {0, col}); - }; - - auto createOutputAccumIter = [&](int col) -> - typename MM1::OutputTileIteratorAccum { - using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum; - return OutputTileIteratorAccum( - typename OutputTileIteratorAccum::Params{(int32_t)p.o_strideM()}, - p.output_accum_ptr, - typename OutputTileIteratorAccum::TensorCoord{ - p.num_queries, p.head_dim_value}, - thread_id(), - {0, col}); - }; - - // Iterate through keys - for (int32_t iter_key_start = 0; iter_key_start < p.num_keys; - iter_key_start += kKeysPerBlock) { - int32_t problem_size_0_m = - cutlass::fast_min((int32_t)kQueriesPerBlock, p.num_queries); - int32_t problem_size_0_n = cutlass::fast_min( - int32_t(kKeysPerBlock), p.num_keys - iter_key_start); - int32_t const& problem_size_0_k = p.head_dim; - int32_t const& problem_size_1_n = p.head_dim_value; - int32_t const& problem_size_1_k = problem_size_0_n; - - auto prologueV = [&](int blockN) { - typename MM1::Mma::IteratorB iterator_V( - typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)}, - p.value_ptr + iter_key_start * p.v_strideM, - {problem_size_1_k, problem_size_1_n}, - thread_id(), - cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); - MM1::Mma::prologue( - shared_storage.after_mm0.mm1.mm, - iterator_V, - thread_id(), - problem_size_1_k); - }; - - __syncthreads(); // Need to have shared memory initialized, and `m_prime` - // updated from end of prev iter - // - // MATMUL: Q.K_t - // - // Computes the block-matrix product of: - // (a) query[query_start:query_end, :] - // with - // (b) key[iter_key_start:iter_key_start + kKeysPerBlock] - // and stores that into `shared_storage.si` - // - - // Compute threadblock location - cutlass::gemm::GemmCoord tb_tile_offset = {0, 0, 0}; - - cutlass::MatrixCoord tb_offset_A{ - tb_tile_offset.m() * MM0::Mma::Shape::kM, tb_tile_offset.k()}; - - cutlass::MatrixCoord tb_offset_B{ - tb_tile_offset.k(), tb_tile_offset.n() * MM0::Mma::Shape::kN}; - - // Construct iterators to A and B operands - typename MM0::IteratorA iterator_A( - typename MM0::IteratorA::Params( - typename MM0::MmaCore::LayoutA(p.q_strideM)), - p.query_ptr, - {problem_size_0_m, problem_size_0_k}, - thread_id(), - tb_offset_A); - - typename MM0::IteratorB iterator_B( - typename MM0::IteratorB::Params( - typename MM0::MmaCore::LayoutB(p.k_strideM)), - p.key_ptr + iter_key_start * p.k_strideM, - {problem_size_0_k, problem_size_0_n}, - thread_id(), - tb_offset_B); - - auto my_warp_id = warp_id(); - auto my_lane_id = lane_id(); - - // Construct thread-scoped matrix multiply - typename MM0::Mma mma( - shared_storage.mm0, thread_id(), my_warp_id, my_lane_id); - - typename MM0::Mma::FragmentC accum; - - accum.clear(); - - auto gemm_k_iterations = - (problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK; - - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); - __syncthreads(); - - if (kPreloadV) { - prologueV(0); - } - - typename MM0::Mma::Operator::IteratorC::TensorCoord - iteratorC_tile_offset = { - (tb_tile_offset.m() * MM0::Mma::WarpCount::kM) + - (my_warp_id % MM0::Mma::WarpCount::kM), - (tb_tile_offset.n() * MM0::Mma::WarpCount::kN) + - (my_warp_id / MM0::Mma::WarpCount::kM)}; - - // Mask out last if causal - if (p.causal && p.num_keys - iter_key_start <= kKeysPerBlock) { - auto query_start = blockIdx.x * kQueriesPerBlock; - auto lane_offset = MM0::ScalingCoefsUpdater::get_lane_offset( - lane_id(), warp_id(), iteratorC_tile_offset); - int32_t last_col; - MM0::ScalingCoefsUpdater::iterateRows( - lane_offset, - [&](int accum_m) { - last_col = query_start + accum_m - iter_key_start; - }, - [&](int accum_m, int accum_n, int idx) { - if (accum_n > last_col) { - accum[idx] = - -cutlass::platform::numeric_limits::infinity(); - } - }, - [&](int accum_m) {}); - } - DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] { - DISPATCH_BOOL( - p.num_keys - iter_key_start >= kKeysPerBlock, - kFullColumns, - ([&] { - // Update `mi` from accum stored in registers - // Also updates `accum` with accum[i] <- - // exp(accum[i] * scale - // - mi) - MM0::ScalingCoefsUpdater::update< - kQueriesPerBlock, - kFullColumns, - kIsFirst, - kKeepOutputInRF>( - accum_o, - accum, - mi, - m_prime, - s_prime, - lane_id(), - thread_id(), - warp_id(), - p.num_keys - iter_key_start, - iteratorC_tile_offset, - 1.0f / cutlass::fast_sqrt(float(p.head_dim))); - })); - })); - - // Output results to shared-memory - int warp_idx_mn_0 = my_warp_id % - (MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN); - auto output_tile_coords = cutlass::MatrixCoord{ - warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM, - warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM}; - - MM0::B2bGemm::accumToSmem( - shared_storage.after_mm0.si, accum, my_lane_id, output_tile_coords); - - __syncthreads(); - - // - // MATMUL: Attn . V - // Run the matmul `attn @ V` for a block of attn and V. - // `attn` is read from shared memory (in `shared_storage_si`) - // `V` is read from global memory (with iterator_B) - // - - const int64_t nBlockN = kSingleValueIteration - ? 1 - : ceil_div( - (int64_t)problem_size_1_n, int64_t(MM1::ThreadblockShape::kN)); - for (int blockN = 0; blockN < nBlockN; ++blockN) { - int gemm_k_iterations = - (problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK; - - // Compute threadblock-scoped matrix multiply-add and store it in accum - // (in registers) - if (!kPreloadV) { - __syncthreads(); // we share shmem between mma and epilogue - } - - typename MM1::Mma::IteratorB iterator_V( - typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)}, - p.value_ptr + iter_key_start * p.v_strideM, - {problem_size_1_k, problem_size_1_n}, - thread_id(), - cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); - typename MM1::Mma mma_pv( - shared_storage.after_mm0.mm1.mm, - shared_storage.after_mm0.si, - (int)thread_id(), - (int)warp_id(), - (int)lane_id(), - (int)problem_size_1_k); - mma_pv.set_prologue_done(kPreloadV); - if (!kKeepOutputInRF) { - accum_o.clear(); - } - mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o); - __syncthreads(); - - if (kPreloadV && !kSingleValueIteration && blockN + 1 < nBlockN) { - prologueV(blockN + 1); - } - - if (!kKeepOutputInRF) { - DISPATCH_BOOL( - iter_key_start == 0, kIsFirst, ([&] { - DISPATCH_BOOL( - (iter_key_start + kKeysPerBlock) >= p.num_keys, - kIsLast, - ([&] { - using DefaultEpilogue = typename MM1::DefaultEpilogue; - using DefaultOp = - typename MM1::DefaultConfig::EpilogueOutputOp; - using ElementCompute = typename DefaultOp::ElementCompute; - using EpilogueOutputOp = typename cutlass::epilogue:: - thread::MemoryEfficientAttentionNormalize< - typename cutlass::platform::conditional< - kIsLast, - output_t, - output_accum_t>::type, - output_accum_t, - DefaultOp::kCount, - typename DefaultOp::ElementAccumulator, - ElementCompute, - kIsFirst, - kIsLast, - cutlass::Array>; - using Epilogue = typename cutlass::epilogue::threadblock:: - EpiloguePipelined< - typename DefaultEpilogue::Shape, - typename MM1::Mma::Operator, - DefaultEpilogue::kPartitionsK, - typename cutlass::platform::conditional< - kIsLast, - typename MM1::OutputTileIterator, - typename MM1::OutputTileIteratorAccum>::type, - typename DefaultEpilogue:: - AccumulatorFragmentIterator, - typename DefaultEpilogue::WarpTileIterator, - typename DefaultEpilogue::SharedLoadIterator, - EpilogueOutputOp, - typename DefaultEpilogue::Padding, - DefaultEpilogue::kFragmentsPerIteration, - true, // IterationsUnroll - typename MM1::OutputTileIteratorAccum // Read - // iterator - >; - - int col = blockN * MM1::Mma::Shape::kN; - auto source_iter = createOutputAccumIter(col); - auto dest_iter = call_conditional< - kIsLast, - decltype(createOutputIter), - decltype(createOutputAccumIter)>:: - apply(createOutputIter, createOutputAccumIter, col); - EpilogueOutputOp rescale(s_prime, m_prime); - Epilogue epilogue( - shared_storage.epilogue_shared_storage(), - thread_id(), - warp_id(), - lane_id()); - epilogue(rescale, dest_iter, accum_o, source_iter); - })); - })); - if (!kSingleValueIteration) { - __syncthreads(); - } - } - } - __syncthreads(); // we modify `m_prime` after - } - - if (kKeepOutputInRF) { - constexpr bool kIsFirst = true; - constexpr bool kIsLast = true; - using DefaultEpilogue = typename MM1::DefaultEpilogue; - using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; - using ElementCompute = typename DefaultOp::ElementCompute; - using EpilogueOutputOp = - typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize< - output_t, // output - output_accum_t, // source - DefaultOp::kCount, - typename DefaultOp::ElementAccumulator, // accum - output_accum_t, // compute - kIsFirst, - kIsLast, - cutlass::Array>; - using Epilogue = - typename cutlass::epilogue::threadblock::EpiloguePipelined< - typename DefaultEpilogue::Shape, - typename MM1::Mma::Operator, - DefaultEpilogue::kPartitionsK, - typename MM1::OutputTileIterator, // destination - typename DefaultEpilogue::AccumulatorFragmentIterator, - typename DefaultEpilogue::WarpTileIterator, - typename DefaultEpilogue::SharedLoadIterator, - EpilogueOutputOp, - typename DefaultEpilogue::Padding, - DefaultEpilogue::kFragmentsPerIteration, - true, // IterationsUnroll - typename MM1::OutputTileIteratorAccum // source tile - >; - auto dest_iter = createOutputIter(0); - EpilogueOutputOp rescale(s_prime, m_prime); - Epilogue epilogue( - shared_storage.epilogue_shared_storage(), - thread_id(), - warp_id(), - lane_id()); - epilogue(rescale, dest_iter, accum_o); - } - - // 7. Calculate logsumexp - // To make the backward easier, we pad logsumexp with `inf` - // this avoids a few bound checks, and is not more expensive during fwd - static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); - if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) { - auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE; - if (thread_id() < p.num_queries) { - p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()]) + - cutlass::fast_log(accum_t(s_prime[thread_id()])); - } else if (thread_id() < lse_dim) { - p.logsumexp_ptr[thread_id()] = - cutlass::platform::numeric_limits::infinity(); - } - } - } - - static CUTLASS_DEVICE int8_t lane_id() { - return threadIdx.x; - } - static CUTLASS_DEVICE int8_t warp_id() { - return threadIdx.y; - } - static CUTLASS_DEVICE int16_t thread_id() { - return threadIdx.x + threadIdx.y * blockDim.x; - } -}; - -template -__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) - attention_kernel_batched_impl(typename AK::Params p) { - if (!p.advance_to_block()) { - return; - } - AK::attention_kernel(p); -} - -template -__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) - attention_kernel_batched(typename AK::Params params); - -#define _ATTENTION_KERNEL_FORWARD_BEGIN(...) \ - template <> \ - __global__ void __launch_bounds__( \ - __VA_ARGS__::kNumThreads, __VA_ARGS__::kMinBlocksPerSm) \ - attention_kernel_batched<__VA_ARGS__>(typename __VA_ARGS__::Params p) { \ - using Kernel = __VA_ARGS__; -#define _ATTENTION_KERNEL_FORWARD_END() } - -#ifdef __CUDA_ARCH__ -#define __CUDA_ARCH_OR_ZERO__ __CUDA_ARCH__ -#else -#define __CUDA_ARCH_OR_ZERO__ 0 -#endif - -#define INSTANTIATE_ATTENTION_KERNEL_FORWARD( \ - ARCH, \ - SCALAR_T, \ - IS_ALIGNED, \ - QUERIES_PER_BLOCK, \ - KEYS_PER_BLOCK, \ - SINGLE_VALUE_ITER) \ - _ATTENTION_KERNEL_FORWARD_BEGIN(AttentionKernel< \ - SCALAR_T, \ - cutlass::arch::Sm##ARCH, \ - IS_ALIGNED, \ - QUERIES_PER_BLOCK, \ - KEYS_PER_BLOCK, \ - SINGLE_VALUE_ITER>) \ - if (!p.advance_to_block()) { \ - return; \ - } \ - Kernel::attention_kernel(p); \ - _ATTENTION_KERNEL_FORWARD_END(); - -#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED( \ - ARCH, \ - SCALAR_T, \ - IS_ALIGNED, \ - QUERIES_PER_BLOCK, \ - KEYS_PER_BLOCK, \ - SINGLE_VALUE_ITER) \ - _ATTENTION_KERNEL_FORWARD_BEGIN(AttentionKernel< \ - SCALAR_T, \ - cutlass::arch::Sm##ARCH, \ - IS_ALIGNED, \ - QUERIES_PER_BLOCK, \ - KEYS_PER_BLOCK, \ - SINGLE_VALUE_ITER>) \ - printf( \ - "FATAL: this function is for sm%d, but was built for sm%d\n", \ - int(ARCH), \ - int(__CUDA_ARCH_OR_ZERO__)); \ - _ATTENTION_KERNEL_FORWARD_END(); - -// All kernels are disabled by default -#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \ - INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(50, __VA_ARGS__) -#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \ - INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(70, __VA_ARGS__) -#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \ - INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(75, __VA_ARGS__) -#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \ - INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(80, __VA_ARGS__) - -// Enable the right one based on __CUDA_ARCH__ -#ifndef __CUDA_ARCH__ -#elif __CUDA_ARCH__ < 500 -//#error "Need cuda arch at least 5.0" -#elif __CUDA_ARCH__ < 700 -#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50 -#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \ - INSTANTIATE_ATTENTION_KERNEL_FORWARD(50, __VA_ARGS__) -#elif __CUDA_ARCH__ < 750 -#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70 -#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \ - INSTANTIATE_ATTENTION_KERNEL_FORWARD(70, __VA_ARGS__) -#elif __CUDA_ARCH__ < 800 -#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75 -#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \ - INSTANTIATE_ATTENTION_KERNEL_FORWARD(75, __VA_ARGS__) -#elif __CUDA_ARCH__ >= 800 -#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80 -#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \ - INSTANTIATE_ATTENTION_KERNEL_FORWARD(80, __VA_ARGS__) -#endif - -#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h index d4484628b6..3cd86674f1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h @@ -21,15 +21,21 @@ struct MemoryEfficientAttentionParams { int32_t qk_head_size; int32_t v_head_size; bool causal; + // The default shape of attn_bias is [1, N, S, S*]. Sometimes we need to use [B, N, S, S*] in custom models. + bool is_attn_bias_batched; - int32_t* cu_seqlens_q; - int32_t* cu_seqlens_k; + float scale; - const void* query; // [B, S, N, H] - const void* key; // [B, L, N, H], where L is kv_sequence_length - const void* value; // [B, L, N, H_v] - void* output; // [B, S, N, H_v] - void* workspace; // [B, S, N, H_v] when kNeedsOutputAccumulatorBuffer, nullptr otherwise + int32_t* seqstart_q_ptr; + int32_t* seqstart_k_ptr; + int32_t* seqlen_k_ptr; + + const void* query; // [B, S, N, H] + const void* key; // [B, L, N, H], where L is kv_sequence_length + const void* value; // [B, L, N, H_v] + const void* attn_bias; // [N, S, S*] or null + void* output; // [B, S, N, H_v] + void* workspace; // [B, S, N, H_v] when kNeedsOutputAccumulatorBuffer, nullptr otherwise cudaStream_t stream; static bool need_workspace(size_t v_head_size, bool is_float) { diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index ac3c7afcb1..f077d56f03 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -116,6 +116,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { int sm = device_prop.major * 10 + device_prop.minor; bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN; + bool is_mask_1d_key_seq_len_start = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START; bool use_fused_cross_attention = !disable_fused_cross_attention_ && nullptr == key_padding_mask && @@ -168,12 +169,14 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32 || parameters.kv_sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32; + bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0; + bool use_memory_efficient_attention = fused_runner == nullptr && fused_cross_attention_kernel == nullptr && !disable_memory_efficient_attention_ && is_long_sequence && - nullptr == key_padding_mask && // TODO: support 1D mask - nullptr == relative_position_bias && + (relative_position_bias == nullptr || is_good_for_rpb) && + (nullptr == key_padding_mask || is_mask_1d_key_seq_len_start) && has_memory_efficient_attention(sm, sizeof(T) == 2); #else constexpr bool use_memory_efficient_attention = false; diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index b205b64954..4cc07d8764 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -272,7 +272,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "mask_index", "Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), " "(batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length), " - "or index with shape (batch_size) or (2 * batch_size)", + "or index with shape (batch_size) or (2 * batch_size) or (3 * batch_size + 2)", "M", OpSchema::Optional) .Input(4, @@ -590,7 +590,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .Input(4, "key_padding_mask", - "Key padding mask with shape (batch_size) or (batch_size, kv_sequence_length)", + "Key padding mask with shape (batch_size) or (3 * batch_size + 2) or (batch_size, kv_sequence_length)", "M", OpSchema::Optional) .Input(5, diff --git a/onnxruntime/test/contrib_ops/attention_op_test_helper.cc b/onnxruntime/test/contrib_ops/attention_op_test_helper.cc index 5c7f9dfab0..c7f7c7b653 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test_helper.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test_helper.cc @@ -3090,6 +3090,195 @@ void GetSelfAttentionDataWithPast(AttentionTestData& data) { data.is_static_kv = false; } +void GetAttentionDataCutlassRelPosBias(AttentionTestData& data) { + data.hidden_size = 8; + data.v_hidden_size = 8; + data.num_heads = 2; + data.batch_size = 1; + data.sequence_length = 8; + data.kv_sequence_length = 0; + data.mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START; + + data.key_padding_mask_data = {8, 0, 8, 0, 8}; + + data.skip_kernel_types = { + AttentionKernelType::AttentionKernel_TrtFlashAttention, + AttentionKernelType::AttentionKernel_TrtFusedCrossAttention, + AttentionKernelType::AttentionKernel_TrtFusedAttention}; + + { + data.query_data = { + -0.029273793f, 0.079709493f, 0.064531095f, 0.24270254f, + -0.28326464f, 0.20984903f, -0.10173888f, 0.18373983f, + + 0.089472905f, -0.0063416883f, -0.049477674f, 0.36512995f, + -0.23620239f, 0.1464397f, 0.068258412f, 0.31627196f, + + 0.12436871f, -0.0075563118f, -0.11576633f, 0.41008925f, + -0.19456652f, 0.20145792f, 0.11790096f, 0.39789933f, + + 0.002485469f, 0.029660821f, -0.043821491f, 0.3892332f, + -0.26994205f, 0.14530671f, 0.12950704f, 0.36185294f, + + -0.029273793f, 0.079709493f, 0.064531095f, 0.24270254f, + -0.28326464f, 0.20984903f, -0.10173888f, 0.18373983f, + + 0.089472905f, -0.0063416883f, -0.049477674f, 0.36512995f, + -0.23620239f, 0.1464397f, 0.068258412f, 0.31627196f, + + 0.12436871f, -0.0075563118f, -0.11576633f, 0.41008925f, + -0.19456652f, 0.20145792f, 0.11790096f, 0.39789933f, + + 0.002485469f, 0.029660821f, -0.043821491f, 0.3892332f, + -0.26994205f, 0.14530671f, 0.12950704f, 0.36185294f, + }; + } + { + data.key_data = { + -0.32538497f, 0.34121913f, -0.18170178f, -0.015152611f, + 0.20429322f, 0.25979176f, 0.21269324f, 0.0025638193f, + + -0.24246037f, 0.21112341f, -0.36959589f, -0.16091451f, + 0.24183474f, 0.18856162f, 0.094487116f, -0.3053959f, + + -0.35736683f, 0.29276621f, -0.4217523f, -0.20031664f, + 0.33148992f, 0.26928401f, 0.19360018f, -0.39494509f, + + -0.28043351f, 0.24279942f, -0.29154932f, -0.13657911f, + 0.31932494f, 0.3500579f, 0.027172565f, -0.19327414f, + + -0.32538497f, 0.34121913f, -0.18170178f, -0.015152611f, + 0.20429322f, 0.25979176f, 0.21269324f, 0.0025638193f, + + -0.24246037f, 0.21112341f, -0.36959589f, -0.16091451f, + 0.24183474f, 0.18856162f, 0.094487116f, -0.3053959f, + + -0.35736683f, 0.29276621f, -0.4217523f, -0.20031664f, + 0.33148992f, 0.26928401f, 0.19360018f, -0.39494509f, + + -0.28043351f, 0.24279942f, -0.29154932f, -0.13657911f, + 0.31932494f, 0.3500579f, 0.027172565f, -0.19327414f, + }; + } + + { + data.value_data = { + 0.56916672f, -0.2443777f, 0.47111356f, -0.52134115f, + 0.010381341f, 0.0696759f, -0.071910433f, -0.35201436f, + + 0.70809275f, -0.24479815f, 0.41633749f, -0.34744334f, + -0.0044222325f, 0.25929695f, -0.087832771f, -0.281232f, + + 0.90039468f, -0.28931504f, 0.56394172f, -0.43948689f, + -0.05856207f, 0.33713666f, -0.10320446f, -0.38833332f, + + 0.76054728f, -0.29080144f, 0.50414616f, -0.42371163f, + -0.047198489f, 0.31959397f, -0.22683662f, -0.30321664f, + + 0.56916672f, -0.2443777f, 0.47111356f, -0.52134115f, + 0.010381341f, 0.0696759f, -0.071910433f, -0.35201436f, + + 0.70809275f, -0.24479815f, 0.41633749f, -0.34744334f, + -0.0044222325f, 0.25929695f, -0.087832771f, -0.281232f, + + 0.90039468f, -0.28931504f, 0.56394172f, -0.43948689f, + -0.05856207f, 0.33713666f, -0.10320446f, -0.38833332f, + + 0.76054728f, -0.29080144f, 0.50414616f, -0.42371163f, + -0.047198489f, 0.31959397f, -0.22683662f, -0.30321664f, + }; + } + + { + data.bias_data = { + -0.38124341f, 0.02696526f, -0.11914945f, -0.43795273f, + 0.04772711f, -0.03419551f, -0.30606642f, 0.42656231f, + -0.25891554f, 0.13431972f, 0.22861153f, 0.06360734f, + -0.10595283f, -0.42839217f, 0.28931111f, -0.13180739f, + 0.27079183f, 0.42074734f, -0.40314156f, -0.43726659f, + -0.40546918f, 0.06927037f, 0.16979086f, 0.41458064f + }; + } + + { + data.rel_pos_bias_data = { + -10.808288f, -10.887209f, 7.8799553f, -4.6565766f, + -1.6700006f, -0.033962168f, 7.4929152f, 10.944146f, + 8.640254f, -18.862164f, -3.1202927f, -6.3049207f, + 3.4508536f, 11.722519f, 3.3550568f, -5.4888172f, + + -2.0828252f, -13.241742f, 2.9868939f, 1.4455698f, + -15.262972f, -10.457437f, -8.4519463f, -4.4281874f, + 10.212368f, -0.28622282f, 12.087646f, 6.5218501f, + 8.1785011f, 13.985523f, -8.2068987f, 5.4260745f, + + -10.808288f, -10.887209f, 7.8799553f, -4.6565766f, + -1.6700006f, -0.033962168f, 7.4929152f, 10.944146f, + 8.640254f, -18.862164f, -3.1202927f, -6.3049207f, + 3.4508536f, 11.722519f, 3.3550568f, -5.4888172f, + + -2.0828252f, -13.241742f, 2.9868939f, 1.4455698f, + -15.262972f, -10.457437f, -8.4519463f, -4.4281874f, + 10.212368f, -0.28622282f, 12.087646f, 6.5218501f, + 8.1785011f, 13.985523f, -8.2068987f, 5.4260745f, + + -10.808288f, -10.887209f, 7.8799553f, -4.6565766f, + -1.6700006f, -0.033962168f, 7.4929152f, 10.944146f, + 8.640254f, -18.862164f, -3.1202927f, -6.3049207f, + 3.4508536f, 11.722519f, 3.3550568f, -5.4888172f, + + -2.0828252f, -13.241742f, 2.9868939f, 1.4455698f, + -15.262972f, -10.457437f, -8.4519463f, -4.4281874f, + 10.212368f, -0.28622282f, 12.087646f, 6.5218501f, + 8.1785011f, 13.985523f, -8.2068987f, 5.4260745f, + + -10.808288f, -10.887209f, 7.8799553f, -4.6565766f, + -1.6700006f, -0.033962168f, 7.4929152f, 10.944146f, + 8.640254f, -18.862164f, -3.1202927f, -6.3049207f, + 3.4508536f, 11.722519f, 3.3550568f, -5.4888172f, + + -2.0828252f, -13.241742f, 2.9868939f, 1.4455698f, + -15.262972f, -10.457437f, -8.4519463f, -4.4281874f, + 10.212368f, -0.28622282f, 12.087646f, 6.5218501f, + 8.1785011f, 13.985523f, -8.2068987f, 5.4260745f, + }; + } + + { + data.fp16_output_data = { + 1.0419922f, 0.13000488f, 0.10528564f, -0.86230469f, + -0.45336914f, 0.39013672f, -0.048858643f, 0.10571289f, + + 0.97265625f, 0.17590332f, 0.015625f, -0.79248047f, + -0.40917969f, 0.31933594f, 0.082763672f, 0.12976074f, + + 1.1455078f, 0.13134766f, 0.15014648f, -0.87451172f, + -0.46142578f, 0.40161133f, 0.04309082f, 0.042663574f, + + 1.0009766f, 0.17004395f, 0.033752441f, -0.80078125f, + -0.41625977f, 0.33349609f, 0.080383301f, 0.11846924f, + + 1.0419922f, 0.13000488f, 0.10528564f, -0.86230469f, + -0.45336914f, 0.39013672f, -0.048858643f, 0.10571289f, + + 0.97265625f, 0.17590332f, 0.015625f, -0.79248047f, + -0.40917969f, 0.31933594f, 0.082763672f, 0.12976074f, + + 1.1455078f, 0.13134766f, 0.15014648f, -0.87451172f, + -0.46142578f, 0.40161133f, 0.04309082f, 0.042663574f, + + 1.0009766f, 0.17004395f, 0.033752441f, -0.80078125f, + -0.41625977f, 0.33349609f, 0.080383301f, 0.11846924f, + }; + } + + { + data.fp32_output_data = {}; + } + + data.is_static_kv = false; +} + bool SkipAttentionKernel(AttentionTestData& data, AttentionKernelType kernel_type) { return std::find(data.skip_kernel_types.begin(), data.skip_kernel_types.end(), kernel_type) != data.skip_kernel_types.end(); } diff --git a/onnxruntime/test/contrib_ops/attention_op_test_helper.h b/onnxruntime/test/contrib_ops/attention_op_test_helper.h index 807e1e207d..bfd65c5794 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test_helper.h +++ b/onnxruntime/test/contrib_ops/attention_op_test_helper.h @@ -62,6 +62,8 @@ void GetCrossAttentionData_HeadSize16(AttentionTestData& data); void GetCrossAttentionDataWithPast(AttentionTestData& data); void GetSelfAttentionDataWithPast(AttentionTestData& data); +void GetAttentionDataCutlassRelPosBias(AttentionTestData& data); + bool SkipAttentionKernel(AttentionTestData& data, AttentionKernelType kernel_type); } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc index 415a1c6f8f..d01e07a46a 100644 --- a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc @@ -55,7 +55,9 @@ static void RunMultiHeadAttentionTest( std::vector key_dims = {batch_size, is_static_kv ? kv_sequence_length : sequence_length, hidden_size}; std::vector value_dims = {batch_size, is_static_kv ? kv_sequence_length : sequence_length, v_hidden_size}; std::vector bias_dims = {hidden_size + hidden_size + v_hidden_size}; - std::vector rel_pos_bias_dims = {1, num_heads, sequence_length, sequence_length + kv_sequence_length}; + // TODO(wy): Introduce past sequence length to avoid using kv_sequence_length. + std::vector rel_pos_bias_dims = + {1, num_heads, sequence_length, past_key_data.size() ? sequence_length + kv_sequence_length : sequence_length}; std::vector past_key_dims = {batch_size, num_heads, kv_sequence_length, hidden_size / num_heads}; std::vector past_value_dims = past_key_dims; std::vector output_dims = {batch_size, sequence_length, v_hidden_size}; @@ -82,9 +84,10 @@ static void RunMultiHeadAttentionTest( std::vector mask_dims_1 = {batch_size}; std::vector mask_dims_2 = {batch_size, kv_sequence_length}; + std::vector mask_dims_3 = {3 * batch_size + 2}; std::vector& key_padding_mask_dims = (mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN) - ? mask_dims_1 - : mask_dims_2; + ? mask_dims_1 + : (mask_type == AttentionMaskType::MASK_2D_KEY_PADDING ? mask_dims_2 : mask_dims_3); if (use_float16) { tester.AddInput("query", query_dims, ToFloat16(query)); @@ -487,5 +490,11 @@ TEST(MultiHeadAttentionTest, SelfAttentionWithPast) { RunMultiHeadAttentionTests(data); } +TEST(MultiHeadAttentionTest, AttentionCutlassRelPosBias) { + AttentionTestData data; + GetAttentionDataCutlassRelPosBias(data); + RunMultiHeadAttentionTests(data); +} + } // namespace test } // namespace onnxruntime diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index f5e35a3899..4a55eaa33e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.32 + version: 1.0.36 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.32 + version: 1.0.36 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here.