mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Allow generic pipeline to accept some params for cross attention (#16519)
Allow `GemmSoftmaxGemmPermuteGenericPipeline<T>` to be used in some cross attention, that opt for rocblas instead of ck if rocblas is better to the small problem. The improvement is ~20% e2e time reduction on some test cases for whisper large. **Note:** This is because ck has some performance issue if the sequence length is merely 1, and should be improved in the future.
This commit is contained in:
parent
3866614519
commit
af89496fc7
2 changed files with 89 additions and 14 deletions
|
|
@ -409,6 +409,20 @@ struct GemmSoftmaxGemmPermuteParams : onnxruntime::rocm::tunable::OpParams {
|
|||
void* workspace_buffer{nullptr};
|
||||
};
|
||||
|
||||
inline bool IsKVBNMH(AttentionMode mode) {
|
||||
switch (mode) {
|
||||
case QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE:
|
||||
case QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE:
|
||||
case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH:
|
||||
case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH:
|
||||
case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH:
|
||||
case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct GemmSoftmaxGemmPermuteGenericPipeline {
|
||||
static bool UseRawAttentionMask(const GemmSoftmaxGemmPermuteParams<T>* params) {
|
||||
|
|
@ -441,6 +455,12 @@ struct GemmSoftmaxGemmPermuteGenericPipeline {
|
|||
inline static Status Gemm1(const GemmSoftmaxGemmPermuteParams<T>* params) {
|
||||
auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch();
|
||||
auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params);
|
||||
|
||||
int k_buffer_stride = n * k;
|
||||
if (IsKVBNMH(params->attention->mode)) {
|
||||
k_buffer_stride = params->attention->max_sequence_length * params->attention->head_size;
|
||||
}
|
||||
|
||||
// GEMM1 [m,k] * [n,k]' -> [m,n]
|
||||
return blas::row_major::StridedBatchedGemm(
|
||||
params->TuningContext(), params->Stream(), params->handle,
|
||||
|
|
@ -449,7 +469,7 @@ struct GemmSoftmaxGemmPermuteGenericPipeline {
|
|||
// For raw attention mask, the scalar is moved to softmax computation.
|
||||
/*alpha=*/UseRawAttentionMask(params) ? 1.0f : params->scale,
|
||||
params->q_buffer, k, m * k,
|
||||
params->k_buffer, k, n * k,
|
||||
params->k_buffer, k, k_buffer_stride,
|
||||
/*beta=*/0.0f,
|
||||
gemm1_out, n, m * n,
|
||||
batch);
|
||||
|
|
@ -494,6 +514,12 @@ struct GemmSoftmaxGemmPermuteGenericPipeline {
|
|||
inline static Status Gemm2(const GemmSoftmaxGemmPermuteParams<T>* params) {
|
||||
auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch();
|
||||
auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params);
|
||||
|
||||
int v_buffer_stride = n * o;
|
||||
if (IsKVBNMH(params->attention->mode)) {
|
||||
v_buffer_stride = params->attention->max_sequence_length * params->attention->v_head_size;
|
||||
}
|
||||
|
||||
// GEMM2 [m,n] * [n,o] -> [m,o]
|
||||
// semantically, the output buffer contains B*N matrices of shape [S,H], compactly, thus B,N,S,H.
|
||||
return blas::row_major::StridedBatchedGemm(
|
||||
|
|
@ -502,7 +528,7 @@ struct GemmSoftmaxGemmPermuteGenericPipeline {
|
|||
m, o, n,
|
||||
/*alpha=*/1.0f,
|
||||
softmax_out, n, m * n,
|
||||
params->v_buffer, o, n * o,
|
||||
params->v_buffer, o, v_buffer_stride,
|
||||
/*beta=*/0.0f,
|
||||
gemm2_out, o, m * o,
|
||||
batch);
|
||||
|
|
@ -519,10 +545,53 @@ struct GemmSoftmaxGemmPermuteGenericPipeline {
|
|||
params->device_prop->maxThreadsPerBlock, false, gemm2_out, params->out_buffer);
|
||||
}
|
||||
|
||||
static Status GetSupportedStatus(const GemmSoftmaxGemmPermuteParams<T>* params) {
|
||||
const auto& attn = params->attention;
|
||||
// TODO: address the BNMH k,v strides
|
||||
switch (attn->mode) {
|
||||
case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE:
|
||||
case QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE:
|
||||
case QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE:
|
||||
case QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE:
|
||||
case QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE:
|
||||
if (attn->qkv_format == Q_K_V_BNSH) {
|
||||
return Status::OK();
|
||||
} else {
|
||||
return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH, got ",
|
||||
attn->qkv_format);
|
||||
}
|
||||
case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE:
|
||||
return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH but k, v are BLNH");
|
||||
case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE:
|
||||
case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH:
|
||||
case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH:
|
||||
case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH:
|
||||
case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH:
|
||||
case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH:
|
||||
case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH:
|
||||
case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH:
|
||||
case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH:
|
||||
// If sequence_length is 1, query of B1NH can be simply viewed as BN1H.
|
||||
if (attn->sequence_length == 1) {
|
||||
return Status::OK();
|
||||
} else {
|
||||
return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH, ",
|
||||
"only if sequence_length is 1, query of BSNH can be viewed as BNSH");
|
||||
}
|
||||
case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE:
|
||||
case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE:
|
||||
return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH");
|
||||
default:
|
||||
return TUNABLE_OP_UNSUPPORTED("unknonw");
|
||||
}
|
||||
return TUNABLE_OP_UNSUPPORTED("unknonw case");
|
||||
}
|
||||
|
||||
static Status Run(const GemmSoftmaxGemmPermuteParams<T>* params, bool use_persistent_softmax) {
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
|
||||
params->attention->qkv_format != Q_K_V_BNSH,
|
||||
"GenericPipeline only supports qkv_format as Q_K_V_BNSH, got", params->attention->qkv_format);
|
||||
auto supported_status = GetSupportedStatus(params);
|
||||
if (!supported_status.IsOK()) {
|
||||
return supported_status;
|
||||
}
|
||||
ORT_RETURN_IF_ERROR(Gemm1(params));
|
||||
|
||||
if (UseRawAttentionMask(params)) {
|
||||
|
|
@ -586,11 +655,16 @@ class GemmSoftmaxGemmPermuteTunableOp : public tunable::TunableOp<GemmSoftmaxGem
|
|||
}
|
||||
|
||||
inline static size_t GetWorkspaceNumBytes(const RocmAttentionParameters* attn) {
|
||||
if (!IsSupportedMaskType(attn)) {
|
||||
return 0;
|
||||
size_t num_bytes = GemmSoftmaxGemmPermuteGenericPipeline<T>::GetWorkspaceNumBytes(attn);
|
||||
|
||||
#ifdef USE_COMPOSABLE_KERNEL
|
||||
if (IsSupportedMaskType(attn)) {
|
||||
auto [buffer, sizes, strides] = GetRawMaskBufferAddrSizesAndStrides(nullptr, attn);
|
||||
num_bytes = std::max(num_bytes, sizeof(T) * sizes.x * sizes.y * sizes.z);
|
||||
}
|
||||
auto [buffer, sizes, strides] = GetRawMaskBufferAddrSizesAndStrides(nullptr, attn);
|
||||
return sizeof(T) * sizes.x * sizes.y * sizes.z;
|
||||
#endif
|
||||
|
||||
return num_bytes;
|
||||
}
|
||||
|
||||
template <int VecSize, typename Converter>
|
||||
|
|
|
|||
|
|
@ -118,11 +118,12 @@ class Op {
|
|||
// NOTE: onnxruntime's Status currently does not have a StatusCode::UNSUPPORTED. Currently, we do not want to extend the
|
||||
// enum. So we reuse StatusCode::INVALID_ARGUMENT for this purpose. It can be interpreted as "The input argument is not
|
||||
// valid for this specialized kernel implementation.". This semantic is crucial for the tuning mechanism.
|
||||
#define TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(condition, ...) \
|
||||
do { \
|
||||
if (condition) { \
|
||||
return ORT_MAKE_STATUS(NONE, INVALID_ARGUMENT, __VA_ARGS__); \
|
||||
} \
|
||||
#define TUNABLE_OP_UNSUPPORTED(...) ORT_MAKE_STATUS(NONE, INVALID_ARGUMENT, __VA_ARGS__)
|
||||
#define TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(condition, ...) \
|
||||
do { \
|
||||
if (condition) { \
|
||||
return TUNABLE_OP_UNSUPPORTED(__VA_ARGS__); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
template <typename ParamsT, typename TimerT>
|
||||
|
|
|
|||
Loading…
Reference in a new issue