[ROCm] Add attention kv cache for decoding (#16076)

This commit is contained in:
cloudhan 2023-06-16 14:17:56 +08:00 committed by GitHub
parent 96471491d7
commit 9110e5b9bd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 435 additions and 88 deletions

View file

@ -180,6 +180,12 @@ Status LaunchConcatPastToPresent(cudaStream_t stream,
const half* past,
const half* k_v,
half* present);
template <typename T>
Status LaunchStridedCopy(cudaStream_t stream,
const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h)
T* out, longlong4 out_strides, // coord (b,n,s,h)
int max_threads_per_block);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,158 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cuda/bert/attention_impl.h"
using namespace onnxruntime::cuda;
namespace onnxruntime {
namespace contrib {
namespace cuda {
template <typename T>
__global__ void StridedCopy(const T* in, const int H, longlong4 in_strides, // coord (b,n,s,h)
T* out, longlong4 out_strides // coord (b,n,s,h)
) {
const int h = threadIdx.x;
const int n = threadIdx.y;
const int s = blockIdx.x;
const int b = blockIdx.y;
if (h < H) {
const int in_offset = b * in_strides.x + n * in_strides.y + s * in_strides.z + h * in_strides.w;
const int out_offset = b * out_strides.x + n * out_strides.y + s * out_strides.z + h * out_strides.w;
out[out_offset] = in[in_offset];
}
}
template <typename T>
__global__ void StridedCopyLarge(const T* in, const int H, longlong4 in_strides, // coord (b,n,s,h)
T* out, longlong4 out_strides // coord (b,n,s,h)
) {
// Use when (H*)*num_heads > 1024
int h = threadIdx.x;
const int n = threadIdx.y;
const int s = blockIdx.x;
const int b = blockIdx.y;
const int h_step = blockDim.x;
while (h < H) {
const int in_offset = b * in_strides.x + n * in_strides.y + s * in_strides.z + h * in_strides.w;
const int out_offset = b * out_strides.x + n * out_strides.y + s * out_strides.z + h * out_strides.w;
out[out_offset] = in[in_offset];
h += h_step;
}
}
template <int NumBytes>
struct ToByteType;
template <>
struct ToByteType<2> {
using T = int16_t;
};
template <>
struct ToByteType<4> {
using T = int32_t;
};
template <>
struct ToByteType<8> {
using T = int64_t;
};
template <>
struct ToByteType<16> {
using T = uint4;
};
template <>
struct ToByteType<32> {
using T = ulonglong4;
};
template <int NumBytes>
using ToBytes = typename ToByteType<NumBytes>::T;
template <typename T>
Status LaunchStridedCopy(cudaStream_t stream,
const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h)
T* out, longlong4 out_strides, // coord (b,n,s,h)
int max_threads_per_block) {
int batch_size = in_shape.x;
int num_heads = in_shape.y;
int sequence_length = in_shape.z;
int head_size = in_shape.w;
if (sequence_length == 0) {
return Status::OK();
}
const dim3 grid(sequence_length, batch_size);
if (0 == (head_size % 4)) { // pack 4 element together
using Bytes = ToBytes<sizeof(T) * 4>;
const int H = head_size / 4;
in_strides.x /= 4;
in_strides.y /= 4;
in_strides.z /= 4;
out_strides.x /= 4;
out_strides.y /= 4;
out_strides.z /= 4;
if (H * num_heads <= max_threads_per_block) {
const dim3 block(H, num_heads, 1);
StridedCopy<Bytes><<<grid, block, 0, stream>>>(reinterpret_cast<const Bytes*>(in), H, in_strides,
reinterpret_cast<Bytes*>(out), out_strides);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
StridedCopyLarge<Bytes><<<grid, block, 0, stream>>>(reinterpret_cast<const Bytes*>(in), H, in_strides,
reinterpret_cast<Bytes*>(out), out_strides);
}
} else if (0 == (head_size % 2)) { // pack 2 element together
using Bytes = ToBytes<sizeof(T) * 2>;
const int H = head_size / 2;
in_strides.x /= 2;
in_strides.y /= 2;
in_strides.z /= 2;
out_strides.x /= 2;
out_strides.y /= 2;
out_strides.z /= 2;
if (H * num_heads <= max_threads_per_block) {
const dim3 block(H, num_heads, 1);
StridedCopy<Bytes><<<grid, block, 0, stream>>>(reinterpret_cast<const Bytes*>(in), H, in_strides,
reinterpret_cast<Bytes*>(out), out_strides);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
StridedCopyLarge<Bytes><<<grid, block, 0, stream>>>(reinterpret_cast<const Bytes*>(in), H, in_strides,
reinterpret_cast<Bytes*>(out), out_strides);
}
} else {
using Bytes = ToBytes<sizeof(T)>;
if (head_size * num_heads <= max_threads_per_block) {
const dim3 block(head_size, num_heads, 1);
StridedCopy<Bytes><<<grid, block, 0, stream>>>(reinterpret_cast<const Bytes*>(in), head_size, in_strides,
reinterpret_cast<Bytes*>(out), out_strides);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
StridedCopyLarge<Bytes><<<grid, block, 0, stream>>>(reinterpret_cast<const Bytes*>(in), head_size, in_strides,
reinterpret_cast<Bytes*>(out), out_strides);
}
}
return CUDA_CALL(cudaGetLastError());
}
template Status LaunchStridedCopy<float>(
cudaStream_t stream,
const float* in, int4 in_shape, longlong4 in_strides,
float* out, longlong4 out_strides,
int max_threads_per_block);
template Status LaunchStridedCopy<half>(
cudaStream_t stream,
const half* in, int4 in_shape, longlong4 in_strides,
half* out, longlong4 out_strides,
int max_threads_per_block);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -88,10 +88,11 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
ORT_RETURN_IF_ERROR(ClassifyAttentionMode(
Node().OpType(), &attn, /*qkv=*/{}, /*past=*/{past}, /*present=*/{present}));
// TODO: support QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE and QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE
ORT_ENFORCE(attn.mode == QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE ||
attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE ||
attn.mode == QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE);
attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE ||
attn.mode == QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE ||
attn.mode == QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE);
size_t qkv_project_output_bytes = QkvProjectGeneric::GetOutputNumBytes(&attn);
size_t shared_workspace_bytes = std::max(QkvProjectGeneric::GetWorkspaceNumBytes(&attn),
@ -123,27 +124,46 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
ORT_RETURN_IF_ERROR(QkvProjectGeneric::Run(&gemm_permute_params));
auto [q_buffer, k_buffer, v_buffer] = QkvProjectGeneric::UnspliceOutputQKV(&gemm_permute_params);
// NOTE: GemmPermute always output 3BNSH, k_buffer and v_buffer can be treated as 2BNSH
if (nullptr != present) {
// Concat past (2xBxNxS'xH) to present (2xBxNxTxH):
// past_k (BxNxS'xH) + k (BxNxSxH) => present_k (BxNxTxH)
// past_v (BxNxS'xH) + v (BxNxSxH) => present_v (BxNxTxH)
const int batches = attn.batch_size * attn.num_heads;
const int present_size_per_batch = attn.total_sequence_length * attn.head_size;
ORT_RETURN_IF_ERROR(
LaunchConcatPastToPresent(Stream(context),
attn.total_sequence_length,
attn.sequence_length,
attn.batch_size,
attn.head_size,
attn.num_heads,
device_prop.maxThreadsPerBlock,
nullptr == past ? nullptr : reinterpret_cast<const HipT*>(past->DataRaw()),
k_buffer,
reinterpret_cast<HipT*>(present->MutableDataRaw())));
Strides dst_strides; // the output buffer is present Tensor, the buffer is the same
// update pointers to present_k and present_v.
int4 add_shape{2 * attn.batch_size, attn.num_heads, attn.sequence_length, attn.head_size};
HipT* add_dest = nullptr; // destination of concatenated data to present
const HipT* const add_src = k_buffer; // source of concatenated data to present
const auto add_src_strides = Strides::BNSHMemory(
2 * attn.batch_size, attn.num_heads, attn.sequence_length, attn.head_size);
if (attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE) {
dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size);
add_dest = reinterpret_cast<HipT*>(present->MutableDataRaw()) /* + dst_strides.OffsetAt(0, 0, 0, 0)*/;
} else if (attn.mode == QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE) {
dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size);
add_dest = reinterpret_cast<HipT*>(present->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0);
// We only need to copy past to present in this case. All other cases will be build the present incrementally
const int4 past_shape = {2 * attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size};
HipT* const past_dest = reinterpret_cast<HipT*>(present->MutableDataRaw());
const HipT* const past_src = reinterpret_cast<const HipT*>(past->DataRaw());
const Strides past_src_strides = Strides::BNSHMemory(
2 * attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size);
ORT_RETURN_IF_ERROR(LaunchStridedCopy(stream, past_src, past_shape, past_src_strides.ForBNSHCoord(),
past_dest, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock));
} else if (attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE) {
dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.max_sequence_length, attn.head_size);
add_dest = reinterpret_cast<HipT*>(present->MutableDataRaw()) /* + dst_strides.OffsetAt(0, 0, 0, 0)*/;
} else if (attn.mode == QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE) {
dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.max_sequence_length, attn.head_size);
add_dest = reinterpret_cast<HipT*>(present->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0);
}
ORT_RETURN_IF_ERROR(LaunchStridedCopy(stream, add_src, add_shape, add_src_strides.ForBNSHCoord(),
add_dest, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock));
// update pointers to present_k and present_v. TODO: switch to ConvertToOffsetedBufferViews
k_buffer = reinterpret_cast<HipT*>(present->MutableDataRaw());
v_buffer = reinterpret_cast<HipT*>(present->MutableDataRaw()) + batches * present_size_per_batch;
v_buffer = reinterpret_cast<HipT*>(present->MutableDataRaw()) + dst_strides.OffsetAt(attn.batch_size, 0, 0, 0);
}
// For testing, environment variable ORT_TRANSFORMER_OPTIONS=1 could enable persistent softmax
@ -160,6 +180,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
params.device_prop = &device_prop;
// FIXME: the params.scale seems to be different from AttentionParameters::scale;
params.scale = 1.0f / sqrt(static_cast<float>(attn.head_size));
// TODO: switch to ConvertToOffsetedBufferViews
params.q_buffer = q_buffer;
params.k_buffer = k_buffer;
params.v_buffer = v_buffer;

