mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-18 01:54:05 +00:00
[ROCm] Add AttentionMode to make attention logic streamline (#15978)
Refactor for future kv cache change.
This commit is contained in:
parent
b28e927ca4
commit
2cf0ae7d01
7 changed files with 348 additions and 78 deletions
|
|
@ -52,7 +52,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
const Tensor* past_seq_len = context->Input<Tensor>(kPastSequenceLengthInputIndex);
|
||||
|
||||
auto& device_prop = GetDeviceProp();
|
||||
AttentionParameters attn;
|
||||
RocmAttentionParameters attn;
|
||||
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(),
|
||||
weights->Shape(),
|
||||
bias->Shape(),
|
||||
|
|
@ -63,7 +63,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
device_prop.maxThreadsPerBlock,
|
||||
past_seq_len));
|
||||
ORT_ENFORCE(attn.sequence_length == attn.kv_sequence_length); // self attention
|
||||
ORT_ENFORCE(attn.qkv_format == Q_K_V_BNSH); // non-packed, permuted
|
||||
ORT_ENFORCE(attn.qkv_format == Q_K_V_BNSH); // non-packed, permuted
|
||||
|
||||
TensorShapeVector output_shape(3);
|
||||
output_shape[0] = static_cast<int64_t>(attn.batch_size);
|
||||
|
|
@ -86,6 +86,13 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
using AttentionGeneric = GemmSoftmaxGemmPermuteGenericPipeline<HipT>;
|
||||
using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp<HipT>;
|
||||
|
||||
ORT_RETURN_IF_ERROR(ClassifyAttentionMode(
|
||||
Node().OpType(), &attn, /*qkv=*/{}, /*past=*/{past}, /*present=*/{present}));
|
||||
// TODO: support QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE and QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE
|
||||
ORT_ENFORCE(attn.mode == QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE ||
|
||||
attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE ||
|
||||
attn.mode == QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE);
|
||||
|
||||
size_t qkv_project_output_bytes = QkvProjectGeneric::GetOutputNumBytes(&attn);
|
||||
size_t shared_workspace_bytes = std::max(QkvProjectGeneric::GetWorkspaceNumBytes(&attn),
|
||||
AttentionGeneric::GetWorkspaceNumBytes(&attn));
|
||||
|
|
|
|||
|
|
@ -78,6 +78,86 @@ inline int3 Get2DMaskStrides(int total_sequence_length) {
|
|||
return {total_sequence_length, 0, 1};
|
||||
}
|
||||
|
||||
Status ClassifyAttentionMode(
|
||||
const std::string& op,
|
||||
RocmAttentionParameters* attn,
|
||||
const std::vector<const Tensor*>& qkv,
|
||||
const std::vector<const Tensor*>& past,
|
||||
const std::vector<Tensor*>& present) {
|
||||
size_t num_qkv = std::count_if(qkv.cbegin(), qkv.cend(), [](auto it) { return it != nullptr; });
|
||||
size_t num_past = std::count_if(past.cbegin(), past.cend(), [](auto it) { return it != nullptr; });
|
||||
size_t num_present = std::count_if(present.cbegin(), present.cend(), [](auto it) { return it != nullptr; });
|
||||
|
||||
auto hint = MakeString(num_qkv, " qkv inputs, ", num_past, " past inputs and ", num_present, " present inputs");
|
||||
LOGS_DEFAULT(VERBOSE) << hint;
|
||||
|
||||
if (op == "Attention") {
|
||||
ORT_ENFORCE(num_qkv == 0);
|
||||
if (num_past == 0 && num_present == 0) {
|
||||
attn->mode = QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE;
|
||||
return Status::OK();
|
||||
} else if (num_past == 0 && num_present == 1) {
|
||||
if (attn->past_present_share_buffer == false) {
|
||||
attn->mode = QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE;
|
||||
return Status::OK();
|
||||
} else {
|
||||
attn->mode = QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE;
|
||||
return Status::OK();
|
||||
}
|
||||
} else if (num_past == 1 && num_present == 1) {
|
||||
if (attn->past_present_share_buffer == false) {
|
||||
attn->mode = QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE;
|
||||
return Status::OK();
|
||||
} else {
|
||||
attn->mode = QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
} else if (op == "MultiHeadAttention") {
|
||||
if (num_qkv == 3 && num_past == 0 && num_present == 0) {
|
||||
if (attn->qkv_format == Q_K_V_BSNH) {
|
||||
attn->mode = BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE;
|
||||
return Status::OK();
|
||||
} else if (attn->pass_past_in_kv) {
|
||||
attn->mode = BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE;
|
||||
return Status::OK();
|
||||
}
|
||||
} else if (num_qkv == 3 && num_past == 2 && num_present == 2) {
|
||||
if (attn->past_present_share_buffer == false) {
|
||||
if (attn->qkv_format == Q_K_V_BSNH) {
|
||||
attn->mode = BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH;
|
||||
return Status::OK();
|
||||
} else if (attn->pass_past_in_kv) {
|
||||
attn->mode = BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH;
|
||||
return Status::OK();
|
||||
}
|
||||
} else {
|
||||
if (attn->qkv_format == Q_K_V_BSNH) {
|
||||
attn->mode = BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH;
|
||||
return Status::OK();
|
||||
} else if (attn->pass_past_in_kv) {
|
||||
attn->mode = BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
} else if (num_qkv == 1 && num_past == 0 && num_present == 0) {
|
||||
if (attn->qkv_format == QKV_BSN3H) {
|
||||
attn->mode = BLN3H_NONE_NONE_NONE_NONE_NONE_NONE;
|
||||
return Status::OK();
|
||||
}
|
||||
} else if (num_qkv == 2 && num_past == 0 && num_present == 0) {
|
||||
if (attn->qkv_format == Q_KV_BSNH_BSN2H) {
|
||||
attn->mode = BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
}
|
||||
return ORT_MAKE_STATUS(
|
||||
ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Unsupported AttentionMode for ", op, ". Got qkv format ", attn->qkv_format,
|
||||
". Got ", hint);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status DecoderQkvToContext(
|
||||
const hipDeviceProp_t& prop,
|
||||
|
|
@ -117,7 +197,7 @@ Status DecoderQkvToContext(
|
|||
const T* q = qkv_buffer;
|
||||
// transpose q and copy them to qkv_buffer
|
||||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, head_size,
|
||||
num_heads, max_threads_per_block, true, gemm_query_buffer, qkv_buffer));
|
||||
num_heads, max_threads_per_block, true, gemm_query_buffer, qkv_buffer));
|
||||
|
||||
const T* k = qkv_buffer + k_buffer_offset;
|
||||
const T* v = qkv_buffer + v_buffer_offset;
|
||||
|
|
@ -125,31 +205,31 @@ Status DecoderQkvToContext(
|
|||
if (!static_kv) {
|
||||
// transpose kv and copy them to qkv_buffer
|
||||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads,
|
||||
max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset));
|
||||
max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset));
|
||||
} else {
|
||||
// transpose kv and copy them to qkv_buffer
|
||||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, kv_sequence_length, batch_size, head_size, num_heads,
|
||||
max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset));
|
||||
max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset));
|
||||
}
|
||||
} else {
|
||||
if (!static_kv) {
|
||||
// transpose kv and copy them to temp_buffer
|
||||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads,
|
||||
max_threads_per_block, true, gemm_kv_buffer, temp_qkv_buffer));
|
||||
max_threads_per_block, true, gemm_kv_buffer, temp_qkv_buffer));
|
||||
// concat cache-k with k and copy to qkv_buffer
|
||||
if (nullptr != key_cache) {
|
||||
ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, sequence_length,
|
||||
batch_size, head_size, num_heads,
|
||||
max_threads_per_block, 1, key_cache,
|
||||
temp_qkv_buffer, qkv_buffer + k_buffer_offset));
|
||||
batch_size, head_size, num_heads,
|
||||
max_threads_per_block, 1, key_cache,
|
||||
temp_qkv_buffer, qkv_buffer + k_buffer_offset));
|
||||
}
|
||||
// concat cache-v with v and copy to qkv_buffer
|
||||
if (nullptr != value_cache) {
|
||||
ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, sequence_length,
|
||||
batch_size, head_size, num_heads,
|
||||
max_threads_per_block, 1, value_cache,
|
||||
temp_qkv_buffer + k_buffer_offset,
|
||||
qkv_buffer + v_buffer_offset));
|
||||
batch_size, head_size, num_heads,
|
||||
max_threads_per_block, 1, value_cache,
|
||||
temp_qkv_buffer + k_buffer_offset,
|
||||
qkv_buffer + v_buffer_offset));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -214,7 +294,7 @@ Status DecoderQkvToContext(
|
|||
false, 1.0f, false, nullptr, mask_filter_value));
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(ComputeSoftmax<T>(stream, kv_sequence_length, sequence_length, batch_size,
|
||||
num_heads, nullptr, scratch1, scratch2, false));
|
||||
num_heads, nullptr, scratch1, scratch2, false));
|
||||
}
|
||||
|
||||
// compute P*V (as V*P), and store in scratch3: BxNxSxH
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <rocblas/rocblas.h>
|
||||
#include "contrib_ops/cpu/bert/attention_common.h"
|
||||
#include "core/providers/rocm/shared_inc/rocm_utils.h"
|
||||
#include "core/providers/rocm/tunable/rocm_tunable.h"
|
||||
|
||||
|
|
@ -181,6 +182,33 @@ class CompatRocblasMathModeSetter {
|
|||
}
|
||||
};
|
||||
|
||||
enum AttentionMode {
|
||||
// Q,K,V,PastK,PastV,PresentK,PresentV
|
||||
QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE,
|
||||
QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE,
|
||||
QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE,
|
||||
QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE,
|
||||
QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE,
|
||||
BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE,
|
||||
BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE,
|
||||
BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH,
|
||||
BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH,
|
||||
BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH,
|
||||
BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH,
|
||||
BLN3H_NONE_NONE_NONE_NONE_NONE_NONE,
|
||||
BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE,
|
||||
};
|
||||
|
||||
struct RocmAttentionParameters : AttentionParameters {
|
||||
AttentionMode mode;
|
||||
};
|
||||
|
||||
Status ClassifyAttentionMode(const std::string& op,
|
||||
RocmAttentionParameters* attn,
|
||||
const std::vector<const Tensor*>& qkv,
|
||||
const std::vector<const Tensor*>& past,
|
||||
const std::vector<Tensor*>& present);
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ T: total sequence length
|
|||
N: num of heads
|
||||
H: head dimension
|
||||
|
||||
The following use qkv_format == Q_K_V_BNSH as a example:
|
||||
The following use qkv_format == Q_K_V_BNSH (mode == BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE) as a example:
|
||||
|
||||
BN: B*N, which is the batch size of GEMMs. NOTE: To be disambiguated with batch size of Attention Op
|
||||
|
||||
|
|
@ -57,6 +57,46 @@ non-biased, masked, convert the mask to [B,1,1_or_S,T] and perform broadcast
|
|||
Broadcast add is not actually perform the broadcasting, just broadcast the load operation from memory. The impl details
|
||||
are in composable kernels. The scale and add logic is performed via Acc0ElementOp
|
||||
|
||||
# Classified modes:
|
||||
|
||||
| Q | K | V | past(K)| pastV | present(K)| presentV | Op, desc
|
||||
| ---- | ---- | ---- | ------ | ----- | --------- | -------- | ---------
|
||||
| QFMT | KFMT | VFMT | - | - | - | - | A, basic, qkv is impl dependent by qkv_format
|
||||
| QFMT | KFMT | VFMT | 2BNPH | - | 2BNTH *^ | - | A, past_present_share_buffer = false, qkv is impl dependent by qkv_format
|
||||
| QFMT | KFMT | VFMT | 2BNMH | - | 2BNMH *^ | - | A, past_present_share_buffer = true, qkv is impl dependent by qkv_format
|
||||
| BSNH | BLNH*| BLNH^| - | - | - | - | MHA basic
|
||||
| BSNH | BNLH*| BNLH^| - | - | - | - | MHA cross, pass_past_in_kv = true
|
||||
| BSNH | - | - | - | - | BNLH * | BNLH ^ | MHA cross, pass_past_in_kv = false
|
||||
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH * | BNTH ^ | MHA self, past_present_share_buffer = false
|
||||
| BSNH | BNLH | BNLH | BNPH | BNPH | BNTH * | BNTH ^ | MHA self, past_present_share_buffer = false
|
||||
| BSNH | BLNH | BLNH | BNMH | BNMH | BNMH * | BNMH ^ | MHA self, past_present_share_buffer = true
|
||||
| BSNH | BNLH | BNLH | BNMH | BNMH | BNMH * | BNMH ^ | MHA self, past_present_share_buffer = true
|
||||
| BLN3H*^| - | - | - | - | - | - | MHA basic, qkv_packed
|
||||
| BSNH | BLN2H*^| - | - | - | - | - | MHA basic, kv_packed
|
||||
|
||||
Q, K, V, past(K), pastV, present(K), presentV is the Input of the contrib OpKernel
|
||||
|
||||
About k_buffer and v_buffer, we always explicitly concat past to present and use present_k for k_buffer and present_b for v_buffer
|
||||
|
||||
- Marked with `*` indicate the Tensor is used for k_buffer passing.
|
||||
- Marked with `^` indicate the Tensor is used for v_buffer passing.
|
||||
|
||||
# Supported Op
|
||||
|
||||
- A: Attention
|
||||
- MHA: MultiHeadAttention
|
||||
|
||||
# Dim Value
|
||||
|
||||
- B: batch_size
|
||||
- N: num_heads
|
||||
- H: head_size
|
||||
|
||||
- S: sequence_length
|
||||
- L: kv_sequence_length
|
||||
- P: past_sequence_length
|
||||
- T: total_sequence_length = P + L
|
||||
- M: max_sequence_length
|
||||
*/
|
||||
|
||||
#include "core/framework/tensor_shape.h"
|
||||
|
|
@ -88,66 +128,151 @@ inline int3 Get2DMaskStrides(int total_sequence_length) {
|
|||
return {total_sequence_length, 0, 1};
|
||||
}
|
||||
|
||||
// A stride maps from natural coordinate to physical offset of underlying memory storage buffer offset. We need to
|
||||
// specify both of the natural coordinate order, say (b,n,s,h), (b,s,n,h) or (b,n,h,s), and memory order, say BNSH or
|
||||
// BSNH, to determain the strides. To obtain the offset, we just do the inner product of coordinate with the strides.
|
||||
// This wrapper create the stride vector from the physical dimension (or physical shape).
|
||||
struct Strides {
|
||||
// Create the strides for BNSH physically indexed memory buffer
|
||||
static Strides BNSHMemory(int batch_dim,
|
||||
int num_head_dim,
|
||||
int seqlen_dim,
|
||||
int head_size_dim) {
|
||||
ORT_UNUSED_PARAMETER(batch_dim);
|
||||
return Strides{longlong4{
|
||||
static_cast<int64_t>(num_head_dim) * seqlen_dim * head_size_dim,
|
||||
static_cast<int64_t>(seqlen_dim) * head_size_dim,
|
||||
static_cast<int64_t>(head_size_dim),
|
||||
static_cast<int64_t>(1),
|
||||
}};
|
||||
}
|
||||
|
||||
// Create the strides for BSNH physically indexed memory buffer
|
||||
static Strides BSNHMemory(int batch_dim,
|
||||
int seqlen_dim,
|
||||
int num_head_dim,
|
||||
int head_size_dim) {
|
||||
ORT_UNUSED_PARAMETER(batch_dim);
|
||||
return Strides{longlong4{
|
||||
static_cast<int64_t>(seqlen_dim) * num_head_dim * head_size_dim,
|
||||
static_cast<int64_t>(head_size_dim),
|
||||
static_cast<int64_t>(num_head_dim) * head_size_dim,
|
||||
static_cast<int64_t>(1),
|
||||
}};
|
||||
}
|
||||
|
||||
template <typename T = longlong4>
|
||||
T ForBNSHCoord() {
|
||||
using E = typename T::value_type;
|
||||
return T{static_cast<E>(strides_for_bnsh_coord.x),
|
||||
static_cast<E>(strides_for_bnsh_coord.y),
|
||||
static_cast<E>(strides_for_bnsh_coord.z),
|
||||
static_cast<E>(strides_for_bnsh_coord.w)};
|
||||
}
|
||||
|
||||
template <typename T = longlong4>
|
||||
T ForBSNHCoord() {
|
||||
using E = typename T::value_type;
|
||||
return T{static_cast<E>(strides_for_bnsh_coord.x),
|
||||
static_cast<E>(strides_for_bnsh_coord.z),
|
||||
static_cast<E>(strides_for_bnsh_coord.y),
|
||||
static_cast<E>(strides_for_bnsh_coord.w)};
|
||||
}
|
||||
|
||||
template <typename T = longlong4>
|
||||
T ForBNHSCoord() {
|
||||
using E = typename T::value_type;
|
||||
return T{static_cast<E>(strides_for_bnsh_coord.x),
|
||||
static_cast<E>(strides_for_bnsh_coord.y),
|
||||
static_cast<E>(strides_for_bnsh_coord.w),
|
||||
static_cast<E>(strides_for_bnsh_coord.z)};
|
||||
}
|
||||
|
||||
// store intermediate strides in the canonical (b,n,s,h) coordinate order
|
||||
longlong4 strides_for_bnsh_coord;
|
||||
};
|
||||
|
||||
template <typename HipT, typename T>
|
||||
std::tuple<const HipT*, const HipT*, const HipT*> GetQkvBuffers(
|
||||
const AttentionParameters* attn,
|
||||
const T* query,
|
||||
const T* key,
|
||||
const T* value) {
|
||||
switch (attn->qkv_format) {
|
||||
case Q_K_V_BNSH:
|
||||
case Q_K_V_BSNH:
|
||||
std::tuple<const HipT*, const HipT*, const HipT*> ConvertToOffsetedBufferViews(
|
||||
const RocmAttentionParameters* attn,
|
||||
const T* query = nullptr, // q or packed_qkv
|
||||
const T* key = nullptr, // k or packed kv
|
||||
const T* value = nullptr, //
|
||||
const T* present = nullptr, // present or present_k
|
||||
const T* present_v = nullptr) {
|
||||
ORT_UNUSED_PARAMETER(present_v);
|
||||
switch (attn->mode) {
|
||||
case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: {
|
||||
return {reinterpret_cast<const HipT*>(query),
|
||||
reinterpret_cast<const HipT*>(key),
|
||||
reinterpret_cast<const HipT*>(value)};
|
||||
case Q_KV_BSNH_BSN2H: {
|
||||
}
|
||||
case QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE: {
|
||||
auto offset = static_cast<int64_t>(attn->batch_size) * attn->num_heads * attn->total_sequence_length *
|
||||
attn->head_size;
|
||||
return {reinterpret_cast<const HipT*>(query),
|
||||
reinterpret_cast<const HipT*>(present),
|
||||
reinterpret_cast<const HipT*>(present) + offset};
|
||||
}
|
||||
case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: {
|
||||
auto packed_kv = reinterpret_cast<const HipT*>(key);
|
||||
return {reinterpret_cast<const HipT*>(query), packed_kv, packed_kv + attn->head_size};
|
||||
}
|
||||
case QKV_BSN3H: {
|
||||
case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE: {
|
||||
auto packed_qkv = reinterpret_cast<const HipT*>(query);
|
||||
return {packed_qkv, packed_qkv + 1 * attn->head_size, packed_qkv + 2 * attn->head_size};
|
||||
}
|
||||
default:
|
||||
return {nullptr, nullptr, nullptr};
|
||||
ORT_ENFORCE("unreachable");
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
inline std::tuple<int4, int4, int4> GetQkvStrides(const AttentionParameters* attn) {
|
||||
inline std::tuple<Strides, Strides, Strides> GetQkvStrides(const RocmAttentionParameters* attn) {
|
||||
// G0 not used, because it is the slowest dimension
|
||||
const int& G1 = attn->num_heads;
|
||||
const int& M = attn->sequence_length;
|
||||
const int& N = attn->total_sequence_length;
|
||||
const int& K = attn->head_size;
|
||||
const int& O = attn->v_head_size;
|
||||
const int& B = attn->batch_size;
|
||||
const int& N = attn->num_heads;
|
||||
const int& S = attn->sequence_length;
|
||||
const int& L = attn->kv_sequence_length;
|
||||
// const int& T = attn->total_sequence_length;
|
||||
const int& H = attn->head_size;
|
||||
const int& Hv = attn->v_head_size;
|
||||
|
||||
int4 q_strides, k_strides, v_strides;
|
||||
switch (attn->qkv_format) {
|
||||
case Q_K_V_BNSH:
|
||||
q_strides = {G1 * M * K, M * K, K, 1};
|
||||
k_strides = {G1 * N * K, N * K, K, 1}; // matrices are transposed
|
||||
v_strides = {G1 * N * O, N * O, 1, O};
|
||||
break;
|
||||
case Q_KV_BSNH_BSN2H:
|
||||
ORT_ENFORCE(K == O);
|
||||
q_strides = {M * G1 * K, K, G1 * K, 1}; // [G0, M, G1, K] layout
|
||||
k_strides = {N * G1 * 2 * K, 2 * K, G1 * 2 * K, 1}; // [G0, N, G1, K] layout
|
||||
v_strides = {N * G1 * 2 * O, 2 * O, 1, G1 * 2 * O}; // [G0, N, G1, O] layout
|
||||
break;
|
||||
case QKV_BSN3H:
|
||||
ORT_ENFORCE(K == O);
|
||||
q_strides = {M * G1 * 3 * K, 3 * K, G1 * 3 * K, 1}; // [G0, M, G1, K] layout
|
||||
k_strides = {N * G1 * 3 * K, 3 * K, G1 * 3 * K, 1}; // [G0, N, G1, K] layout
|
||||
v_strides = {N * G1 * 3 * O, 3 * O, 1, G1 * 3 * O}; // [G0, N, G1, O] layout
|
||||
break;
|
||||
switch (attn->mode) {
|
||||
case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE:
|
||||
if (attn->qkv_format == Q_K_V_BNSH) {
|
||||
return {
|
||||
Strides::BNSHMemory(B, N, S, H),
|
||||
Strides::BNSHMemory(B, N, L, H),
|
||||
Strides::BNSHMemory(B, N, L, Hv),
|
||||
};
|
||||
} else if (attn->qkv_format == Q_K_V_BSNH) {
|
||||
return {
|
||||
Strides::BSNHMemory(B, S, N, H),
|
||||
Strides::BSNHMemory(B, L, N, H),
|
||||
Strides::BSNHMemory(B, L, N, Hv),
|
||||
};
|
||||
}
|
||||
case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE:
|
||||
return {
|
||||
Strides::BSNHMemory(B, S, N, H),
|
||||
Strides::BSNHMemory(B, L, N, 2 * H),
|
||||
Strides::BSNHMemory(B, L, N, 2 * Hv),
|
||||
};
|
||||
case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE:
|
||||
return {
|
||||
Strides::BSNHMemory(B, L, N, 3 * H),
|
||||
Strides::BSNHMemory(B, L, N, 3 * H),
|
||||
Strides::BSNHMemory(B, L, N, 3 * Hv),
|
||||
};
|
||||
default:
|
||||
break;
|
||||
ORT_ENFORCE("unreachable");
|
||||
return {};
|
||||
}
|
||||
return {q_strides, k_strides, v_strides};
|
||||
}
|
||||
|
||||
inline std::tuple<const int*, int3, int3> GetRawMaskBufferAddrSizesAndStrides(
|
||||
const int* buffer, const AttentionParameters* attn) {
|
||||
const int* buffer, const RocmAttentionParameters* attn) {
|
||||
const int* offseted_buffer{buffer}; // how to view the mask buffer
|
||||
int3 sizes{0, 0, 0}; // the logical shape of the view
|
||||
int3 strides{-1, -1, -1}; // the physical memory layout
|
||||
|
|
@ -187,7 +312,7 @@ struct GemmSoftmaxGemmPermuteParams : onnxruntime::rocm::tunable::OpParams {
|
|||
"_H", attention->head_size,
|
||||
bias_buffer != nullptr ? "_B" : "_NB",
|
||||
"_M", mask_index_dims.size(),
|
||||
"_QKV", attention->qkv_format);
|
||||
"_MODE", attention->mode);
|
||||
}
|
||||
|
||||
std::tuple<int, int, int, int, int> GetGemmsMNKOBatch() const {
|
||||
|
|
@ -201,7 +326,7 @@ struct GemmSoftmaxGemmPermuteParams : onnxruntime::rocm::tunable::OpParams {
|
|||
}
|
||||
|
||||
rocblas_handle handle;
|
||||
const AttentionParameters* attention;
|
||||
const RocmAttentionParameters* attention;
|
||||
const hipDeviceProp_t* device_prop;
|
||||
|
||||
float scale;
|
||||
|
|
@ -240,7 +365,7 @@ struct GemmSoftmaxGemmPermuteGenericPipeline {
|
|||
return {gemm1_out, softmax_out, gemm2_out};
|
||||
}
|
||||
|
||||
inline static size_t GetWorkspaceNumBytes(const AttentionParameters* attn) {
|
||||
inline static size_t GetWorkspaceNumBytes(const RocmAttentionParameters* attn) {
|
||||
return GetAttentionWorkspaceSize(
|
||||
sizeof(T),
|
||||
attn->batch_size,
|
||||
|
|
@ -333,7 +458,8 @@ struct GemmSoftmaxGemmPermuteGenericPipeline {
|
|||
|
||||
static Status Run(const GemmSoftmaxGemmPermuteParams<T>* params, bool use_persistent_softmax) {
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
|
||||
params->attention->qkv_format != Q_K_V_BNSH, "GenericPipeline only supports qkv_format as Q_K_V_BNSH");
|
||||
params->attention->qkv_format != Q_K_V_BNSH,
|
||||
"GenericPipeline only supports qkv_format as Q_K_V_BNSH, got", params->attention->qkv_format);
|
||||
ORT_RETURN_IF_ERROR(Gemm1(params));
|
||||
|
||||
if (UseRawAttentionMask(params)) {
|
||||
|
|
@ -355,19 +481,25 @@ class GemmSoftmaxGemmPermuteTunableOp : public tunable::TunableOp<GemmSoftmaxGem
|
|||
public:
|
||||
GemmSoftmaxGemmPermuteTunableOp();
|
||||
|
||||
inline static bool IsSupportedQkvFormat(const AttentionParameters* attn) {
|
||||
switch (attn->qkv_format) {
|
||||
case Q_K_V_BNSH:
|
||||
case Q_K_V_BSNH:
|
||||
case Q_KV_BSNH_BSN2H:
|
||||
case QKV_BSN3H:
|
||||
inline static bool IsSupportedMode(const RocmAttentionParameters* attn) {
|
||||
switch (attn->mode) {
|
||||
case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE:
|
||||
case QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE:
|
||||
// depends on qkv format
|
||||
if (attn->qkv_format == Q_K_V_BNSH || attn->qkv_format == Q_K_V_BSNH) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE:
|
||||
case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
inline static bool IsSupportedMaskType(const AttentionParameters* attn) {
|
||||
inline static bool IsSupportedMaskType(const RocmAttentionParameters* attn) {
|
||||
switch (attn->mask_type) {
|
||||
case MASK_NONE:
|
||||
case MASK_2D_DUMMY:
|
||||
|
|
@ -380,7 +512,7 @@ class GemmSoftmaxGemmPermuteTunableOp : public tunable::TunableOp<GemmSoftmaxGem
|
|||
}
|
||||
}
|
||||
|
||||
inline static size_t GetWorkspaceNumBytes(const AttentionParameters* attn) {
|
||||
inline static size_t GetWorkspaceNumBytes(const RocmAttentionParameters* attn) {
|
||||
if (!IsSupportedMaskType(attn)) {
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -477,8 +609,8 @@ auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() {
|
|||
auto op = [impl = std::move(impl), invoker = std::move(invoker)](
|
||||
const GemmSoftmaxGemmPermuteParams<T>* params) -> Status {
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
|
||||
!GemmSoftmaxGemmPermuteTunableOp<T>::IsSupportedQkvFormat(params->attention),
|
||||
"qkv format is not supported, got ", params->attention->qkv_format);
|
||||
!GemmSoftmaxGemmPermuteTunableOp<T>::IsSupportedMode(params->attention),
|
||||
"attention mode is not supported, got ", params->attention->mode);
|
||||
if constexpr (USE_BIAS) {
|
||||
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
|
||||
params->bias_buffer == nullptr, "biased version only support input with bias");
|
||||
|
|
@ -512,11 +644,11 @@ auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() {
|
|||
|
||||
auto [qs, ks, vs] = GetQkvStrides(attn);
|
||||
std::vector<ck::index_t> q_buffer_lengths = {G0, G1, M, K};
|
||||
std::vector<ck::index_t> q_buffer_strides = {qs.x, qs.y, qs.z, qs.w};
|
||||
std::vector<ck::index_t> q_buffer_strides = qs.template ForBNSHCoord<std::vector<ck::index_t>>();
|
||||
std::vector<ck::index_t> k_buffer_lengths = {G0, G1, N, K};
|
||||
std::vector<ck::index_t> k_buffer_strides = {ks.x, ks.y, ks.z, ks.w};
|
||||
std::vector<ck::index_t> k_buffer_strides = ks.template ForBNSHCoord<std::vector<ck::index_t>>();
|
||||
std::vector<ck::index_t> v_buffer_lengths = {G0, G1, O, N};
|
||||
std::vector<ck::index_t> v_buffer_strides = {vs.x, vs.y, vs.z, vs.w};
|
||||
std::vector<ck::index_t> v_buffer_strides = vs.template ForBNHSCoord<std::vector<ck::index_t>>();
|
||||
std::vector<ck::index_t> out_buffer_lengths = {G0, G1, M, O};
|
||||
std::vector<ck::index_t> out_buffer_strides = {M * G1 * O, O, G1 * O, 1}; // permute 0213
|
||||
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
"bias, key_padding_mask and attention cache is not supported");
|
||||
|
||||
auto& device_prop = GetDeviceProp();
|
||||
AttentionParameters attn;
|
||||
RocmAttentionParameters attn;
|
||||
ORT_RETURN_IF_ERROR(
|
||||
multihead_attention_helper::CheckInputs<Tensor>(
|
||||
query, key, value, bias,
|
||||
|
|
@ -95,6 +95,13 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
// TODO: Add support for attention cache
|
||||
ORT_ENFORCE(present_key == nullptr && present_value == nullptr, "attention cache is not supported");
|
||||
|
||||
ORT_RETURN_IF_ERROR(ClassifyAttentionMode(
|
||||
Node().OpType(), &attn,
|
||||
/*qkv=*/{query, key, value},
|
||||
/*past=*/{past_key, past_value},
|
||||
/*present=*/{present_key, present_value}));
|
||||
|
||||
|
||||
using HipT = typename ToHipType<T>::MappedType;
|
||||
using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp<HipT>;
|
||||
auto workspace_bytes = AttentionTunableOp::GetWorkspaceNumBytes(&attn);
|
||||
|
|
@ -107,7 +114,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
params.attention = &attn;
|
||||
params.device_prop = &device_prop;
|
||||
params.scale = scale_ == 0 ? 1.0f / sqrt(attn.head_size) : scale_;
|
||||
std::tie(params.q_buffer, params.k_buffer, params.v_buffer) = GetQkvBuffers<HipT>(
|
||||
std::tie(params.q_buffer, params.k_buffer, params.v_buffer) = ConvertToOffsetedBufferViews<HipT>(
|
||||
&attn,
|
||||
query->DataRaw(),
|
||||
key == nullptr ? nullptr : key->DataRaw(),
|
||||
|
|
|
|||
|
|
@ -438,7 +438,6 @@ if __name__ == "__main__":
|
|||
group.add_argument("--scale", type=float, default=None, help="default to 1.0/sqrt(head_size)")
|
||||
group.add_argument(
|
||||
"--qkv_format",
|
||||
type=lambda name: getattr(ke.qkv_format, name),
|
||||
default="Q_K_V_BNSH",
|
||||
choices=[
|
||||
"Q_K_V_BNSH", # non-packed, permuted
|
||||
|
|
@ -462,6 +461,6 @@ if __name__ == "__main__":
|
|||
args.biased,
|
||||
args.mask_dim,
|
||||
args.scale,
|
||||
args.qkv_format,
|
||||
getattr(ke.qkv_format, args.qkv_format),
|
||||
sort=args.sort,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -39,7 +39,10 @@ class IGemmSoftmaxGemmPermuteKernelExplorer : public IKernelExplorer {
|
|||
|
||||
attn_.batch_size = batch;
|
||||
attn_.sequence_length = seqlen;
|
||||
attn_.kv_sequence_length = seqlen; // NOTE: not used
|
||||
// NOTE: This test wrapper does not support past present concat, then past_sequence_length = 0 always holds.
|
||||
// Thus, total_sequence_length = past_sequence_length + kv_sequence_length further implies
|
||||
// total_sequence_length == kv_sequence_length
|
||||
attn_.kv_sequence_length = total_seqlen;
|
||||
attn_.past_sequence_length = 0;
|
||||
attn_.original_past_sequence_length = 0; // NOTE: not used
|
||||
attn_.total_sequence_length = total_seqlen;
|
||||
|
|
@ -66,6 +69,20 @@ class IGemmSoftmaxGemmPermuteKernelExplorer : public IKernelExplorer {
|
|||
ORT_ENFORCE(false, "mask type not supported");
|
||||
}
|
||||
attn_.qkv_format = qkv_format;
|
||||
switch (qkv_format) {
|
||||
case contrib::Q_K_V_BNSH:
|
||||
case contrib::Q_K_V_BSNH:
|
||||
attn_.mode = contrib::rocm::QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE;
|
||||
break;
|
||||
case contrib::Q_KV_BSNH_BSN2H:
|
||||
attn_.mode = contrib::rocm::BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE;
|
||||
break;
|
||||
case contrib::QKV_BSN3H:
|
||||
attn_.mode = contrib::rocm::BLN3H_NONE_NONE_NONE_NONE_NONE_NONE;
|
||||
break;
|
||||
default:
|
||||
ORT_NOT_IMPLEMENTED("qkv_format ", qkv_format, " is not implemented");
|
||||
}
|
||||
|
||||
device_prop = GetEp()->GetDeviceProp();
|
||||
|
||||
|
|
@ -76,7 +93,7 @@ class IGemmSoftmaxGemmPermuteKernelExplorer : public IKernelExplorer {
|
|||
params_.device_prop = &device_prop;
|
||||
params_.scale = scale;
|
||||
|
||||
std::tie(params_.q_buffer, params_.k_buffer, params_.v_buffer) = GetQkvBuffers<T>(
|
||||
std::tie(params_.q_buffer, params_.k_buffer, params_.v_buffer) = ConvertToOffsetedBufferViews<T>(
|
||||
&attn_, Q.ptr(), K.has_value() ? K->ptr() : nullptr, V.has_value() ? V->ptr() : nullptr);
|
||||
|
||||
if (attn_bias.has_value()) {
|
||||
|
|
@ -114,7 +131,7 @@ class IGemmSoftmaxGemmPermuteKernelExplorer : public IKernelExplorer {
|
|||
using ParamsT = contrib::rocm::GemmSoftmaxGemmPermuteParams<T>;
|
||||
rocblas_handle rocblas_handle_;
|
||||
hipDeviceProp_t device_prop;
|
||||
contrib::AttentionParameters attn_;
|
||||
contrib::rocm::RocmAttentionParameters attn_;
|
||||
ParamsT params_;
|
||||
std::shared_ptr<void> workspace_;
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in a new issue