[ROCm] Add AttentionMode to make attention logic streamline (#15978)

Refactor for future kv cache change.
This commit is contained in:
cloudhan 2023-05-26 12:06:36 +08:00 committed by GitHub
parent b28e927ca4
commit 2cf0ae7d01
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 348 additions and 78 deletions

View file

@ -52,7 +52,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* past_seq_len = context->Input<Tensor>(kPastSequenceLengthInputIndex);
auto& device_prop = GetDeviceProp();
AttentionParameters attn;
RocmAttentionParameters attn;
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(),
weights->Shape(),
bias->Shape(),
@ -63,7 +63,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
device_prop.maxThreadsPerBlock,
past_seq_len));
ORT_ENFORCE(attn.sequence_length == attn.kv_sequence_length); // self attention
ORT_ENFORCE(attn.qkv_format == Q_K_V_BNSH); // non-packed, permuted
ORT_ENFORCE(attn.qkv_format == Q_K_V_BNSH); // non-packed, permuted
TensorShapeVector output_shape(3);
output_shape[0] = static_cast<int64_t>(attn.batch_size);
@ -86,6 +86,13 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
using AttentionGeneric = GemmSoftmaxGemmPermuteGenericPipeline<HipT>;
using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp<HipT>;
ORT_RETURN_IF_ERROR(ClassifyAttentionMode(
Node().OpType(), &attn, /*qkv=*/{}, /*past=*/{past}, /*present=*/{present}));
// TODO: support QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE and QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE
ORT_ENFORCE(attn.mode == QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE ||
attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE ||
attn.mode == QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE);
size_t qkv_project_output_bytes = QkvProjectGeneric::GetOutputNumBytes(&attn);
size_t shared_workspace_bytes = std::max(QkvProjectGeneric::GetWorkspaceNumBytes(&attn),
AttentionGeneric::GetWorkspaceNumBytes(&attn));

View file

