Allow generic pipeline to accept some params for cross attention (#16519)

Allow `GemmSoftmaxGemmPermuteGenericPipeline<T>` to be used in some
cross attention, that opt for rocblas instead of ck if rocblas is
better to the small problem. The improvement is ~20% e2e time reduction
on some test cases for whisper large.

**Note:** This is because ck has some performance issue if the sequence
length is merely 1, and should be improved in the future.
This commit is contained in:
cloudhan 2023-07-13 09:31:31 +08:00 committed by GitHub
parent 3866614519
commit af89496fc7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 89 additions and 14 deletions

View file

@ -409,6 +409,20 @@ struct GemmSoftmaxGemmPermuteParams : onnxruntime::rocm::tunable::OpParams {
void* workspace_buffer{nullptr};
};
inline bool IsKVBNMH(AttentionMode mode) {
switch (mode) {
case QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE:
case QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE:
case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH:
case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH:
case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH:
case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH:
return true;
default:
return false;
}
}
template <typename T>
struct GemmSoftmaxGemmPermuteGenericPipeline {
static bool UseRawAttentionMask(const GemmSoftmaxGemmPermuteParams<T>* params) {
@ -441,6 +455,12 @@ struct GemmSoftmaxGemmPermuteGenericPipeline {
inline static Status Gemm1(const GemmSoftmaxGemmPermuteParams<T>* params) {
auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch();
auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params);
int k_buffer_stride = n * k;
if (IsKVBNMH(params->attention->mode)) {
k_buffer_stride = params->attention->max_sequence_length * params->attention->head_size;
}
// GEMM1 [m,k] * [n,k]' -> [m,n]
return blas::row_major::StridedBatchedGemm(
params->TuningContext(), params->Stream(), params->handle,
@ -449,7 +469,7 @@ struct GemmSoftmaxGemmPermuteGenericPipeline {
// For raw attention mask, the scalar is moved to softmax computation.
/*alpha=*/UseRawAttentionMask(params) ? 1.0f : params->scale,
params->q_buffer, k, m * k,
params->k_buffer, k, n * k,
params->k_buffer, k, k_buffer_stride,
/*beta=*/0.0f,
gemm1_out, n, m * n,
batch);
@ -494,6 +514,12 @@ struct GemmSoftmaxGemmPermuteGenericPipeline {
inline static Status Gemm2(const GemmSoftmaxGemmPermuteParams<T>* params) {
auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch();
auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params);
int v_buffer_stride = n * o;
if (IsKVBNMH(params->attention->mode)) {
v_buffer_stride = params->attention->max_sequence_length * params->attention->v_head_size;
}
// GEMM2 [m,n] * [n,o] -> [m,o]
// semantically, the output buffer contains B*N matrices of shape [S,H], compactly, thus B,N,S,H.
return blas::row_major::StridedBatchedGemm(
@ -502,7 +528,7 @@ struct GemmSoftmaxGemmPermuteGenericPipeline {
m, o, n,
/*alpha=*/1.0f,
softmax_out, n, m * n,
params->v_buffer, o, n * o,
params->v_buffer, o, v_buffer_stride,
/*beta=*/0.0f,
gemm2_out, o, m * o,
batch);
@ -519,10 +545,53 @@ struct GemmSoftmaxGemmPermuteGenericPipeline {
params->device_prop->maxThreadsPerBlock, false, gemm2_out, params->out_buffer);
}
static Status GetSupportedStatus(const GemmSoftmaxGemmPermuteParams<T>* params) {
const auto& attn = params->attention;
// TODO: address the BNMH k,v strides
switch (attn->mode) {
case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE:
case QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE:
case QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE:
case QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE:
case QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE:
if (attn->qkv_format == Q_K_V_BNSH) {
return Status::OK();
} else {
return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH, got ",
attn->qkv_format);
}
case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE:
return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH but k, v are BLNH");
case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE:
case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH:
case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH:
case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH:
case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH:
case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH:
case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH:
case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH:
case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH:
// If sequence_length is 1, query of B1NH can be simply viewed as BN1H.
if (attn->sequence_length == 1) {
return Status::OK();
} else {
return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH, ",
"only if sequence_length is 1, query of BSNH can be viewed as BNSH");
}
case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE:
case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE:
return TUNABLE_OP_UNSUPPORTED("GenericPipeline only supports qkv_format as Q_K_V_BNSH");
default:
return TUNABLE_OP_UNSUPPORTED("unknonw");
}
return TUNABLE_OP_UNSUPPORTED("unknonw case");
}
static Status Run(const GemmSoftmaxGemmPermuteParams<T>* params, bool use_persistent_softmax) {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->attention->qkv_format != Q_K_V_BNSH,
"GenericPipeline only supports qkv_format as Q_K_V_BNSH, got", params->attention->qkv_format);
auto supported_status = GetSupportedStatus(params);
if (!supported_status.IsOK()) {
return supported_status;
}
ORT_RETURN_IF_ERROR(Gemm1(params));
if (UseRawAttentionMask(params)) {
@ -586,11 +655,16 @@ class GemmSoftmaxGemmPermuteTunableOp : public tunable::TunableOp<GemmSoftmaxGem
}
inline static size_t GetWorkspaceNumBytes(const RocmAttentionParameters* attn) {
if (!IsSupportedMaskType(attn)) {
return 0;
size_t num_bytes = GemmSoftmaxGemmPermuteGenericPipeline<T>::GetWorkspaceNumBytes(attn);
#ifdef USE_COMPOSABLE_KERNEL
if (IsSupportedMaskType(attn)) {
auto [buffer, sizes, strides] = GetRawMaskBufferAddrSizesAndStrides(nullptr, attn);
num_bytes = std::max(num_bytes, sizeof(T) * sizes.x * sizes.y * sizes.z);
}
auto [buffer, sizes, strides] = GetRawMaskBufferAddrSizesAndStrides(nullptr, attn);
return sizeof(T) * sizes.x * sizes.y * sizes.z;
#endif
return num_bytes;
}
template <int VecSize, typename Converter>

View file

@ -118,11 +118,12 @@ class Op {
// NOTE: onnxruntime's Status currently does not have a StatusCode::UNSUPPORTED. Currently, we do not want to extend the
// enum. So we reuse StatusCode::INVALID_ARGUMENT for this purpose. It can be interpreted as "The input argument is not
// valid for this specialized kernel implementation.". This semantic is crucial for the tuning mechanism.
#define TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(condition, ...) \
do { \
if (condition) { \
return ORT_MAKE_STATUS(NONE, INVALID_ARGUMENT, __VA_ARGS__); \
} \
#define TUNABLE_OP_UNSUPPORTED(...) ORT_MAKE_STATUS(NONE, INVALID_ARGUMENT, __VA_ARGS__)
#define TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(condition, ...) \
do { \
if (condition) { \
return TUNABLE_OP_UNSUPPORTED(__VA_ARGS__); \
} \
} while (false)
template <typename ParamsT, typename TimerT>