View file

@ -122,6 +122,24 @@ Status ClassifyAttentionMode(
attn->mode = BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE;
return Status::OK();
}
} else if (num_qkv == 3 && num_past == 0 && num_present == 2) {
if (attn->past_present_share_buffer == false) {
if (attn->qkv_format == Q_K_V_BSNH) {
attn->mode = BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH;
return Status::OK();
} else if (attn->pass_past_in_kv) {
attn->mode = BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH;
return Status::OK();
}
} else {
if (attn->qkv_format == Q_K_V_BSNH) {
attn->mode = BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH;
return Status::OK();
} else if (attn->pass_past_in_kv) {
attn->mode = BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH;
return Status::OK();
}
}
} else if (num_qkv == 3 && num_past == 2 && num_present == 2) {
if (attn->past_present_share_buffer == false) {
if (attn->qkv_format == Q_K_V_BSNH) {

View file

@ -98,28 +98,6 @@ Status LaunchConcatTensorToTensor(hipStream_t stream,
const half* tensor_add,
half* tensor_out);
Status LaunchConcatPastToPresent(hipStream_t stream,
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const float* past,
const float* k_v,
float* present);
Status LaunchConcatPastToPresent(hipStream_t stream,
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const half* past,
const half* k_v,
half* present);
inline rocblas_status _compat_rocblas_gemm_strided_batched_ex(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
@ -191,6 +169,10 @@ enum AttentionMode {
QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE,
BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE,
BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE,
BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH,
BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH,
BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH,
BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH,
BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH,
BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH,
BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH,
@ -209,6 +191,11 @@ Status ClassifyAttentionMode(const std::string& op,
const std::vector<const Tensor*>& past,
const std::vector<Tensor*>& present);
template <typename T>
Status LaunchStridedCopy(hipStream_t stream,
const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h)
T* out, longlong4 out_strides, // coord (b,n,s,h)
int max_threads_per_block);
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -67,6 +67,10 @@ are in composable kernels. The scale and add logic is performed via Acc0ElementO
| BSNH | BLNH*| BLNH^| - | - | - | - | MHA basic
| BSNH | BNLH*| BNLH^| - | - | - | - | MHA cross, pass_past_in_kv = true
| BSNH | - | - | - | - | BNLH * | BNLH ^ | MHA cross, pass_past_in_kv = false
| BSNH | BLNH | BLNH | - | - | BNTH * | BNTH ^ | MHA cross, past_present_share_buffer = false
| BSNH | BNLH | BNLH | - | - | BNTH * | BNTH ^ | MHA cross, past_present_share_buffer = false
| BSNH | BLNH | BLNH | - | - | BNMH * | BNMH ^ | MHA cross, past_present_share_buffer = true
| BSNH | BNLH | BNLH | - | - | BNMH * | BNMH ^ | MHA cross, past_present_share_buffer = true
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH * | BNTH ^ | MHA self, past_present_share_buffer = false
| BSNH | BNLH | BNLH | BNPH | BNPH | BNTH * | BNTH ^ | MHA self, past_present_share_buffer = false
| BSNH | BLNH | BLNH | BNMH | BNMH | BNMH * | BNMH ^ | MHA self, past_present_share_buffer = true
@ -76,7 +80,7 @@ are in composable kernels. The scale and add logic is performed via Acc0ElementO
Q, K, V, past(K), pastV, present(K), presentV is the Input of the contrib OpKernel
About k_buffer and v_buffer, we always explicitly concat past to present and use present_k for k_buffer and present_b for v_buffer
About k_buffer and v_buffer, we always explicitly concat past to present and use present_k for k_buffer and present_v for v_buffer
- Marked with `*` indicate the Tensor is used for k_buffer passing.
- Marked with `^` indicate the Tensor is used for v_buffer passing.
@ -162,7 +166,7 @@ struct Strides {
}
template <typename T = longlong4>
T ForBNSHCoord() {
T ForBNSHCoord() const {
using E = typename T::value_type;
return T{static_cast<E>(strides_for_bnsh_coord.x),
static_cast<E>(strides_for_bnsh_coord.y),
@ -171,7 +175,7 @@ struct Strides {
}
template <typename T = longlong4>
T ForBSNHCoord() {
T ForBSNHCoord() const {
using E = typename T::value_type;
return T{static_cast<E>(strides_for_bnsh_coord.x),
static_cast<E>(strides_for_bnsh_coord.z),
@ -180,7 +184,7 @@ struct Strides {
}
template <typename T = longlong4>
T ForBNHSCoord() {
T ForBNHSCoord() const {
using E = typename T::value_type;
return T{static_cast<E>(strides_for_bnsh_coord.x),
static_cast<E>(strides_for_bnsh_coord.y),
@ -188,6 +192,11 @@ struct Strides {
static_cast<E>(strides_for_bnsh_coord.z)};
}
int64_t OffsetAt(int b, int n, int s, int h) const {
return strides_for_bnsh_coord.x * b + strides_for_bnsh_coord.y * n +
strides_for_bnsh_coord.z * s + strides_for_bnsh_coord.w * h;
}
// store intermediate strides in the canonical (b,n,s,h) coordinate order
longlong4 strides_for_bnsh_coord;
};
@ -200,13 +209,15 @@ std::tuple<const HipT*, const HipT*, const HipT*> ConvertToOffsetedBufferViews(
const T* value = nullptr, //
const T* present = nullptr, // present or present_k
const T* present_v = nullptr) {
ORT_UNUSED_PARAMETER(present_v);
switch (attn->mode) {
case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: {
case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE:
case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE:
case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE: {
return {reinterpret_cast<const HipT*>(query),
reinterpret_cast<const HipT*>(key),
reinterpret_cast<const HipT*>(value)};
}
case QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE:
case QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE: {
auto offset = static_cast<int64_t>(attn->batch_size) * attn->num_heads * attn->total_sequence_length *
attn->head_size;
@ -214,6 +225,25 @@ std::tuple<const HipT*, const HipT*, const HipT*> ConvertToOffsetedBufferViews(
reinterpret_cast<const HipT*>(present),
reinterpret_cast<const HipT*>(present) + offset};
}
case QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE:
case QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE: {
auto offset = static_cast<int64_t>(attn->batch_size) * attn->num_heads * attn->max_sequence_length *
attn->head_size;
return {reinterpret_cast<const HipT*>(query),
reinterpret_cast<const HipT*>(present),
reinterpret_cast<const HipT*>(present) + offset};
}
case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH:
case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH:
case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH:
case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH:
case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH:
case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH:
case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH:
case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH:
return {reinterpret_cast<const HipT*>(query),
reinterpret_cast<const HipT*>(present),
reinterpret_cast<const HipT*>(present_v)};
case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: {
auto packed_kv = reinterpret_cast<const HipT*>(key);
return {reinterpret_cast<const HipT*>(query), packed_kv, packed_kv + attn->head_size};
@ -234,7 +264,8 @@ inline std::tuple<Strides, Strides, Strides> GetQkvStrides(const RocmAttentionPa
const int& N = attn->num_heads;
const int& S = attn->sequence_length;
const int& L = attn->kv_sequence_length;
// const int& T = attn->total_sequence_length;
const int& T = attn->total_sequence_length;
const int& M = attn->max_sequence_length;
const int& H = attn->head_size;
const int& Hv = attn->v_head_size;
@ -253,6 +284,36 @@ inline std::tuple<Strides, Strides, Strides> GetQkvStrides(const RocmAttentionPa
Strides::BSNHMemory(B, L, N, Hv),
};
}
case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH:
case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH:
case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH:
case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH:
return {
Strides::BSNHMemory(B, S, N, H),
Strides::BNSHMemory(B, N, T, H),
Strides::BNSHMemory(B, N, T, Hv),
};
case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH:
case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH:
case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH:
case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH:
return {
Strides::BSNHMemory(B, S, N, H),
Strides::BNSHMemory(B, N, M, H),
Strides::BNSHMemory(B, N, M, Hv),
};
case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE:
return {
Strides::BSNHMemory(B, S, N, H),
Strides::BSNHMemory(B, L, N, H),
Strides::BSNHMemory(B, L, N, Hv),
};
case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE:
return {
Strides::BSNHMemory(B, S, N, H),
Strides::BNSHMemory(B, N, L, H),
Strides::BNSHMemory(B, N, L, Hv),
};
case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE:
return {
Strides::BSNHMemory(B, S, N, H),
@ -310,8 +371,10 @@ struct GemmSoftmaxGemmPermuteParams : onnxruntime::rocm::tunable::OpParams {
"_T", attention->total_sequence_length,
"_N", attention->num_heads,
"_H", attention->head_size,
"_Hv", attention->v_head_size,
bias_buffer != nullptr ? "_B" : "_NB",
"_M", mask_index_dims.size(),
"_QKV", attention->qkv_format,
"_MODE", attention->mode);
}
@ -320,7 +383,7 @@ struct GemmSoftmaxGemmPermuteParams : onnxruntime::rocm::tunable::OpParams {
auto m = attention->sequence_length;
auto n = attention->total_sequence_length;
auto k = attention->head_size;
auto o = attention->head_size;
auto o = attention->v_head_size;
auto batch = attention->batch_size * attention->num_heads;
return {m, n, k, o, batch};
}
@ -491,6 +554,16 @@ class GemmSoftmaxGemmPermuteTunableOp : public tunable::TunableOp<GemmSoftmaxGem
} else {
return false;
}
case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH:
case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH:
case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH:
case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH:
case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE:
case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE:
case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH:
case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH:
case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH:
case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH:
case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE:
case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE:
return true;
@ -554,7 +627,7 @@ class GemmSoftmaxGemmPermuteTunableOp : public tunable::TunableOp<GemmSoftmaxGem
*reinterpret_cast<StoreT*>(out + out_offset) = store;
} else {
#pragma unroll
for (int i = tidx; i < mask_lengths.z; i++) {
for (int i = 0; i < mask_lengths.z - tidx; i++) {
out[out_offset + i] = cvt(mask_buffer[in_offset + i]);
}
}
@ -639,7 +712,6 @@ auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() {
{
auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch();
ORT_ENFORCE(M == m && N == n && K == k && O == o && G0 * G1 == batch, "semantic mismatch");
ORT_ENFORCE(K == O, "inner product dimension mismatch");
}
auto [qs, ks, vs] = GetQkvStrides(attn);

View file

@ -58,9 +58,11 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* past_key = context->Input<Tensor>(6);
const Tensor* past_value = context->Input<Tensor>(7);
// TODO: Add support for bias, key_padding_mask and attention cache.
ORT_ENFORCE(bias == nullptr && key_padding_mask == nullptr && past_key == nullptr && past_value == nullptr,
"bias, key_padding_mask and attention cache is not supported");
if (nullptr != bias) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"qkv_bias is not supported on ROCm EP. "
"User should fuse the qkv bias to qkv projection instead.");
}
auto& device_prop = GetDeviceProp();
RocmAttentionParameters attn;
@ -72,14 +74,10 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
&attn,
num_heads_, mask_filter_value_, scale_,
false, false, device_prop.maxThreadsPerBlock));
// TODO: support more qkv formats
ORT_ENFORCE(attn.qkv_format == Q_KV_BSNH_BSN2H || attn.qkv_format == QKV_BSN3H, "Got ", attn.qkv_format);
int sequence_length = attn.sequence_length;
TensorShapeVector output_shape(3);
output_shape[0] = static_cast<int64_t>(attn.batch_size);
output_shape[1] = static_cast<int64_t>(sequence_length);
output_shape[1] = static_cast<int64_t>(attn.sequence_length);
output_shape[2] = static_cast<int64_t>(attn.v_hidden_size);
Tensor* output = context->Output(0, output_shape);
@ -92,8 +90,6 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
TensorShape present_shape(present_dims);
Tensor* present_key = context->Output(1, present_shape);
Tensor* present_value = context->Output(2, present_shape);
// TODO: Add support for attention cache
ORT_ENFORCE(present_key == nullptr && present_value == nullptr, "attention cache is not supported");
ORT_RETURN_IF_ERROR(ClassifyAttentionMode(
Node().OpType(), &attn,
@ -101,26 +97,114 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
/*past=*/{past_key, past_value},
/*present=*/{present_key, present_value}));
using HipT = typename ToHipType<T>::MappedType;
using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp<HipT>;
auto workspace_bytes = AttentionTunableOp::GetWorkspaceNumBytes(&attn);
auto workspace = GetScratchBuffer<void>(workspace_bytes, context->GetComputeStream());
hipStream_t stream = Stream(context);
if (nullptr != present_key) { // process past present concat
Strides dst_strides;
int4 past_shape;
Strides past_src_strides;
const HipT* past_key_src;
const HipT* past_value_src;
HipT* past_key_dst{};
HipT* past_value_dst{};
int4 add_shape;
Strides add_src_strides;
const HipT* add_key_src =reinterpret_cast<const HipT*>(key->DataRaw());
const HipT* add_value_src = reinterpret_cast<const HipT*>(value->DataRaw());
HipT* add_key_dst;
HipT* add_value_dst;
if (attn.mode == BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH ||
attn.mode == BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH) {
dst_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size);
past_shape = {attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size};
past_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size);
past_key_src = reinterpret_cast<const HipT*>(past_key->DataRaw());
past_value_src = reinterpret_cast<const HipT*>(past_value->DataRaw());
past_key_dst = reinterpret_cast<HipT*>(present_key->MutableDataRaw());
past_value_dst = reinterpret_cast<HipT*>(present_value->MutableDataRaw());
if (attn.mode == BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH) {
add_src_strides = Strides::BSNHMemory(attn.batch_size, attn.kv_sequence_length, attn.num_heads, attn.head_size);
} else if (attn.mode == BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH) {
add_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size);
}
} else if (attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH ||
attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH) {
dst_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size);
if (attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH) {
add_src_strides = Strides::BSNHMemory(attn.batch_size, attn.kv_sequence_length, attn.num_heads, attn.head_size);
} else if (attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH) {
add_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size);
}
} else if (
attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH ||
attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH ||
attn.mode == BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH ||
attn.mode == BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH) {
dst_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.max_sequence_length, attn.head_size);
if (attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH || attn.mode == BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH) {
add_src_strides = Strides::BSNHMemory(attn.batch_size, attn.kv_sequence_length, attn.num_heads, attn.head_size);
} else if (attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH || attn.mode == BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH) {
add_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size);
}
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"past present concatenation is not implemented for attention mode ", attn.mode);
}
add_shape = {attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size}; // kernel in coord (b,n,s,h)
add_key_dst = reinterpret_cast<HipT*>(present_key->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0);
add_value_dst = reinterpret_cast<HipT*>(present_value->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0);
if (past_key_dst) {
ORT_RETURN_IF_ERROR(LaunchStridedCopy(
stream, past_key_src, past_shape, past_src_strides.ForBNSHCoord(),
past_key_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock));
}
if (past_value_dst) {
ORT_RETURN_IF_ERROR(LaunchStridedCopy(
stream, past_value_src, past_shape, past_src_strides.ForBNSHCoord(),
past_value_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock));
}
ORT_RETURN_IF_ERROR(LaunchStridedCopy(
stream, add_key_src, add_shape, add_src_strides.ForBNSHCoord(),
add_key_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock));
ORT_RETURN_IF_ERROR(LaunchStridedCopy(
stream, add_value_src, add_shape, add_src_strides.ForBNSHCoord(),
add_value_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock));
}
GemmSoftmaxGemmPermuteParams<HipT> params;
params.tuning_ctx = GetTuningContext();
params.stream = Stream(context);
params.stream = stream;
params.handle = GetRocblasHandle(context);
params.attention = &attn;
params.device_prop = &device_prop;
params.scale = scale_ == 0 ? 1.0f / sqrt(attn.head_size) : scale_;
std::tie(params.q_buffer, params.k_buffer, params.v_buffer) = ConvertToOffsetedBufferViews<HipT>(
&attn,
query->DataRaw(),
key == nullptr ? nullptr : key->DataRaw(),
value == nullptr ? nullptr : value->DataRaw());
nullptr == query ? nullptr : reinterpret_cast<const HipT*>(query->DataRaw()),
nullptr == key ? nullptr : reinterpret_cast<const HipT*>(key->DataRaw()),
nullptr == value ? nullptr : reinterpret_cast<const HipT*>(value->DataRaw()),
nullptr == present_key ? nullptr : reinterpret_cast<const HipT*>(present_key->DataRaw()),
nullptr == present_value ? nullptr : reinterpret_cast<const HipT*>(present_value->DataRaw()));
params.out_buffer = reinterpret_cast<HipT*>(output->MutableDataRaw());
if (key_padding_mask != nullptr) {
params.mask_index_buffer = key_padding_mask->Data<int>();
params.mask_index_dims = key_padding_mask->Shape().AsShapeVector();
}
if (relative_position_bias != nullptr) {
params.bias_buffer = reinterpret_cast<const HipT*>(relative_position_bias->DataRaw());
}

View file

@ -239,6 +239,12 @@ static void RunAttentionTest(
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (enable_rocm) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultRocmExecutionProvider(/*test_tunable_op=*/true));
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (enable_cpu) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());

View file

@ -9,7 +9,7 @@
#include "test/util/include/scoped_env_vars.h"
#include "test/contrib_ops/attention_op_test_helper.h"
#if defined(USE_ROCM) && defined(USE_COMPOSABLE_KERNEL)
#if defined(USE_ROCM) && defined(USE_COMPOSABLE_KERNEL) && !defined(USE_MIGRAPHX)
#define DISABLE_ROCM false
#else
#define DISABLE_ROCM true
@ -21,12 +21,6 @@
#define ROCM_GTEST_SKIP(message)
#endif
#if defined(USE_MIGRAPHX)
#define MIGX_GTEST_SKIP(message) GTEST_SKIP_(message)
#else
#define MIGX_GTEST_SKIP(message)
#endif
namespace onnxruntime {
namespace test {
@ -66,6 +60,16 @@ static void RunMultiHeadAttentionTest(
bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()) && !use_float16 && !disable_cpu;
bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get()) && !disable_dml;
if (enable_rocm && !use_float16) {
LOGS_DEFAULT(WARNING) << "ROCm MHA only have kernel for half datatype implemented, skip float datatype tests";
enable_rocm = false;
}
if (enable_rocm && !bias_data.empty()) {
LOGS_DEFAULT(WARNING) << "ROCm MHA does not support qkv_bias, skip qkv_bias tests";
enable_rocm = false;
}
if (enable_cpu || enable_cuda || enable_rocm || enable_dml) {
OpTester tester("MultiHeadAttention", 1, onnxruntime::kMSDomain);
tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(num_heads));
@ -457,7 +461,6 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu
// Test fused cross attention kernel
// It requires head_size > 32 and head_size <= 64 for T4 GPU; hidden_size == v_hidden_size.
TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize40) {
ROCM_GTEST_SKIP("ROCm MHA does not support bias");
AttentionTestData data;
GetCrossAttentionData_HeadSize40(data);
RunMultiHeadAttentionTests(data);
@ -467,7 +470,7 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize40) {
}
TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_Mask1D) {
ROCM_GTEST_SKIP("ROCm MHA does not support mask");
ROCM_GTEST_SKIP("ROCm MHA does not support mask type of MASK_1D_KEY_SEQ_LEN");
AttentionTestData data;
GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, true);
RunMultiHeadAttentionTests(data, true);
@ -477,7 +480,7 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_M
}
TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_Mask2D) {
ROCM_GTEST_SKIP("ROCm MHA does not support mask");
ROCM_GTEST_SKIP("ROCm MHA expect failure due to ck bug");
AttentionTestData data;
GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, false);
RunMultiHeadAttentionTests(data, true);
@ -487,7 +490,6 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_M
}
TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize32_LeftSidePadding_Mask2D) {
ROCM_GTEST_SKIP("ROCm MHA does not support mask");
AttentionTestData data;
GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(data);
RunMultiHeadAttentionTests(data, true);
@ -497,14 +499,12 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize32_LeftSidePadding_Ma
}
TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_NoBias_NoMask_PackedKV) {
MIGX_GTEST_SKIP("MIGX MHA does not support Packed KV");
AttentionTestData data;
GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(data);
RunMultiHeadAttentionTests(data);
}
TEST(MultiHeadAttentionTest, SelfAttention_Batch2_HeadSize32_NoBias_NoMask_PackedQKV) {
MIGX_GTEST_SKIP("MIGX MHA does not support Packed QKV");
AttentionTestData data;
GetSelfAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedQKV(data);
RunMultiHeadAttentionTests(data);
@ -513,7 +513,6 @@ TEST(MultiHeadAttentionTest, SelfAttention_Batch2_HeadSize32_NoBias_NoMask_Packe
// This tests qk_head_size != v_head_size
TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize16_8) {
ROCM_GTEST_SKIP("ROCm MHA does not support bias");
AttentionTestData data;
GetCrossAttentionData_HeadSize16_8(data);
RunMultiHeadAttentionTests(data);
@ -523,7 +522,6 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize16_8) {
}
TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize16) {
ROCM_GTEST_SKIP("ROCm MHA does not support bias");
AttentionTestData data;
GetCrossAttentionData_HeadSize16(data);
RunMultiHeadAttentionTests(data);
@ -533,21 +531,21 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize16) {
}
TEST(MultiHeadAttentionTest, CrossAttentionWithPast) {
ROCM_GTEST_SKIP("ROCm MHA does not support attention cache");
ROCM_GTEST_SKIP("ROCm MHA only support head_size >= 8");
AttentionTestData data;
GetCrossAttentionDataWithPast(data);
RunMultiHeadAttentionTests(data);
}
TEST(MultiHeadAttentionTest, SelfAttention_WithPast_WithRelPosBias_ForT5) {
ROCM_GTEST_SKIP("ROCm MHA does not support attention cache");
ROCM_GTEST_SKIP("ROCm MHA only support head_size >= 8");
AttentionTestData data;
GetSelfAttentionData_WithPast_WithRelPosBias_ForT5(data);
RunMultiHeadAttentionTests(data, true);
}
TEST(MultiHeadAttentionTest, AttentionCutlassRelPosBias) {
ROCM_GTEST_SKIP("ROCm does not support cutlass");
// ROCM_GTEST_SKIP("ROCm does not support cutlass");
AttentionTestData data;
GetAttentionDataCutlassRelPosBias(data);
RunMultiHeadAttentionTests(data);
@ -555,7 +553,6 @@ TEST(MultiHeadAttentionTest, AttentionCutlassRelPosBias) {
TEST(MultiHeadAttentionTest, CrossAttention_DiffSequenceLengths) {
// Whisper decoder cross attention without mask and different sequence lengths for Q and K/V
ROCM_GTEST_SKIP("ROCm not supported");
AttentionTestData data;
GetCrossAttentionData_DiffSequenceLengths(data);
RunMultiHeadAttentionTests(data);
@ -569,7 +566,6 @@ TEST(MultiHeadAttentionTest, CrossAttention_DiffSequenceLengths) {
TEST(MultiHeadAttentionTest, SelfAttention_WithPastAndPresent_NoMask_NoRelPosBias) {
// Whisper decoder self attention with past_kv and present_kv
ROCM_GTEST_SKIP("ROCm not supported");
AttentionTestData data;
GetSelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias(data);
RunMultiHeadAttentionTests(data);
@ -583,7 +579,6 @@ TEST(MultiHeadAttentionTest, SelfAttention_WithPastAndPresent_NoMask_NoRelPosBia
TEST(MultiHeadAttentionTest, CrossAttention_WithPastPassedInDirectly_NoMask) {
// Whisper decoder cross attention with past_kv in place of current KV and no present_kv
ROCM_GTEST_SKIP("ROCm not supported");
AttentionTestData data;
GetCrossAttentionData_WithPastPassedInDirectly_NoMask(data);
RunMultiHeadAttentionTests(data);