@ -78,6 +78,86 @@ inline int3 Get2DMaskStrides(int total_sequence_length) {
return {total_sequence_length, 0, 1};
}
Status ClassifyAttentionMode(
const std::string& op,
RocmAttentionParameters* attn,
const std::vector<const Tensor*>& qkv,
const std::vector<const Tensor*>& past,
const std::vector<Tensor*>& present) {
size_t num_qkv = std::count_if(qkv.cbegin(), qkv.cend(), [](auto it) { return it != nullptr; });
size_t num_past = std::count_if(past.cbegin(), past.cend(), [](auto it) { return it != nullptr; });
size_t num_present = std::count_if(present.cbegin(), present.cend(), [](auto it) { return it != nullptr; });
auto hint = MakeString(num_qkv, " qkv inputs, ", num_past, " past inputs and ", num_present, " present inputs");
LOGS_DEFAULT(VERBOSE) << hint;
if (op == "Attention") {
ORT_ENFORCE(num_qkv == 0);
if (num_past == 0 && num_present == 0) {
attn->mode = QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE;
return Status::OK();
} else if (num_past == 0 && num_present == 1) {
if (attn->past_present_share_buffer == false) {
attn->mode = QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE;
return Status::OK();
} else {
attn->mode = QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE;
return Status::OK();
}
} else if (num_past == 1 && num_present == 1) {
if (attn->past_present_share_buffer == false) {
attn->mode = QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE;
return Status::OK();
} else {
attn->mode = QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE;
return Status::OK();
}
}
} else if (op == "MultiHeadAttention") {
if (num_qkv == 3 && num_past == 0 && num_present == 0) {
if (attn->qkv_format == Q_K_V_BSNH) {
attn->mode = BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE;
return Status::OK();
} else if (attn->pass_past_in_kv) {
attn->mode = BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE;
return Status::OK();
}
} else if (num_qkv == 3 && num_past == 2 && num_present == 2) {
if (attn->past_present_share_buffer == false) {
if (attn->qkv_format == Q_K_V_BSNH) {
attn->mode = BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH;
return Status::OK();
} else if (attn->pass_past_in_kv) {
attn->mode = BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH;
return Status::OK();
}
} else {
if (attn->qkv_format == Q_K_V_BSNH) {
attn->mode = BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH;
return Status::OK();
} else if (attn->pass_past_in_kv) {
attn->mode = BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH;
return Status::OK();
}
}
} else if (num_qkv == 1 && num_past == 0 && num_present == 0) {
if (attn->qkv_format == QKV_BSN3H) {
attn->mode = BLN3H_NONE_NONE_NONE_NONE_NONE_NONE;
return Status::OK();
}
} else if (num_qkv == 2 && num_past == 0 && num_present == 0) {
if (attn->qkv_format == Q_KV_BSNH_BSN2H) {
attn->mode = BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE;
return Status::OK();
}
}
}
return ORT_MAKE_STATUS(
ONNXRUNTIME, INVALID_ARGUMENT,
"Unsupported AttentionMode for ", op, ". Got qkv format ", attn->qkv_format,
". Got ", hint);
}
template <typename T>
Status DecoderQkvToContext(
const hipDeviceProp_t& prop,
@ -117,7 +197,7 @@ Status DecoderQkvToContext(
const T* q = qkv_buffer;
// transpose q and copy them to qkv_buffer
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, head_size,
num_heads, max_threads_per_block, true, gemm_query_buffer, qkv_buffer));
num_heads, max_threads_per_block, true, gemm_query_buffer, qkv_buffer));
const T* k = qkv_buffer + k_buffer_offset;
const T* v = qkv_buffer + v_buffer_offset;
@ -125,31 +205,31 @@ Status DecoderQkvToContext(
if (!static_kv) {
// transpose kv and copy them to qkv_buffer
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads,
max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset));
max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset));
} else {
// transpose kv and copy them to qkv_buffer
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, kv_sequence_length, batch_size, head_size, num_heads,
max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset));
max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset));
}
} else {
if (!static_kv) {
// transpose kv and copy them to temp_buffer
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads,
max_threads_per_block, true, gemm_kv_buffer, temp_qkv_buffer));
max_threads_per_block, true, gemm_kv_buffer, temp_qkv_buffer));
// concat cache-k with k and copy to qkv_buffer
if (nullptr != key_cache) {
ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, sequence_length,
batch_size, head_size, num_heads,
max_threads_per_block, 1, key_cache,
temp_qkv_buffer, qkv_buffer + k_buffer_offset));
batch_size, head_size, num_heads,
max_threads_per_block, 1, key_cache,
temp_qkv_buffer, qkv_buffer + k_buffer_offset));
}
// concat cache-v with v and copy to qkv_buffer
if (nullptr != value_cache) {
ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, sequence_length,
batch_size, head_size, num_heads,
max_threads_per_block, 1, value_cache,
temp_qkv_buffer + k_buffer_offset,
qkv_buffer + v_buffer_offset));
batch_size, head_size, num_heads,
max_threads_per_block, 1, value_cache,
temp_qkv_buffer + k_buffer_offset,
qkv_buffer + v_buffer_offset));
}
}
}
@ -214,7 +294,7 @@ Status DecoderQkvToContext(
false, 1.0f, false, nullptr, mask_filter_value));
} else {
ORT_RETURN_IF_ERROR(ComputeSoftmax<T>(stream, kv_sequence_length, sequence_length, batch_size,
num_heads, nullptr, scratch1, scratch2, false));
num_heads, nullptr, scratch1, scratch2, false));
}
// compute P*V (as V*P), and store in scratch3: BxNxSxH

View file

