From af89496fc7cab02d01d65291e1d0d29e4fc7657d Mon Sep 17 00:00:00 2001 From: cloudhan Date: Thu, 13 Jul 2023 09:31:31 +0800 Subject: [PATCH] Allow generic pipeline to accept some params for cross attention (#16519) Allow `GemmSoftmaxGemmPermuteGenericPipeline` to be used in some cross attention, that opt for rocblas instead of ck if rocblas is better to the small problem. The improvement is ~20% e2e time reduction on some test cases for whisper large. **Note:** This is because ck has some performance issue if the sequence length is merely 1, and should be improved in the future. --- ...ed_gemm_softmax_gemm_permute_pipelines.cuh | 92 +++++++++++++++++-- onnxruntime/core/framework/tunable.h | 11 ++- 2 files changed, 89 insertions(+), 14 deletions(-) diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh index 14e1430b41..6ad3c325ff 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh @@ -409,6 +409,20 @@ struct GemmSoftmaxGemmPermuteParams : onnxruntime::rocm::tunable::OpParams { void* workspace_buffer{nullptr}; }; +inline bool IsKVBNMH(AttentionMode mode) { + switch (mode) { + case QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE: + case QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE: + case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: + case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: + case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: + case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: + return true; + default: + return false; + } +} + template struct GemmSoftmaxGemmPermuteGenericPipeline { static bool UseRawAttentionMask(const GemmSoftmaxGemmPermuteParams* params) { @@ -441,6 +455,12 @@ struct GemmSoftmaxGemmPermuteGenericPipeline { inline static Status Gemm1(const GemmSoftmaxGemmPermuteParams* params) { auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch(); auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); + + int k_buffer_stride = n * k; + if (IsKVBNMH(params->attention->mode)) { + k_buffer_stride = params->attention->max_sequence_length * params->attention->head_size; + } + // GEMM1 [m,k] * [n,k]' -> [m,n] return blas::row_major::StridedBatchedGemm( params->TuningContext(), params->Stream(), params->handle, @@ -449,7 +469,7 @@ struct GemmSoftmaxGemmPermuteGenericPipeline { // For raw attention mask, the scalar is moved to softmax computation. /*alpha=*/UseRawAttentionMask(params) ? 1.0f : params->scale, params->q_buffer, k, m * k, - params->k_buffer, k, n * k, + params->k_buffer, k, k_buffer_stride, /*beta=*/0.0f, gemm1_out, n, m * n, batch); @@ -494,6 +514,12 @@ struct GemmSoftmaxGemmPermuteGenericPipeline { inline static Status Gemm2(const GemmSoftmaxGemmPermuteParams* params) { auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch(); auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); + + int v_buffer_stride = n * o; + if (IsKVBNMH(params->attention->mode)) { + v_buffer_stride = params->attention->max_sequence_length * params->attention->v_head_size; + } + // GEMM2 [m,n] * [n,o] -> [m,o] // semantically, the output buffer contains B*N matrices of shape [S,H], compactly, thus B,N,S,H. return blas::row_major::StridedBatchedGemm( @@ -502,7 +528,7 @@ struct GemmSoftmaxGemmPermuteGenericPipeline { m, o, n, /*alpha=*/1.0f, softmax_out, n, m * n, - params->v_buffer, o, n * o, + params->v_buffer, o, v_buffer_stride, /*beta=*/0.0f, gemm2_out, o, m * o, batch); @@ -519,10 +545,53 @@ struct GemmSoftmaxGemmPermuteGenericPipeline { params->device_prop->maxThreadsPerBlock, false, gemm2_out, params->out_buffer); } + static Status GetSupportedStatus(const GemmSoftmaxGemmPermuteParams* params) { + const auto& attn = params->attention; + // TODO: address the BNMH k,v strides + switch (attn->mode) { + case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: + case QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE: + case QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE: + case QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE: + case QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE: + if (attn->qkv_format == Q_K_V_BNSH) { + return Status::OK(); + } else { + return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH, got ", + attn->qkv_format); + } + case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE: + return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH but k, v are BLNH"); + case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE: + case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH: + case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH: + case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH: + case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH: + case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: + case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: + case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: + case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: + // If sequence_length is 1, query of B1NH can be simply viewed as BN1H. + if (attn->sequence_length == 1) { + return Status::OK(); + } else { + return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH, ", + "only if sequence_length is 1, query of BSNH can be viewed as BNSH"); + } + case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE: + case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: + return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH"); + default: + return TUNABLE_OP_UNSUPPORTED("unknonw"); + } + return TUNABLE_OP_UNSUPPORTED("unknonw case"); + } + static Status Run(const GemmSoftmaxGemmPermuteParams* params, bool use_persistent_softmax) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->attention->qkv_format != Q_K_V_BNSH, - "GenericPipeline only supports qkv_format as Q_K_V_BNSH, got", params->attention->qkv_format); + auto supported_status = GetSupportedStatus(params); + if (!supported_status.IsOK()) { + return supported_status; + } ORT_RETURN_IF_ERROR(Gemm1(params)); if (UseRawAttentionMask(params)) { @@ -586,11 +655,16 @@ class GemmSoftmaxGemmPermuteTunableOp : public tunable::TunableOp::GetWorkspaceNumBytes(attn); + +#ifdef USE_COMPOSABLE_KERNEL + if (IsSupportedMaskType(attn)) { + auto [buffer, sizes, strides] = GetRawMaskBufferAddrSizesAndStrides(nullptr, attn); + num_bytes = std::max(num_bytes, sizeof(T) * sizes.x * sizes.y * sizes.z); } - auto [buffer, sizes, strides] = GetRawMaskBufferAddrSizesAndStrides(nullptr, attn); - return sizeof(T) * sizes.x * sizes.y * sizes.z; +#endif + + return num_bytes; } template diff --git a/onnxruntime/core/framework/tunable.h b/onnxruntime/core/framework/tunable.h index 3c3e957f75..65057742cd 100644 --- a/onnxruntime/core/framework/tunable.h +++ b/onnxruntime/core/framework/tunable.h @@ -118,11 +118,12 @@ class Op { // NOTE: onnxruntime's Status currently does not have a StatusCode::UNSUPPORTED. Currently, we do not want to extend the // enum. So we reuse StatusCode::INVALID_ARGUMENT for this purpose. It can be interpreted as "The input argument is not // valid for this specialized kernel implementation.". This semantic is crucial for the tuning mechanism. -#define TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(condition, ...) \ - do { \ - if (condition) { \ - return ORT_MAKE_STATUS(NONE, INVALID_ARGUMENT, __VA_ARGS__); \ - } \ +#define TUNABLE_OP_UNSUPPORTED(...) ORT_MAKE_STATUS(NONE, INVALID_ARGUMENT, __VA_ARGS__) +#define TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(condition, ...) \ + do { \ + if (condition) { \ + return TUNABLE_OP_UNSUPPORTED(__VA_ARGS__); \ + } \ } while (false) template