From 2cf0ae7d0111e08e1a471bec0259d91bad205e4e Mon Sep 17 00:00:00 2001 From: cloudhan Date: Fri, 26 May 2023 12:06:36 +0800 Subject: [PATCH] [ROCm] Add AttentionMode to make attention logic streamline (#15978) Refactor for future kv cache change. --- .../contrib_ops/rocm/bert/attention.cu | 11 +- .../contrib_ops/rocm/bert/attention_impl.cu | 104 +++++++- .../contrib_ops/rocm/bert/attention_impl.h | 28 ++ ...ed_gemm_softmax_gemm_permute_pipelines.cuh | 246 ++++++++++++++---- .../rocm/bert/multihead_attention.cu | 11 +- .../kernels/gemm_softmax_gemm_permute_test.py | 3 +- .../kernels/rocm/gemm_softmax_gemm_permute.cu | 23 +- 7 files changed, 348 insertions(+), 78 deletions(-) diff --git a/onnxruntime/contrib_ops/rocm/bert/attention.cu b/onnxruntime/contrib_ops/rocm/bert/attention.cu index 96c05eaa90..124116497c 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/attention.cu @@ -52,7 +52,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { const Tensor* past_seq_len = context->Input(kPastSequenceLengthInputIndex); auto& device_prop = GetDeviceProp(); - AttentionParameters attn; + RocmAttentionParameters attn; ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), @@ -63,7 +63,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { device_prop.maxThreadsPerBlock, past_seq_len)); ORT_ENFORCE(attn.sequence_length == attn.kv_sequence_length); // self attention - ORT_ENFORCE(attn.qkv_format == Q_K_V_BNSH); // non-packed, permuted + ORT_ENFORCE(attn.qkv_format == Q_K_V_BNSH); // non-packed, permuted TensorShapeVector output_shape(3); output_shape[0] = static_cast(attn.batch_size); @@ -86,6 +86,13 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { using AttentionGeneric = GemmSoftmaxGemmPermuteGenericPipeline; using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp; + ORT_RETURN_IF_ERROR(ClassifyAttentionMode( + Node().OpType(), &attn, /*qkv=*/{}, /*past=*/{past}, /*present=*/{present})); + // TODO: support QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE and QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE + ORT_ENFORCE(attn.mode == QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE || + attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE || + attn.mode == QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE); + size_t qkv_project_output_bytes = QkvProjectGeneric::GetOutputNumBytes(&attn); size_t shared_workspace_bytes = std::max(QkvProjectGeneric::GetWorkspaceNumBytes(&attn), AttentionGeneric::GetWorkspaceNumBytes(&attn)); diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu index 32f225451e..b31311e6ed 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu @@ -78,6 +78,86 @@ inline int3 Get2DMaskStrides(int total_sequence_length) { return {total_sequence_length, 0, 1}; } +Status ClassifyAttentionMode( + const std::string& op, + RocmAttentionParameters* attn, + const std::vector& qkv, + const std::vector& past, + const std::vector& present) { + size_t num_qkv = std::count_if(qkv.cbegin(), qkv.cend(), [](auto it) { return it != nullptr; }); + size_t num_past = std::count_if(past.cbegin(), past.cend(), [](auto it) { return it != nullptr; }); + size_t num_present = std::count_if(present.cbegin(), present.cend(), [](auto it) { return it != nullptr; }); + + auto hint = MakeString(num_qkv, " qkv inputs, ", num_past, " past inputs and ", num_present, " present inputs"); + LOGS_DEFAULT(VERBOSE) << hint; + + if (op == "Attention") { + ORT_ENFORCE(num_qkv == 0); + if (num_past == 0 && num_present == 0) { + attn->mode = QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE; + return Status::OK(); + } else if (num_past == 0 && num_present == 1) { + if (attn->past_present_share_buffer == false) { + attn->mode = QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE; + return Status::OK(); + } else { + attn->mode = QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE; + return Status::OK(); + } + } else if (num_past == 1 && num_present == 1) { + if (attn->past_present_share_buffer == false) { + attn->mode = QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE; + return Status::OK(); + } else { + attn->mode = QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE; + return Status::OK(); + } + } + } else if (op == "MultiHeadAttention") { + if (num_qkv == 3 && num_past == 0 && num_present == 0) { + if (attn->qkv_format == Q_K_V_BSNH) { + attn->mode = BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE; + return Status::OK(); + } else if (attn->pass_past_in_kv) { + attn->mode = BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE; + return Status::OK(); + } + } else if (num_qkv == 3 && num_past == 2 && num_present == 2) { + if (attn->past_present_share_buffer == false) { + if (attn->qkv_format == Q_K_V_BSNH) { + attn->mode = BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH; + return Status::OK(); + } else if (attn->pass_past_in_kv) { + attn->mode = BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH; + return Status::OK(); + } + } else { + if (attn->qkv_format == Q_K_V_BSNH) { + attn->mode = BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH; + return Status::OK(); + } else if (attn->pass_past_in_kv) { + attn->mode = BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH; + return Status::OK(); + } + } + } else if (num_qkv == 1 && num_past == 0 && num_present == 0) { + if (attn->qkv_format == QKV_BSN3H) { + attn->mode = BLN3H_NONE_NONE_NONE_NONE_NONE_NONE; + return Status::OK(); + } + } else if (num_qkv == 2 && num_past == 0 && num_present == 0) { + if (attn->qkv_format == Q_KV_BSNH_BSN2H) { + attn->mode = BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE; + return Status::OK(); + } + } + } + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Unsupported AttentionMode for ", op, ". Got qkv format ", attn->qkv_format, + ". Got ", hint); +} + template Status DecoderQkvToContext( const hipDeviceProp_t& prop, @@ -117,7 +197,7 @@ Status DecoderQkvToContext( const T* q = qkv_buffer; // transpose q and copy them to qkv_buffer ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, head_size, - num_heads, max_threads_per_block, true, gemm_query_buffer, qkv_buffer)); + num_heads, max_threads_per_block, true, gemm_query_buffer, qkv_buffer)); const T* k = qkv_buffer + k_buffer_offset; const T* v = qkv_buffer + v_buffer_offset; @@ -125,31 +205,31 @@ Status DecoderQkvToContext( if (!static_kv) { // transpose kv and copy them to qkv_buffer ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); + max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); } else { // transpose kv and copy them to qkv_buffer ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, kv_sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); + max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); } } else { if (!static_kv) { // transpose kv and copy them to temp_buffer ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_kv_buffer, temp_qkv_buffer)); + max_threads_per_block, true, gemm_kv_buffer, temp_qkv_buffer)); // concat cache-k with k and copy to qkv_buffer if (nullptr != key_cache) { ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, sequence_length, - batch_size, head_size, num_heads, - max_threads_per_block, 1, key_cache, - temp_qkv_buffer, qkv_buffer + k_buffer_offset)); + batch_size, head_size, num_heads, + max_threads_per_block, 1, key_cache, + temp_qkv_buffer, qkv_buffer + k_buffer_offset)); } // concat cache-v with v and copy to qkv_buffer if (nullptr != value_cache) { ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, sequence_length, - batch_size, head_size, num_heads, - max_threads_per_block, 1, value_cache, - temp_qkv_buffer + k_buffer_offset, - qkv_buffer + v_buffer_offset)); + batch_size, head_size, num_heads, + max_threads_per_block, 1, value_cache, + temp_qkv_buffer + k_buffer_offset, + qkv_buffer + v_buffer_offset)); } } } @@ -214,7 +294,7 @@ Status DecoderQkvToContext( false, 1.0f, false, nullptr, mask_filter_value)); } else { ORT_RETURN_IF_ERROR(ComputeSoftmax(stream, kv_sequence_length, sequence_length, batch_size, - num_heads, nullptr, scratch1, scratch2, false)); + num_heads, nullptr, scratch1, scratch2, false)); } // compute P*V (as V*P), and store in scratch3: BxNxSxH diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h index 896520e4b4..5b6ec6de70 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h @@ -5,6 +5,7 @@ #include #include +#include "contrib_ops/cpu/bert/attention_common.h" #include "core/providers/rocm/shared_inc/rocm_utils.h" #include "core/providers/rocm/tunable/rocm_tunable.h" @@ -181,6 +182,33 @@ class CompatRocblasMathModeSetter { } }; +enum AttentionMode { + // Q,K,V,PastK,PastV,PresentK,PresentV + QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE, + QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE, + QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE, + QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE, + QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE, + BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE, + BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE, + BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH, + BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH, + BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH, + BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH, + BLN3H_NONE_NONE_NONE_NONE_NONE_NONE, + BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE, +}; + +struct RocmAttentionParameters : AttentionParameters { + AttentionMode mode; +}; + +Status ClassifyAttentionMode(const std::string& op, + RocmAttentionParameters* attn, + const std::vector& qkv, + const std::vector& past, + const std::vector& present); + } // namespace rocm } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh index 40b8b54b7f..c8febbc795 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh @@ -11,7 +11,7 @@ T: total sequence length N: num of heads H: head dimension -The following use qkv_format == Q_K_V_BNSH as a example: +The following use qkv_format == Q_K_V_BNSH (mode == BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE) as a example: BN: B*N, which is the batch size of GEMMs. NOTE: To be disambiguated with batch size of Attention Op @@ -57,6 +57,46 @@ non-biased, masked, convert the mask to [B,1,1_or_S,T] and perform broadcast Broadcast add is not actually perform the broadcasting, just broadcast the load operation from memory. The impl details are in composable kernels. The scale and add logic is performed via Acc0ElementOp +# Classified modes: + +| Q | K | V | past(K)| pastV | present(K)| presentV | Op, desc +| ---- | ---- | ---- | ------ | ----- | --------- | -------- | --------- +| QFMT | KFMT | VFMT | - | - | - | - | A, basic, qkv is impl dependent by qkv_format +| QFMT | KFMT | VFMT | 2BNPH | - | 2BNTH *^ | - | A, past_present_share_buffer = false, qkv is impl dependent by qkv_format +| QFMT | KFMT | VFMT | 2BNMH | - | 2BNMH *^ | - | A, past_present_share_buffer = true, qkv is impl dependent by qkv_format +| BSNH | BLNH*| BLNH^| - | - | - | - | MHA basic +| BSNH | BNLH*| BNLH^| - | - | - | - | MHA cross, pass_past_in_kv = true +| BSNH | - | - | - | - | BNLH * | BNLH ^ | MHA cross, pass_past_in_kv = false +| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH * | BNTH ^ | MHA self, past_present_share_buffer = false +| BSNH | BNLH | BNLH | BNPH | BNPH | BNTH * | BNTH ^ | MHA self, past_present_share_buffer = false +| BSNH | BLNH | BLNH | BNMH | BNMH | BNMH * | BNMH ^ | MHA self, past_present_share_buffer = true +| BSNH | BNLH | BNLH | BNMH | BNMH | BNMH * | BNMH ^ | MHA self, past_present_share_buffer = true +| BLN3H*^| - | - | - | - | - | - | MHA basic, qkv_packed +| BSNH | BLN2H*^| - | - | - | - | - | MHA basic, kv_packed + +Q, K, V, past(K), pastV, present(K), presentV is the Input of the contrib OpKernel + +About k_buffer and v_buffer, we always explicitly concat past to present and use present_k for k_buffer and present_b for v_buffer + +- Marked with `*` indicate the Tensor is used for k_buffer passing. +- Marked with `^` indicate the Tensor is used for v_buffer passing. + +# Supported Op + +- A: Attention +- MHA: MultiHeadAttention + +# Dim Value + +- B: batch_size +- N: num_heads +- H: head_size + +- S: sequence_length +- L: kv_sequence_length +- P: past_sequence_length +- T: total_sequence_length = P + L +- M: max_sequence_length */ #include "core/framework/tensor_shape.h" @@ -88,66 +128,151 @@ inline int3 Get2DMaskStrides(int total_sequence_length) { return {total_sequence_length, 0, 1}; } +// A stride maps from natural coordinate to physical offset of underlying memory storage buffer offset. We need to +// specify both of the natural coordinate order, say (b,n,s,h), (b,s,n,h) or (b,n,h,s), and memory order, say BNSH or +// BSNH, to determain the strides. To obtain the offset, we just do the inner product of coordinate with the strides. +// This wrapper create the stride vector from the physical dimension (or physical shape). +struct Strides { + // Create the strides for BNSH physically indexed memory buffer + static Strides BNSHMemory(int batch_dim, + int num_head_dim, + int seqlen_dim, + int head_size_dim) { + ORT_UNUSED_PARAMETER(batch_dim); + return Strides{longlong4{ + static_cast(num_head_dim) * seqlen_dim * head_size_dim, + static_cast(seqlen_dim) * head_size_dim, + static_cast(head_size_dim), + static_cast(1), + }}; + } + + // Create the strides for BSNH physically indexed memory buffer + static Strides BSNHMemory(int batch_dim, + int seqlen_dim, + int num_head_dim, + int head_size_dim) { + ORT_UNUSED_PARAMETER(batch_dim); + return Strides{longlong4{ + static_cast(seqlen_dim) * num_head_dim * head_size_dim, + static_cast(head_size_dim), + static_cast(num_head_dim) * head_size_dim, + static_cast(1), + }}; + } + + template + T ForBNSHCoord() { + using E = typename T::value_type; + return T{static_cast(strides_for_bnsh_coord.x), + static_cast(strides_for_bnsh_coord.y), + static_cast(strides_for_bnsh_coord.z), + static_cast(strides_for_bnsh_coord.w)}; + } + + template + T ForBSNHCoord() { + using E = typename T::value_type; + return T{static_cast(strides_for_bnsh_coord.x), + static_cast(strides_for_bnsh_coord.z), + static_cast(strides_for_bnsh_coord.y), + static_cast(strides_for_bnsh_coord.w)}; + } + + template + T ForBNHSCoord() { + using E = typename T::value_type; + return T{static_cast(strides_for_bnsh_coord.x), + static_cast(strides_for_bnsh_coord.y), + static_cast(strides_for_bnsh_coord.w), + static_cast(strides_for_bnsh_coord.z)}; + } + + // store intermediate strides in the canonical (b,n,s,h) coordinate order + longlong4 strides_for_bnsh_coord; +}; + template -std::tuple GetQkvBuffers( - const AttentionParameters* attn, - const T* query, - const T* key, - const T* value) { - switch (attn->qkv_format) { - case Q_K_V_BNSH: - case Q_K_V_BSNH: +std::tuple ConvertToOffsetedBufferViews( + const RocmAttentionParameters* attn, + const T* query = nullptr, // q or packed_qkv + const T* key = nullptr, // k or packed kv + const T* value = nullptr, // + const T* present = nullptr, // present or present_k + const T* present_v = nullptr) { + ORT_UNUSED_PARAMETER(present_v); + switch (attn->mode) { + case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: { return {reinterpret_cast(query), reinterpret_cast(key), reinterpret_cast(value)}; - case Q_KV_BSNH_BSN2H: { + } + case QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE: { + auto offset = static_cast(attn->batch_size) * attn->num_heads * attn->total_sequence_length * + attn->head_size; + return {reinterpret_cast(query), + reinterpret_cast(present), + reinterpret_cast(present) + offset}; + } + case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: { auto packed_kv = reinterpret_cast(key); return {reinterpret_cast(query), packed_kv, packed_kv + attn->head_size}; } - case QKV_BSN3H: { + case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE: { auto packed_qkv = reinterpret_cast(query); return {packed_qkv, packed_qkv + 1 * attn->head_size, packed_qkv + 2 * attn->head_size}; } default: - return {nullptr, nullptr, nullptr}; + ORT_ENFORCE("unreachable"); + return {}; } } -inline std::tuple GetQkvStrides(const AttentionParameters* attn) { +inline std::tuple GetQkvStrides(const RocmAttentionParameters* attn) { // G0 not used, because it is the slowest dimension - const int& G1 = attn->num_heads; - const int& M = attn->sequence_length; - const int& N = attn->total_sequence_length; - const int& K = attn->head_size; - const int& O = attn->v_head_size; + const int& B = attn->batch_size; + const int& N = attn->num_heads; + const int& S = attn->sequence_length; + const int& L = attn->kv_sequence_length; + // const int& T = attn->total_sequence_length; + const int& H = attn->head_size; + const int& Hv = attn->v_head_size; - int4 q_strides, k_strides, v_strides; - switch (attn->qkv_format) { - case Q_K_V_BNSH: - q_strides = {G1 * M * K, M * K, K, 1}; - k_strides = {G1 * N * K, N * K, K, 1}; // matrices are transposed - v_strides = {G1 * N * O, N * O, 1, O}; - break; - case Q_KV_BSNH_BSN2H: - ORT_ENFORCE(K == O); - q_strides = {M * G1 * K, K, G1 * K, 1}; // [G0, M, G1, K] layout - k_strides = {N * G1 * 2 * K, 2 * K, G1 * 2 * K, 1}; // [G0, N, G1, K] layout - v_strides = {N * G1 * 2 * O, 2 * O, 1, G1 * 2 * O}; // [G0, N, G1, O] layout - break; - case QKV_BSN3H: - ORT_ENFORCE(K == O); - q_strides = {M * G1 * 3 * K, 3 * K, G1 * 3 * K, 1}; // [G0, M, G1, K] layout - k_strides = {N * G1 * 3 * K, 3 * K, G1 * 3 * K, 1}; // [G0, N, G1, K] layout - v_strides = {N * G1 * 3 * O, 3 * O, 1, G1 * 3 * O}; // [G0, N, G1, O] layout - break; + switch (attn->mode) { + case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: + if (attn->qkv_format == Q_K_V_BNSH) { + return { + Strides::BNSHMemory(B, N, S, H), + Strides::BNSHMemory(B, N, L, H), + Strides::BNSHMemory(B, N, L, Hv), + }; + } else if (attn->qkv_format == Q_K_V_BSNH) { + return { + Strides::BSNHMemory(B, S, N, H), + Strides::BSNHMemory(B, L, N, H), + Strides::BSNHMemory(B, L, N, Hv), + }; + } + case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: + return { + Strides::BSNHMemory(B, S, N, H), + Strides::BSNHMemory(B, L, N, 2 * H), + Strides::BSNHMemory(B, L, N, 2 * Hv), + }; + case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE: + return { + Strides::BSNHMemory(B, L, N, 3 * H), + Strides::BSNHMemory(B, L, N, 3 * H), + Strides::BSNHMemory(B, L, N, 3 * Hv), + }; default: - break; + ORT_ENFORCE("unreachable"); + return {}; } - return {q_strides, k_strides, v_strides}; } inline std::tuple GetRawMaskBufferAddrSizesAndStrides( - const int* buffer, const AttentionParameters* attn) { + const int* buffer, const RocmAttentionParameters* attn) { const int* offseted_buffer{buffer}; // how to view the mask buffer int3 sizes{0, 0, 0}; // the logical shape of the view int3 strides{-1, -1, -1}; // the physical memory layout @@ -187,7 +312,7 @@ struct GemmSoftmaxGemmPermuteParams : onnxruntime::rocm::tunable::OpParams { "_H", attention->head_size, bias_buffer != nullptr ? "_B" : "_NB", "_M", mask_index_dims.size(), - "_QKV", attention->qkv_format); + "_MODE", attention->mode); } std::tuple GetGemmsMNKOBatch() const { @@ -201,7 +326,7 @@ struct GemmSoftmaxGemmPermuteParams : onnxruntime::rocm::tunable::OpParams { } rocblas_handle handle; - const AttentionParameters* attention; + const RocmAttentionParameters* attention; const hipDeviceProp_t* device_prop; float scale; @@ -240,7 +365,7 @@ struct GemmSoftmaxGemmPermuteGenericPipeline { return {gemm1_out, softmax_out, gemm2_out}; } - inline static size_t GetWorkspaceNumBytes(const AttentionParameters* attn) { + inline static size_t GetWorkspaceNumBytes(const RocmAttentionParameters* attn) { return GetAttentionWorkspaceSize( sizeof(T), attn->batch_size, @@ -333,7 +458,8 @@ struct GemmSoftmaxGemmPermuteGenericPipeline { static Status Run(const GemmSoftmaxGemmPermuteParams* params, bool use_persistent_softmax) { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->attention->qkv_format != Q_K_V_BNSH, "GenericPipeline only supports qkv_format as Q_K_V_BNSH"); + params->attention->qkv_format != Q_K_V_BNSH, + "GenericPipeline only supports qkv_format as Q_K_V_BNSH, got", params->attention->qkv_format); ORT_RETURN_IF_ERROR(Gemm1(params)); if (UseRawAttentionMask(params)) { @@ -355,19 +481,25 @@ class GemmSoftmaxGemmPermuteTunableOp : public tunable::TunableOpqkv_format) { - case Q_K_V_BNSH: - case Q_K_V_BSNH: - case Q_KV_BSNH_BSN2H: - case QKV_BSN3H: + inline static bool IsSupportedMode(const RocmAttentionParameters* attn) { + switch (attn->mode) { + case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: + case QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE: + // depends on qkv format + if (attn->qkv_format == Q_K_V_BNSH || attn->qkv_format == Q_K_V_BSNH) { + return true; + } else { + return false; + } + case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: + case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE: return true; default: return false; } } - inline static bool IsSupportedMaskType(const AttentionParameters* attn) { + inline static bool IsSupportedMaskType(const RocmAttentionParameters* attn) { switch (attn->mask_type) { case MASK_NONE: case MASK_2D_DUMMY: @@ -380,7 +512,7 @@ class GemmSoftmaxGemmPermuteTunableOp : public tunable::TunableOp* params) -> Status { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !GemmSoftmaxGemmPermuteTunableOp::IsSupportedQkvFormat(params->attention), - "qkv format is not supported, got ", params->attention->qkv_format); + !GemmSoftmaxGemmPermuteTunableOp::IsSupportedMode(params->attention), + "attention mode is not supported, got ", params->attention->mode); if constexpr (USE_BIAS) { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( params->bias_buffer == nullptr, "biased version only support input with bias"); @@ -512,11 +644,11 @@ auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() { auto [qs, ks, vs] = GetQkvStrides(attn); std::vector q_buffer_lengths = {G0, G1, M, K}; - std::vector q_buffer_strides = {qs.x, qs.y, qs.z, qs.w}; + std::vector q_buffer_strides = qs.template ForBNSHCoord>(); std::vector k_buffer_lengths = {G0, G1, N, K}; - std::vector k_buffer_strides = {ks.x, ks.y, ks.z, ks.w}; + std::vector k_buffer_strides = ks.template ForBNSHCoord>(); std::vector v_buffer_lengths = {G0, G1, O, N}; - std::vector v_buffer_strides = {vs.x, vs.y, vs.z, vs.w}; + std::vector v_buffer_strides = vs.template ForBNHSCoord>(); std::vector out_buffer_lengths = {G0, G1, M, O}; std::vector out_buffer_strides = {M * G1 * O, O, G1 * O, 1}; // permute 0213 diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu index 27d9c84a6b..aa8a87a1da 100644 --- a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu @@ -63,7 +63,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { "bias, key_padding_mask and attention cache is not supported"); auto& device_prop = GetDeviceProp(); - AttentionParameters attn; + RocmAttentionParameters attn; ORT_RETURN_IF_ERROR( multihead_attention_helper::CheckInputs( query, key, value, bias, @@ -95,6 +95,13 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { // TODO: Add support for attention cache ORT_ENFORCE(present_key == nullptr && present_value == nullptr, "attention cache is not supported"); + ORT_RETURN_IF_ERROR(ClassifyAttentionMode( + Node().OpType(), &attn, + /*qkv=*/{query, key, value}, + /*past=*/{past_key, past_value}, + /*present=*/{present_key, present_value})); + + using HipT = typename ToHipType::MappedType; using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp; auto workspace_bytes = AttentionTunableOp::GetWorkspaceNumBytes(&attn); @@ -107,7 +114,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { params.attention = &attn; params.device_prop = &device_prop; params.scale = scale_ == 0 ? 1.0f / sqrt(attn.head_size) : scale_; - std::tie(params.q_buffer, params.k_buffer, params.v_buffer) = GetQkvBuffers( + std::tie(params.q_buffer, params.k_buffer, params.v_buffer) = ConvertToOffsetedBufferViews( &attn, query->DataRaw(), key == nullptr ? nullptr : key->DataRaw(), diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py index 1d8bbbdaa1..2eb690b43b 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py @@ -438,7 +438,6 @@ if __name__ == "__main__": group.add_argument("--scale", type=float, default=None, help="default to 1.0/sqrt(head_size)") group.add_argument( "--qkv_format", - type=lambda name: getattr(ke.qkv_format, name), default="Q_K_V_BNSH", choices=[ "Q_K_V_BNSH", # non-packed, permuted @@ -462,6 +461,6 @@ if __name__ == "__main__": args.biased, args.mask_dim, args.scale, - args.qkv_format, + getattr(ke.qkv_format, args.qkv_format), sort=args.sort, ) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.cu index e87a5c53ca..8277a7737b 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.cu @@ -39,7 +39,10 @@ class IGemmSoftmaxGemmPermuteKernelExplorer : public IKernelExplorer { attn_.batch_size = batch; attn_.sequence_length = seqlen; - attn_.kv_sequence_length = seqlen; // NOTE: not used + // NOTE: This test wrapper does not support past present concat, then past_sequence_length = 0 always holds. + // Thus, total_sequence_length = past_sequence_length + kv_sequence_length further implies + // total_sequence_length == kv_sequence_length + attn_.kv_sequence_length = total_seqlen; attn_.past_sequence_length = 0; attn_.original_past_sequence_length = 0; // NOTE: not used attn_.total_sequence_length = total_seqlen; @@ -66,6 +69,20 @@ class IGemmSoftmaxGemmPermuteKernelExplorer : public IKernelExplorer { ORT_ENFORCE(false, "mask type not supported"); } attn_.qkv_format = qkv_format; + switch (qkv_format) { + case contrib::Q_K_V_BNSH: + case contrib::Q_K_V_BSNH: + attn_.mode = contrib::rocm::QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE; + break; + case contrib::Q_KV_BSNH_BSN2H: + attn_.mode = contrib::rocm::BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE; + break; + case contrib::QKV_BSN3H: + attn_.mode = contrib::rocm::BLN3H_NONE_NONE_NONE_NONE_NONE_NONE; + break; + default: + ORT_NOT_IMPLEMENTED("qkv_format ", qkv_format, " is not implemented"); + } device_prop = GetEp()->GetDeviceProp(); @@ -76,7 +93,7 @@ class IGemmSoftmaxGemmPermuteKernelExplorer : public IKernelExplorer { params_.device_prop = &device_prop; params_.scale = scale; - std::tie(params_.q_buffer, params_.k_buffer, params_.v_buffer) = GetQkvBuffers( + std::tie(params_.q_buffer, params_.k_buffer, params_.v_buffer) = ConvertToOffsetedBufferViews( &attn_, Q.ptr(), K.has_value() ? K->ptr() : nullptr, V.has_value() ? V->ptr() : nullptr); if (attn_bias.has_value()) { @@ -114,7 +131,7 @@ class IGemmSoftmaxGemmPermuteKernelExplorer : public IKernelExplorer { using ParamsT = contrib::rocm::GemmSoftmaxGemmPermuteParams; rocblas_handle rocblas_handle_; hipDeviceProp_t device_prop; - contrib::AttentionParameters attn_; + contrib::rocm::RocmAttentionParameters attn_; ParamsT params_; std::shared_ptr workspace_; };