@ -5,6 +5,7 @@
#include <hip/hip_fp16.h>
#include <rocblas/rocblas.h>
#include "contrib_ops/cpu/bert/attention_common.h"
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#include "core/providers/rocm/tunable/rocm_tunable.h"
@ -181,6 +182,33 @@ class CompatRocblasMathModeSetter {
}
};
enum AttentionMode {
// Q,K,V,PastK,PastV,PresentK,PresentV
QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE,
QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE,
QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE,
QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE,
QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE,
BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE,
BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE,
BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH,
BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH,
BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH,
BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH,
BLN3H_NONE_NONE_NONE_NONE_NONE_NONE,
BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE,
};
struct RocmAttentionParameters : AttentionParameters {
AttentionMode mode;
};
Status ClassifyAttentionMode(const std::string& op,
RocmAttentionParameters* attn,
const std::vector<const Tensor*>& qkv,
const std::vector<const Tensor*>& past,
const std::vector<Tensor*>& present);
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -11,7 +11,7 @@ T: total sequence length
N: num of heads
H: head dimension
The following use qkv_format == Q_K_V_BNSH as a example:
The following use qkv_format == Q_K_V_BNSH (mode == BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE) as a example:
BN: B*N, which is the batch size of GEMMs. NOTE: To be disambiguated with batch size of Attention Op
@ -57,6 +57,46 @@ non-biased, masked, convert the mask to [B,1,1_or_S,T] and perform broadcast
Broadcast add is not actually perform the broadcasting, just broadcast the load operation from memory. The impl details
are in composable kernels. The scale and add logic is performed via Acc0ElementOp
# Classified modes:
| Q | K | V | past(K)| pastV | present(K)| presentV | Op, desc
| ---- | ---- | ---- | ------ | ----- | --------- | -------- | ---------
| QFMT | KFMT | VFMT | - | - | - | - | A, basic, qkv is impl dependent by qkv_format
| QFMT | KFMT | VFMT | 2BNPH | - | 2BNTH *^ | - | A, past_present_share_buffer = false, qkv is impl dependent by qkv_format
| QFMT | KFMT | VFMT | 2BNMH | - | 2BNMH *^ | - | A, past_present_share_buffer = true, qkv is impl dependent by qkv_format
| BSNH | BLNH*| BLNH^| - | - | - | - | MHA basic
| BSNH | BNLH*| BNLH^| - | - | - | - | MHA cross, pass_past_in_kv = true
| BSNH | - | - | - | - | BNLH * | BNLH ^ | MHA cross, pass_past_in_kv = false
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH * | BNTH ^ | MHA self, past_present_share_buffer = false
| BSNH | BNLH | BNLH | BNPH | BNPH | BNTH * | BNTH ^ | MHA self, past_present_share_buffer = false
| BSNH | BLNH | BLNH | BNMH | BNMH | BNMH * | BNMH ^ | MHA self, past_present_share_buffer = true
| BSNH | BNLH | BNLH | BNMH | BNMH | BNMH * | BNMH ^ | MHA self, past_present_share_buffer = true
| BLN3H*^| - | - | - | - | - | - | MHA basic, qkv_packed
| BSNH | BLN2H*^| - | - | - | - | - | MHA basic, kv_packed
Q, K, V, past(K), pastV, present(K), presentV is the Input of the contrib OpKernel
About k_buffer and v_buffer, we always explicitly concat past to present and use present_k for k_buffer and present_b for v_buffer
- Marked with `*` indicate the Tensor is used for k_buffer passing.
- Marked with `^` indicate the Tensor is used for v_buffer passing.
# Supported Op
- A: Attention
- MHA: MultiHeadAttention
# Dim Value
- B: batch_size
- N: num_heads
- H: head_size
- S: sequence_length
- L: kv_sequence_length
- P: past_sequence_length
- T: total_sequence_length = P + L
- M: max_sequence_length
*/
#include "core/framework/tensor_shape.h"
@ -88,66 +128,151 @@ inline int3 Get2DMaskStrides(int total_sequence_length) {
return {total_sequence_length, 0, 1};
}
// A stride maps from natural coordinate to physical offset of underlying memory storage buffer offset. We need to
// specify both of the natural coordinate order, say (b,n,s,h), (b,s,n,h) or (b,n,h,s), and memory order, say BNSH or
// BSNH, to determain the strides. To obtain the offset, we just do the inner product of coordinate with the strides.
// This wrapper create the stride vector from the physical dimension (or physical shape).
struct Strides {
// Create the strides for BNSH physically indexed memory buffer
static Strides BNSHMemory(int batch_dim,
int num_head_dim,
int seqlen_dim,
int head_size_dim) {
ORT_UNUSED_PARAMETER(batch_dim);
return Strides{longlong4{
static_cast<int64_t>(num_head_dim) * seqlen_dim * head_size_dim,
static_cast<int64_t>(seqlen_dim) * head_size_dim,
static_cast<int64_t>(head_size_dim),
static_cast<int64_t>(1),
}};
}
// Create the strides for BSNH physically indexed memory buffer
static Strides BSNHMemory(int batch_dim,
int seqlen_dim,
int num_head_dim,
int head_size_dim) {
ORT_UNUSED_PARAMETER(batch_dim);
return Strides{longlong4{
static_cast<int64_t>(seqlen_dim) * num_head_dim * head_size_dim,
static_cast<int64_t>(head_size_dim),
static_cast<int64_t>(num_head_dim) * head_size_dim,
static_cast<int64_t>(1),
}};
}
template <typename T = longlong4>
T ForBNSHCoord() {
using E = typename T::value_type;
return T{static_cast<E>(strides_for_bnsh_coord.x),
static_cast<E>(strides_for_bnsh_coord.y),
static_cast<E>(strides_for_bnsh_coord.z),
static_cast<E>(strides_for_bnsh_coord.w)};
}
template <typename T = longlong4>
T ForBSNHCoord() {
using E = typename T::value_type;
return T{static_cast<E>(strides_for_bnsh_coord.x),
static_cast<E>(strides_for_bnsh_coord.z),
static_cast<E>(strides_for_bnsh_coord.y),
static_cast<E>(strides_for_bnsh_coord.w)};
}
template <typename T = longlong4>
T ForBNHSCoord() {
using E = typename T::value_type;
return T{static_cast<E>(strides_for_bnsh_coord.x),
static_cast<E>(strides_for_bnsh_coord.y),
static_cast<E>(strides_for_bnsh_coord.w),
static_cast<E>(strides_for_bnsh_coord.z)};
}
// store intermediate strides in the canonical (b,n,s,h) coordinate order
longlong4 strides_for_bnsh_coord;
};
template <typename HipT, typename T>
std::tuple<const HipT*, const HipT*, const HipT*> GetQkvBuffers(
const AttentionParameters* attn,
const T* query,
const T* key,
const T* value) {
switch (attn->qkv_format) {
case Q_K_V_BNSH:
case Q_K_V_BSNH:
std::tuple<const HipT*, const HipT*, const HipT*> ConvertToOffsetedBufferViews(
const RocmAttentionParameters* attn,
const T* query = nullptr, // q or packed_qkv
const T* key = nullptr, // k or packed kv
const T* value = nullptr, //
const T* present = nullptr, // present or present_k
const T* present_v = nullptr) {
ORT_UNUSED_PARAMETER(present_v);
switch (attn->mode) {
case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: {
return {reinterpret_cast<const HipT*>(query),
reinterpret_cast<const HipT*>(key),
reinterpret_cast<const HipT*>(value)};
case Q_KV_BSNH_BSN2H: {
}
case QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE: {
auto offset = static_cast<int64_t>(attn->batch_size) * attn->num_heads * attn->total_sequence_length *
attn->head_size;
return {reinterpret_cast<const HipT*>(query),
reinterpret_cast<const HipT*>(present),
reinterpret_cast<const HipT*>(present) + offset};
}
case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: {
auto packed_kv = reinterpret_cast<const HipT*>(key);
return {reinterpret_cast<const HipT*>(query), packed_kv, packed_kv + attn->head_size};
}
case QKV_BSN3H: {
case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE: {
auto packed_qkv = reinterpret_cast<const HipT*>(query);
return {packed_qkv, packed_qkv + 1 * attn->head_size, packed_qkv + 2 * attn->head_size};
}
default:
return {nullptr, nullptr, nullptr};
ORT_ENFORCE("unreachable");
return {};
}
}
inline std::tuple<int4, int4, int4> GetQkvStrides(const AttentionParameters* attn) {
inline std::tuple<Strides, Strides, Strides> GetQkvStrides(const RocmAttentionParameters* attn) {
// G0 not used, because it is the slowest dimension
const int& G1 = attn->num_heads;
const int& M = attn->sequence_length;
const int& N = attn->total_sequence_length;
const int& K = attn->head_size;
const int& O = attn->v_head_size;
const int& B = attn->batch_size;
const int& N = attn->num_heads;
const int& S = attn->sequence_length;
const int& L = attn->kv_sequence_length;
// const int& T = attn->total_sequence_length;
const int& H = attn->head_size;
const int& Hv = attn->v_head_size;
int4 q_strides, k_strides, v_strides;
switch (attn->qkv_format) {
case Q_K_V_BNSH:
q_strides = {G1 * M * K, M * K, K, 1};
k_strides = {G1 * N * K, N * K, K, 1}; // matrices are transposed
v_strides = {G1 * N * O, N * O, 1, O};
break;
case Q_KV_BSNH_BSN2H:
ORT_ENFORCE(K == O);
q_strides = {M * G1 * K, K, G1 * K, 1}; // [G0, M, G1, K] layout
k_strides = {N * G1 * 2 * K, 2 * K, G1 * 2 * K, 1}; // [G0, N, G1, K] layout
v_strides = {N * G1 * 2 * O, 2 * O, 1, G1 * 2 * O}; // [G0, N, G1, O] layout
break;
case QKV_BSN3H:
ORT_ENFORCE(K == O);
q_strides = {M * G1 * 3 * K, 3 * K, G1 * 3 * K, 1}; // [G0, M, G1, K] layout
k_strides = {N * G1 * 3 * K, 3 * K, G1 * 3 * K, 1}; // [G0, N, G1, K] layout
v_strides = {N * G1 * 3 * O, 3 * O, 1, G1 * 3 * O}; // [G0, N, G1, O] layout
break;
switch (attn->mode) {
case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE:
if (attn->qkv_format == Q_K_V_BNSH) {
return {
Strides::BNSHMemory(B, N, S, H),
Strides::BNSHMemory(B, N, L, H),
Strides::BNSHMemory(B, N, L, Hv),
};
} else if (attn->qkv_format == Q_K_V_BSNH) {
return {
Strides::BSNHMemory(B, S, N, H),
Strides::BSNHMemory(B, L, N, H),
Strides::BSNHMemory(B, L, N, Hv),
};
}
case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE:
return {
Strides::BSNHMemory(B, S, N, H),
Strides::BSNHMemory(B, L, N, 2 * H),
Strides::BSNHMemory(B, L, N, 2 * Hv),
};
case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE:
return {
Strides::BSNHMemory(B, L, N, 3 * H),
Strides::BSNHMemory(B, L, N, 3 * H),
Strides::BSNHMemory(B, L, N, 3 * Hv),
};
default:
break;
ORT_ENFORCE("unreachable");
return {};
}
return {q_strides, k_strides, v_strides};
}
inline std::tuple<const int*, int3, int3> GetRawMaskBufferAddrSizesAndStrides(
const int* buffer, const AttentionParameters* attn) {
const int* buffer, const RocmAttentionParameters* attn) {
const int* offseted_buffer{buffer}; // how to view the mask buffer
int3 sizes{0, 0, 0}; // the logical shape of the view
int3 strides{-1, -1, -1}; // the physical memory layout
@ -187,7 +312,7 @@ struct GemmSoftmaxGemmPermuteParams : onnxruntime::rocm::tunable::OpParams {
"_H", attention->head_size,
bias_buffer != nullptr ? "_B" : "_NB",
"_M", mask_index_dims.size(),
"_QKV", attention->qkv_format);
"_MODE", attention->mode);
}
std::tuple<int, int, int, int, int> GetGemmsMNKOBatch() const {
@ -201,7 +326,7 @@ struct GemmSoftmaxGemmPermuteParams : onnxruntime::rocm::tunable::OpParams {
}
rocblas_handle handle;
const AttentionParameters* attention;
const RocmAttentionParameters* attention;
const hipDeviceProp_t* device_prop;
float scale;
@ -240,7 +365,7 @@ struct GemmSoftmaxGemmPermuteGenericPipeline {
return {gemm1_out, softmax_out, gemm2_out};
}
inline static size_t GetWorkspaceNumBytes(const AttentionParameters* attn) {
inline static size_t GetWorkspaceNumBytes(const RocmAttentionParameters* attn) {
return GetAttentionWorkspaceSize(
sizeof(T),
attn->batch_size,
@ -333,7 +458,8 @@ struct GemmSoftmaxGemmPermuteGenericPipeline {
static Status Run(const GemmSoftmaxGemmPermuteParams<T>* params, bool use_persistent_softmax) {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->attention->qkv_format != Q_K_V_BNSH, "GenericPipeline only supports qkv_format as Q_K_V_BNSH");
params->attention->qkv_format != Q_K_V_BNSH,
"GenericPipeline only supports qkv_format as Q_K_V_BNSH, got", params->attention->qkv_format);
ORT_RETURN_IF_ERROR(Gemm1(params));
if (UseRawAttentionMask(params)) {
@ -355,19 +481,25 @@ class GemmSoftmaxGemmPermuteTunableOp : public tunable::TunableOp<GemmSoftmaxGem
public:
GemmSoftmaxGemmPermuteTunableOp();
inline static bool IsSupportedQkvFormat(const AttentionParameters* attn) {
switch (attn->qkv_format) {
case Q_K_V_BNSH:
case Q_K_V_BSNH:
case Q_KV_BSNH_BSN2H:
case QKV_BSN3H:
inline static bool IsSupportedMode(const RocmAttentionParameters* attn) {
switch (attn->mode) {
case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE:
case QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE:
// depends on qkv format
if (attn->qkv_format == Q_K_V_BNSH || attn->qkv_format == Q_K_V_BSNH) {
return true;
} else {
return false;
}
case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE:
case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE:
return true;
default:
return false;
}
}
inline static bool IsSupportedMaskType(const AttentionParameters* attn) {
inline static bool IsSupportedMaskType(const RocmAttentionParameters* attn) {
switch (attn->mask_type) {
case MASK_NONE:
case MASK_2D_DUMMY:
@ -380,7 +512,7 @@ class GemmSoftmaxGemmPermuteTunableOp : public tunable::TunableOp<GemmSoftmaxGem
}
}
inline static size_t GetWorkspaceNumBytes(const AttentionParameters* attn) {
inline static size_t GetWorkspaceNumBytes(const RocmAttentionParameters* attn) {
if (!IsSupportedMaskType(attn)) {
return 0;
}
@ -477,8 +609,8 @@ auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() {
auto op = [impl = std::move(impl), invoker = std::move(invoker)](
const GemmSoftmaxGemmPermuteParams<T>* params) -> Status {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
!GemmSoftmaxGemmPermuteTunableOp<T>::IsSupportedQkvFormat(params->attention),
"qkv format is not supported, got ", params->attention->qkv_format);
!GemmSoftmaxGemmPermuteTunableOp<T>::IsSupportedMode(params->attention),
"attention mode is not supported, got ", params->attention->mode);
if constexpr (USE_BIAS) {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->bias_buffer == nullptr, "biased version only support input with bias");
@ -512,11 +644,11 @@ auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() {
auto [qs, ks, vs] = GetQkvStrides(attn);
std::vector<ck::index_t> q_buffer_lengths = {G0, G1, M, K};
std::vector<ck::index_t> q_buffer_strides = {qs.x, qs.y, qs.z, qs.w};
std::vector<ck::index_t> q_buffer_strides = qs.template ForBNSHCoord<std::vector<ck::index_t>>();
std::vector<ck::index_t> k_buffer_lengths = {G0, G1, N, K};
std::vector<ck::index_t> k_buffer_strides = {ks.x, ks.y, ks.z, ks.w};
std::vector<ck::index_t> k_buffer_strides = ks.template ForBNSHCoord<std::vector<ck::index_t>>();
std::vector<ck::index_t> v_buffer_lengths = {G0, G1, O, N};
std::vector<ck::index_t> v_buffer_strides = {vs.x, vs.y, vs.z, vs.w};
std::vector<ck::index_t> v_buffer_strides = vs.template ForBNHSCoord<std::vector<ck::index_t>>();
std::vector<ck::index_t> out_buffer_lengths = {G0, G1, M, O};
std::vector<ck::index_t> out_buffer_strides = {M * G1 * O, O, G1 * O, 1}; // permute 0213

View file

@ -63,7 +63,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
"bias, key_padding_mask and attention cache is not supported");
auto& device_prop = GetDeviceProp();
AttentionParameters attn;
RocmAttentionParameters attn;
ORT_RETURN_IF_ERROR(
multihead_attention_helper::CheckInputs<Tensor>(
query, key, value, bias,
@ -95,6 +95,13 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
// TODO: Add support for attention cache
ORT_ENFORCE(present_key == nullptr && present_value == nullptr, "attention cache is not supported");
ORT_RETURN_IF_ERROR(ClassifyAttentionMode(
Node().OpType(), &attn,
/*qkv=*/{query, key, value},
/*past=*/{past_key, past_value},
/*present=*/{present_key, present_value}));
using HipT = typename ToHipType<T>::MappedType;
using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp<HipT>;
auto workspace_bytes = AttentionTunableOp::GetWorkspaceNumBytes(&attn);
@ -107,7 +114,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
params.attention = &attn;
params.device_prop = &device_prop;
params.scale = scale_ == 0 ? 1.0f / sqrt(attn.head_size) : scale_;
std::tie(params.q_buffer, params.k_buffer, params.v_buffer) = GetQkvBuffers<HipT>(
std::tie(params.q_buffer, params.k_buffer, params.v_buffer) = ConvertToOffsetedBufferViews<HipT>(
&attn,
query->DataRaw(),
key == nullptr ? nullptr : key->DataRaw(),

View file

@ -438,7 +438,6 @@ if __name__ == "__main__":
group.add_argument("--scale", type=float, default=None, help="default to 1.0/sqrt(head_size)")
group.add_argument(
"--qkv_format",
type=lambda name: getattr(ke.qkv_format, name),
default="Q_K_V_BNSH",
choices=[
"Q_K_V_BNSH", # non-packed, permuted
@ -462,6 +461,6 @@ if __name__ == "__main__":
args.biased,
args.mask_dim,
args.scale,
args.qkv_format,
getattr(ke.qkv_format, args.qkv_format),
sort=args.sort,
)

View file

@ -39,7 +39,10 @@ class IGemmSoftmaxGemmPermuteKernelExplorer : public IKernelExplorer {
attn_.batch_size = batch;
attn_.sequence_length = seqlen;
attn_.kv_sequence_length = seqlen; // NOTE: not used
// NOTE: This test wrapper does not support past present concat, then past_sequence_length = 0 always holds.
// Thus, total_sequence_length = past_sequence_length + kv_sequence_length further implies
// total_sequence_length == kv_sequence_length
attn_.kv_sequence_length = total_seqlen;
attn_.past_sequence_length = 0;
attn_.original_past_sequence_length = 0; // NOTE: not used
attn_.total_sequence_length = total_seqlen;
@ -66,6 +69,20 @@ class IGemmSoftmaxGemmPermuteKernelExplorer : public IKernelExplorer {
ORT_ENFORCE(false, "mask type not supported");
}
attn_.qkv_format = qkv_format;
switch (qkv_format) {
case contrib::Q_K_V_BNSH:
case contrib::Q_K_V_BSNH:
attn_.mode = contrib::rocm::QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE;
break;
case contrib::Q_KV_BSNH_BSN2H:
attn_.mode = contrib::rocm::BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE;
break;
case contrib::QKV_BSN3H:
attn_.mode = contrib::rocm::BLN3H_NONE_NONE_NONE_NONE_NONE_NONE;
break;
default:
ORT_NOT_IMPLEMENTED("qkv_format ", qkv_format, " is not implemented");
}
device_prop = GetEp()->GetDeviceProp();
@ -76,7 +93,7 @@ class IGemmSoftmaxGemmPermuteKernelExplorer : public IKernelExplorer {
params_.device_prop = &device_prop;
params_.scale = scale;
std::tie(params_.q_buffer, params_.k_buffer, params_.v_buffer) = GetQkvBuffers<T>(
std::tie(params_.q_buffer, params_.k_buffer, params_.v_buffer) = ConvertToOffsetedBufferViews<T>(
&attn_, Q.ptr(), K.has_value() ? K->ptr() : nullptr, V.has_value() ? V->ptr() : nullptr);
if (attn_bias.has_value()) {
@ -114,7 +131,7 @@ class IGemmSoftmaxGemmPermuteKernelExplorer : public IKernelExplorer {
using ParamsT = contrib::rocm::GemmSoftmaxGemmPermuteParams<T>;
rocblas_handle rocblas_handle_;
hipDeviceProp_t device_prop;
contrib::AttentionParameters attn_;
contrib::rocm::RocmAttentionParameters attn_;
ParamsT params_;
std::shared_ptr<void> workspace_;
};