mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
[ROCm] Add attention kv cache for decoding (#16076)
This commit is contained in:
parent
96471491d7
commit
9110e5b9bd
9 changed files with 435 additions and 88 deletions
|
|
@ -180,6 +180,12 @@ Status LaunchConcatPastToPresent(cudaStream_t stream,
|
|||
const half* past,
|
||||
const half* k_v,
|
||||
half* present);
|
||||
|
||||
template <typename T>
|
||||
Status LaunchStridedCopy(cudaStream_t stream,
|
||||
const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h)
|
||||
T* out, longlong4 out_strides, // coord (b,n,s,h)
|
||||
int max_threads_per_block);
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
158
onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu
Normal file
158
onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "contrib_ops/cuda/bert/attention_impl.h"
|
||||
|
||||
using namespace onnxruntime::cuda;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
template <typename T>
|
||||
__global__ void StridedCopy(const T* in, const int H, longlong4 in_strides, // coord (b,n,s,h)
|
||||
T* out, longlong4 out_strides // coord (b,n,s,h)
|
||||
) {
|
||||
const int h = threadIdx.x;
|
||||
const int n = threadIdx.y;
|
||||
const int s = blockIdx.x;
|
||||
const int b = blockIdx.y;
|
||||
if (h < H) {
|
||||
const int in_offset = b * in_strides.x + n * in_strides.y + s * in_strides.z + h * in_strides.w;
|
||||
const int out_offset = b * out_strides.x + n * out_strides.y + s * out_strides.z + h * out_strides.w;
|
||||
out[out_offset] = in[in_offset];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void StridedCopyLarge(const T* in, const int H, longlong4 in_strides, // coord (b,n,s,h)
|
||||
T* out, longlong4 out_strides // coord (b,n,s,h)
|
||||
) {
|
||||
// Use when (H*)*num_heads > 1024
|
||||
int h = threadIdx.x;
|
||||
const int n = threadIdx.y;
|
||||
const int s = blockIdx.x;
|
||||
const int b = blockIdx.y;
|
||||
|
||||
const int h_step = blockDim.x;
|
||||
|
||||
while (h < H) {
|
||||
const int in_offset = b * in_strides.x + n * in_strides.y + s * in_strides.z + h * in_strides.w;
|
||||
const int out_offset = b * out_strides.x + n * out_strides.y + s * out_strides.z + h * out_strides.w;
|
||||
out[out_offset] = in[in_offset];
|
||||
h += h_step;
|
||||
}
|
||||
}
|
||||
|
||||
template <int NumBytes>
|
||||
struct ToByteType;
|
||||
|
||||
template <>
|
||||
struct ToByteType<2> {
|
||||
using T = int16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ToByteType<4> {
|
||||
using T = int32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ToByteType<8> {
|
||||
using T = int64_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ToByteType<16> {
|
||||
using T = uint4;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ToByteType<32> {
|
||||
using T = ulonglong4;
|
||||
};
|
||||
|
||||
template <int NumBytes>
|
||||
using ToBytes = typename ToByteType<NumBytes>::T;
|
||||
|
||||
template <typename T>
|
||||
Status LaunchStridedCopy(cudaStream_t stream,
|
||||
const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h)
|
||||
T* out, longlong4 out_strides, // coord (b,n,s,h)
|
||||
int max_threads_per_block) {
|
||||
int batch_size = in_shape.x;
|
||||
int num_heads = in_shape.y;
|
||||
int sequence_length = in_shape.z;
|
||||
int head_size = in_shape.w;
|
||||
if (sequence_length == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const dim3 grid(sequence_length, batch_size);
|
||||
if (0 == (head_size % 4)) { // pack 4 element together
|
||||
using Bytes = ToBytes<sizeof(T) * 4>;
|
||||
const int H = head_size / 4;
|
||||
in_strides.x /= 4;
|
||||
in_strides.y /= 4;
|
||||
in_strides.z /= 4;
|
||||
out_strides.x /= 4;
|
||||
out_strides.y /= 4;
|
||||
out_strides.z /= 4;
|
||||
if (H * num_heads <= max_threads_per_block) {
|
||||
const dim3 block(H, num_heads, 1);
|
||||
StridedCopy<Bytes><<<grid, block, 0, stream>>>(reinterpret_cast<const Bytes*>(in), H, in_strides,
|
||||
reinterpret_cast<Bytes*>(out), out_strides);
|
||||
} else {
|
||||
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
|
||||
StridedCopyLarge<Bytes><<<grid, block, 0, stream>>>(reinterpret_cast<const Bytes*>(in), H, in_strides,
|
||||
reinterpret_cast<Bytes*>(out), out_strides);
|
||||
}
|
||||
} else if (0 == (head_size % 2)) { // pack 2 element together
|
||||
using Bytes = ToBytes<sizeof(T) * 2>;
|
||||
const int H = head_size / 2;
|
||||
in_strides.x /= 2;
|
||||
in_strides.y /= 2;
|
||||
in_strides.z /= 2;
|
||||
out_strides.x /= 2;
|
||||
out_strides.y /= 2;
|
||||
out_strides.z /= 2;
|
||||
if (H * num_heads <= max_threads_per_block) {
|
||||
const dim3 block(H, num_heads, 1);
|
||||
StridedCopy<Bytes><<<grid, block, 0, stream>>>(reinterpret_cast<const Bytes*>(in), H, in_strides,
|
||||
reinterpret_cast<Bytes*>(out), out_strides);
|
||||
} else {
|
||||
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
|
||||
StridedCopyLarge<Bytes><<<grid, block, 0, stream>>>(reinterpret_cast<const Bytes*>(in), H, in_strides,
|
||||
reinterpret_cast<Bytes*>(out), out_strides);
|
||||
}
|
||||
} else {
|
||||
using Bytes = ToBytes<sizeof(T)>;
|
||||
if (head_size * num_heads <= max_threads_per_block) {
|
||||
const dim3 block(head_size, num_heads, 1);
|
||||
StridedCopy<Bytes><<<grid, block, 0, stream>>>(reinterpret_cast<const Bytes*>(in), head_size, in_strides,
|
||||
reinterpret_cast<Bytes*>(out), out_strides);
|
||||
} else {
|
||||
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
|
||||
StridedCopyLarge<Bytes><<<grid, block, 0, stream>>>(reinterpret_cast<const Bytes*>(in), head_size, in_strides,
|
||||
reinterpret_cast<Bytes*>(out), out_strides);
|
||||
}
|
||||
}
|
||||
return CUDA_CALL(cudaGetLastError());
|
||||
}
|
||||
|
||||
template Status LaunchStridedCopy<float>(
|
||||
cudaStream_t stream,
|
||||
const float* in, int4 in_shape, longlong4 in_strides,
|
||||
float* out, longlong4 out_strides,
|
||||
int max_threads_per_block);
|
||||
|
||||
template Status LaunchStridedCopy<half>(
|
||||
cudaStream_t stream,
|
||||
const half* in, int4 in_shape, longlong4 in_strides,
|
||||
half* out, longlong4 out_strides,
|
||||
int max_threads_per_block);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -88,10 +88,11 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
|
||||
ORT_RETURN_IF_ERROR(ClassifyAttentionMode(
|
||||
Node().OpType(), &attn, /*qkv=*/{}, /*past=*/{past}, /*present=*/{present}));
|
||||
// TODO: support QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE and QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE
|
||||
ORT_ENFORCE(attn.mode == QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE ||
|
||||
attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE ||
|
||||
attn.mode == QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE);
|
||||
attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE ||
|
||||
attn.mode == QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE ||
|
||||
attn.mode == QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE);
|
||||
|
||||
size_t qkv_project_output_bytes = QkvProjectGeneric::GetOutputNumBytes(&attn);
|
||||
size_t shared_workspace_bytes = std::max(QkvProjectGeneric::GetWorkspaceNumBytes(&attn),
|
||||
|
|
@ -123,27 +124,46 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
ORT_RETURN_IF_ERROR(QkvProjectGeneric::Run(&gemm_permute_params));
|
||||
auto [q_buffer, k_buffer, v_buffer] = QkvProjectGeneric::UnspliceOutputQKV(&gemm_permute_params);
|
||||
|
||||
// NOTE: GemmPermute always output 3BNSH, k_buffer and v_buffer can be treated as 2BNSH
|
||||
if (nullptr != present) {
|
||||
// Concat past (2xBxNxS'xH) to present (2xBxNxTxH):
|
||||
// past_k (BxNxS'xH) + k (BxNxSxH) => present_k (BxNxTxH)
|
||||
// past_v (BxNxS'xH) + v (BxNxSxH) => present_v (BxNxTxH)
|
||||
const int batches = attn.batch_size * attn.num_heads;
|
||||
const int present_size_per_batch = attn.total_sequence_length * attn.head_size;
|
||||
ORT_RETURN_IF_ERROR(
|
||||
LaunchConcatPastToPresent(Stream(context),
|
||||
attn.total_sequence_length,
|
||||
attn.sequence_length,
|
||||
attn.batch_size,
|
||||
attn.head_size,
|
||||
attn.num_heads,
|
||||
device_prop.maxThreadsPerBlock,
|
||||
nullptr == past ? nullptr : reinterpret_cast<const HipT*>(past->DataRaw()),
|
||||
k_buffer,
|
||||
reinterpret_cast<HipT*>(present->MutableDataRaw())));
|
||||
Strides dst_strides; // the output buffer is present Tensor, the buffer is the same
|
||||
|
||||
// update pointers to present_k and present_v.
|
||||
int4 add_shape{2 * attn.batch_size, attn.num_heads, attn.sequence_length, attn.head_size};
|
||||
HipT* add_dest = nullptr; // destination of concatenated data to present
|
||||
const HipT* const add_src = k_buffer; // source of concatenated data to present
|
||||
const auto add_src_strides = Strides::BNSHMemory(
|
||||
2 * attn.batch_size, attn.num_heads, attn.sequence_length, attn.head_size);
|
||||
|
||||
if (attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE) {
|
||||
dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size);
|
||||
add_dest = reinterpret_cast<HipT*>(present->MutableDataRaw()) /* + dst_strides.OffsetAt(0, 0, 0, 0)*/;
|
||||
} else if (attn.mode == QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE) {
|
||||
dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size);
|
||||
add_dest = reinterpret_cast<HipT*>(present->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0);
|
||||
|
||||
// We only need to copy past to present in this case. All other cases will be build the present incrementally
|
||||
const int4 past_shape = {2 * attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size};
|
||||
HipT* const past_dest = reinterpret_cast<HipT*>(present->MutableDataRaw());
|
||||
const HipT* const past_src = reinterpret_cast<const HipT*>(past->DataRaw());
|
||||
const Strides past_src_strides = Strides::BNSHMemory(
|
||||
2 * attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size);
|
||||
|
||||
ORT_RETURN_IF_ERROR(LaunchStridedCopy(stream, past_src, past_shape, past_src_strides.ForBNSHCoord(),
|
||||
past_dest, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock));
|
||||
} else if (attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE) {
|
||||
dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.max_sequence_length, attn.head_size);
|
||||
add_dest = reinterpret_cast<HipT*>(present->MutableDataRaw()) /* + dst_strides.OffsetAt(0, 0, 0, 0)*/;
|
||||
} else if (attn.mode == QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE) {
|
||||
dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.max_sequence_length, attn.head_size);
|
||||
add_dest = reinterpret_cast<HipT*>(present->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0);
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(LaunchStridedCopy(stream, add_src, add_shape, add_src_strides.ForBNSHCoord(),
|
||||
add_dest, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock));
|
||||
|
||||
// update pointers to present_k and present_v. TODO: switch to ConvertToOffsetedBufferViews
|
||||
k_buffer = reinterpret_cast<HipT*>(present->MutableDataRaw());
|
||||
v_buffer = reinterpret_cast<HipT*>(present->MutableDataRaw()) + batches * present_size_per_batch;
|
||||
v_buffer = reinterpret_cast<HipT*>(present->MutableDataRaw()) + dst_strides.OffsetAt(attn.batch_size, 0, 0, 0);
|
||||
}
|
||||
|
||||
// For testing, environment variable ORT_TRANSFORMER_OPTIONS=1 could enable persistent softmax
|
||||
|
|
@ -160,6 +180,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
params.device_prop = &device_prop;
|
||||
// FIXME: the params.scale seems to be different from AttentionParameters::scale;
|
||||
params.scale = 1.0f / sqrt(static_cast<float>(attn.head_size));
|
||||
// TODO: switch to ConvertToOffsetedBufferViews
|
||||
params.q_buffer = q_buffer;
|
||||
params.k_buffer = k_buffer;
|
||||
params.v_buffer = v_buffer;
|
||||
|
|
|
|||
|
|
@ -122,6 +122,24 @@ Status ClassifyAttentionMode(
|
|||
attn->mode = BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE;
|
||||
return Status::OK();
|
||||
}
|
||||
} else if (num_qkv == 3 && num_past == 0 && num_present == 2) {
|
||||
if (attn->past_present_share_buffer == false) {
|
||||
if (attn->qkv_format == Q_K_V_BSNH) {
|
||||
attn->mode = BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH;
|
||||
return Status::OK();
|
||||
} else if (attn->pass_past_in_kv) {
|
||||
attn->mode = BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH;
|
||||
return Status::OK();
|
||||
}
|
||||
} else {
|
||||
if (attn->qkv_format == Q_K_V_BSNH) {
|
||||
attn->mode = BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH;
|
||||
return Status::OK();
|
||||
} else if (attn->pass_past_in_kv) {
|
||||
attn->mode = BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
} else if (num_qkv == 3 && num_past == 2 && num_present == 2) {
|
||||
if (attn->past_present_share_buffer == false) {
|
||||
if (attn->qkv_format == Q_K_V_BSNH) {
|
||||
|
|
|
|||
|
|
@ -98,28 +98,6 @@ Status LaunchConcatTensorToTensor(hipStream_t stream,
|
|||
const half* tensor_add,
|
||||
half* tensor_out);
|
||||
|
||||
Status LaunchConcatPastToPresent(hipStream_t stream,
|
||||
const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int batch_size,
|
||||
const int head_size,
|
||||
const int num_heads,
|
||||
const int max_threads_per_block,
|
||||
const float* past,
|
||||
const float* k_v,
|
||||
float* present);
|
||||
|
||||
Status LaunchConcatPastToPresent(hipStream_t stream,
|
||||
const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int batch_size,
|
||||
const int head_size,
|
||||
const int num_heads,
|
||||
const int max_threads_per_block,
|
||||
const half* past,
|
||||
const half* k_v,
|
||||
half* present);
|
||||
|
||||
inline rocblas_status _compat_rocblas_gemm_strided_batched_ex(rocblas_handle handle,
|
||||
rocblas_operation transa,
|
||||
rocblas_operation transb,
|
||||
|
|
@ -191,6 +169,10 @@ enum AttentionMode {
|
|||
QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE,
|
||||
BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE,
|
||||
BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE,
|
||||
BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH,
|
||||
BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH,
|
||||
BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH,
|
||||
BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH,
|
||||
BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH,
|
||||
BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH,
|
||||
BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH,
|
||||
|
|
@ -209,6 +191,11 @@ Status ClassifyAttentionMode(const std::string& op,
|
|||
const std::vector<const Tensor*>& past,
|
||||
const std::vector<Tensor*>& present);
|
||||
|
||||
template <typename T>
|
||||
Status LaunchStridedCopy(hipStream_t stream,
|
||||
const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h)
|
||||
T* out, longlong4 out_strides, // coord (b,n,s,h)
|
||||
int max_threads_per_block);
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -67,6 +67,10 @@ are in composable kernels. The scale and add logic is performed via Acc0ElementO
|
|||
| BSNH | BLNH*| BLNH^| - | - | - | - | MHA basic
|
||||
| BSNH | BNLH*| BNLH^| - | - | - | - | MHA cross, pass_past_in_kv = true
|
||||
| BSNH | - | - | - | - | BNLH * | BNLH ^ | MHA cross, pass_past_in_kv = false
|
||||
| BSNH | BLNH | BLNH | - | - | BNTH * | BNTH ^ | MHA cross, past_present_share_buffer = false
|
||||
| BSNH | BNLH | BNLH | - | - | BNTH * | BNTH ^ | MHA cross, past_present_share_buffer = false
|
||||
| BSNH | BLNH | BLNH | - | - | BNMH * | BNMH ^ | MHA cross, past_present_share_buffer = true
|
||||
| BSNH | BNLH | BNLH | - | - | BNMH * | BNMH ^ | MHA cross, past_present_share_buffer = true
|
||||
| BSNH | BLNH | BLNH | BNPH | BNPH | BNTH * | BNTH ^ | MHA self, past_present_share_buffer = false
|
||||
| BSNH | BNLH | BNLH | BNPH | BNPH | BNTH * | BNTH ^ | MHA self, past_present_share_buffer = false
|
||||
| BSNH | BLNH | BLNH | BNMH | BNMH | BNMH * | BNMH ^ | MHA self, past_present_share_buffer = true
|
||||
|
|
@ -76,7 +80,7 @@ are in composable kernels. The scale and add logic is performed via Acc0ElementO
|
|||
|
||||
Q, K, V, past(K), pastV, present(K), presentV is the Input of the contrib OpKernel
|
||||
|
||||
About k_buffer and v_buffer, we always explicitly concat past to present and use present_k for k_buffer and present_b for v_buffer
|
||||
About k_buffer and v_buffer, we always explicitly concat past to present and use present_k for k_buffer and present_v for v_buffer
|
||||
|
||||
- Marked with `*` indicate the Tensor is used for k_buffer passing.
|
||||
- Marked with `^` indicate the Tensor is used for v_buffer passing.
|
||||
|
|
@ -162,7 +166,7 @@ struct Strides {
|
|||
}
|
||||
|
||||
template <typename T = longlong4>
|
||||
T ForBNSHCoord() {
|
||||
T ForBNSHCoord() const {
|
||||
using E = typename T::value_type;
|
||||
return T{static_cast<E>(strides_for_bnsh_coord.x),
|
||||
static_cast<E>(strides_for_bnsh_coord.y),
|
||||
|
|
@ -171,7 +175,7 @@ struct Strides {
|
|||
}
|
||||
|
||||
template <typename T = longlong4>
|
||||
T ForBSNHCoord() {
|
||||
T ForBSNHCoord() const {
|
||||
using E = typename T::value_type;
|
||||
return T{static_cast<E>(strides_for_bnsh_coord.x),
|
||||
static_cast<E>(strides_for_bnsh_coord.z),
|
||||
|
|
@ -180,7 +184,7 @@ struct Strides {
|
|||
}
|
||||
|
||||
template <typename T = longlong4>
|
||||
T ForBNHSCoord() {
|
||||
T ForBNHSCoord() const {
|
||||
using E = typename T::value_type;
|
||||
return T{static_cast<E>(strides_for_bnsh_coord.x),
|
||||
static_cast<E>(strides_for_bnsh_coord.y),
|
||||
|
|
@ -188,6 +192,11 @@ struct Strides {
|
|||
static_cast<E>(strides_for_bnsh_coord.z)};
|
||||
}
|
||||
|
||||
int64_t OffsetAt(int b, int n, int s, int h) const {
|
||||
return strides_for_bnsh_coord.x * b + strides_for_bnsh_coord.y * n +
|
||||
strides_for_bnsh_coord.z * s + strides_for_bnsh_coord.w * h;
|
||||
}
|
||||
|
||||
// store intermediate strides in the canonical (b,n,s,h) coordinate order
|
||||
longlong4 strides_for_bnsh_coord;
|
||||
};
|
||||
|
|
@ -200,13 +209,15 @@ std::tuple<const HipT*, const HipT*, const HipT*> ConvertToOffsetedBufferViews(
|
|||
const T* value = nullptr, //
|
||||
const T* present = nullptr, // present or present_k
|
||||
const T* present_v = nullptr) {
|
||||
ORT_UNUSED_PARAMETER(present_v);
|
||||
switch (attn->mode) {
|
||||
case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: {
|
||||
case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE:
|
||||
case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE:
|
||||
case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE: {
|
||||
return {reinterpret_cast<const HipT*>(query),
|
||||
reinterpret_cast<const HipT*>(key),
|
||||
reinterpret_cast<const HipT*>(value)};
|
||||
}
|
||||
case QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE:
|
||||
case QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE: {
|
||||
auto offset = static_cast<int64_t>(attn->batch_size) * attn->num_heads * attn->total_sequence_length *
|
||||
attn->head_size;
|
||||
|
|
@ -214,6 +225,25 @@ std::tuple<const HipT*, const HipT*, const HipT*> ConvertToOffsetedBufferViews(
|
|||
reinterpret_cast<const HipT*>(present),
|
||||
reinterpret_cast<const HipT*>(present) + offset};
|
||||
}
|
||||
case QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE:
|
||||
case QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE: {
|
||||
auto offset = static_cast<int64_t>(attn->batch_size) * attn->num_heads * attn->max_sequence_length *
|
||||
attn->head_size;
|
||||
return {reinterpret_cast<const HipT*>(query),
|
||||
reinterpret_cast<const HipT*>(present),
|
||||
reinterpret_cast<const HipT*>(present) + offset};
|
||||
}
|
||||
case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH:
|
||||
case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH:
|
||||
case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH:
|
||||
case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH:
|
||||
case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH:
|
||||
case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH:
|
||||
case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH:
|
||||
case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH:
|
||||
return {reinterpret_cast<const HipT*>(query),
|
||||
reinterpret_cast<const HipT*>(present),
|
||||
reinterpret_cast<const HipT*>(present_v)};
|
||||
case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: {
|
||||
auto packed_kv = reinterpret_cast<const HipT*>(key);
|
||||
return {reinterpret_cast<const HipT*>(query), packed_kv, packed_kv + attn->head_size};
|
||||
|
|
@ -234,7 +264,8 @@ inline std::tuple<Strides, Strides, Strides> GetQkvStrides(const RocmAttentionPa
|
|||
const int& N = attn->num_heads;
|
||||
const int& S = attn->sequence_length;
|
||||
const int& L = attn->kv_sequence_length;
|
||||
// const int& T = attn->total_sequence_length;
|
||||
const int& T = attn->total_sequence_length;
|
||||
const int& M = attn->max_sequence_length;
|
||||
const int& H = attn->head_size;
|
||||
const int& Hv = attn->v_head_size;
|
||||
|
||||
|
|
@ -253,6 +284,36 @@ inline std::tuple<Strides, Strides, Strides> GetQkvStrides(const RocmAttentionPa
|
|||
Strides::BSNHMemory(B, L, N, Hv),
|
||||
};
|
||||
}
|
||||
case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH:
|
||||
case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH:
|
||||
case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH:
|
||||
case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH:
|
||||
return {
|
||||
Strides::BSNHMemory(B, S, N, H),
|
||||
Strides::BNSHMemory(B, N, T, H),
|
||||
Strides::BNSHMemory(B, N, T, Hv),
|
||||
};
|
||||
case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH:
|
||||
case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH:
|
||||
case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH:
|
||||
case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH:
|
||||
return {
|
||||
Strides::BSNHMemory(B, S, N, H),
|
||||
Strides::BNSHMemory(B, N, M, H),
|
||||
Strides::BNSHMemory(B, N, M, Hv),
|
||||
};
|
||||
case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE:
|
||||
return {
|
||||
Strides::BSNHMemory(B, S, N, H),
|
||||
Strides::BSNHMemory(B, L, N, H),
|
||||
Strides::BSNHMemory(B, L, N, Hv),
|
||||
};
|
||||
case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE:
|
||||
return {
|
||||
Strides::BSNHMemory(B, S, N, H),
|
||||
Strides::BNSHMemory(B, N, L, H),
|
||||
Strides::BNSHMemory(B, N, L, Hv),
|
||||
};
|
||||
case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE:
|
||||
return {
|
||||
Strides::BSNHMemory(B, S, N, H),
|
||||
|
|
@ -310,8 +371,10 @@ struct GemmSoftmaxGemmPermuteParams : onnxruntime::rocm::tunable::OpParams {
|
|||
"_T", attention->total_sequence_length,
|
||||
"_N", attention->num_heads,
|
||||
"_H", attention->head_size,
|
||||
"_Hv", attention->v_head_size,
|
||||
bias_buffer != nullptr ? "_B" : "_NB",
|
||||
"_M", mask_index_dims.size(),
|
||||
"_QKV", attention->qkv_format,
|
||||
"_MODE", attention->mode);
|
||||
}
|
||||
|
||||
|
|
@ -320,7 +383,7 @@ struct GemmSoftmaxGemmPermuteParams : onnxruntime::rocm::tunable::OpParams {
|
|||
auto m = attention->sequence_length;
|
||||
auto n = attention->total_sequence_length;
|
||||
auto k = attention->head_size;
|
||||
auto o = attention->head_size;
|
||||
auto o = attention->v_head_size;
|
||||
auto batch = attention->batch_size * attention->num_heads;
|
||||
return {m, n, k, o, batch};
|
||||
}
|
||||
|
|
@ -491,6 +554,16 @@ class GemmSoftmaxGemmPermuteTunableOp : public tunable::TunableOp<GemmSoftmaxGem
|
|||
} else {
|
||||
return false;
|
||||
}
|
||||
case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH:
|
||||
case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH:
|
||||
case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH:
|
||||
case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH:
|
||||
case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE:
|
||||
case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE:
|
||||
case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH:
|
||||
case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH:
|
||||
case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH:
|
||||
case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH:
|
||||
case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE:
|
||||
case BLN3H_NONE_NONE_NONE_NONE_NONE_NONE:
|
||||
return true;
|
||||
|
|
@ -554,7 +627,7 @@ class GemmSoftmaxGemmPermuteTunableOp : public tunable::TunableOp<GemmSoftmaxGem
|
|||
*reinterpret_cast<StoreT*>(out + out_offset) = store;
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = tidx; i < mask_lengths.z; i++) {
|
||||
for (int i = 0; i < mask_lengths.z - tidx; i++) {
|
||||
out[out_offset + i] = cvt(mask_buffer[in_offset + i]);
|
||||
}
|
||||
}
|
||||
|
|
@ -639,7 +712,6 @@ auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() {
|
|||
{
|
||||
auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch();
|
||||
ORT_ENFORCE(M == m && N == n && K == k && O == o && G0 * G1 == batch, "semantic mismatch");
|
||||
ORT_ENFORCE(K == O, "inner product dimension mismatch");
|
||||
}
|
||||
|
||||
auto [qs, ks, vs] = GetQkvStrides(attn);
|
||||
|
|
|
|||
|
|
@ -58,9 +58,11 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
const Tensor* past_key = context->Input<Tensor>(6);
|
||||
const Tensor* past_value = context->Input<Tensor>(7);
|
||||
|
||||
// TODO: Add support for bias, key_padding_mask and attention cache.
|
||||
ORT_ENFORCE(bias == nullptr && key_padding_mask == nullptr && past_key == nullptr && past_value == nullptr,
|
||||
"bias, key_padding_mask and attention cache is not supported");
|
||||
if (nullptr != bias) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
|
||||
"qkv_bias is not supported on ROCm EP. "
|
||||
"User should fuse the qkv bias to qkv projection instead.");
|
||||
}
|
||||
|
||||
auto& device_prop = GetDeviceProp();
|
||||
RocmAttentionParameters attn;
|
||||
|
|
@ -72,14 +74,10 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
&attn,
|
||||
num_heads_, mask_filter_value_, scale_,
|
||||
false, false, device_prop.maxThreadsPerBlock));
|
||||
// TODO: support more qkv formats
|
||||
ORT_ENFORCE(attn.qkv_format == Q_KV_BSNH_BSN2H || attn.qkv_format == QKV_BSN3H, "Got ", attn.qkv_format);
|
||||
|
||||
int sequence_length = attn.sequence_length;
|
||||
|
||||
TensorShapeVector output_shape(3);
|
||||
output_shape[0] = static_cast<int64_t>(attn.batch_size);
|
||||
output_shape[1] = static_cast<int64_t>(sequence_length);
|
||||
output_shape[1] = static_cast<int64_t>(attn.sequence_length);
|
||||
output_shape[2] = static_cast<int64_t>(attn.v_hidden_size);
|
||||
Tensor* output = context->Output(0, output_shape);
|
||||
|
||||
|
|
@ -92,8 +90,6 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
TensorShape present_shape(present_dims);
|
||||
Tensor* present_key = context->Output(1, present_shape);
|
||||
Tensor* present_value = context->Output(2, present_shape);
|
||||
// TODO: Add support for attention cache
|
||||
ORT_ENFORCE(present_key == nullptr && present_value == nullptr, "attention cache is not supported");
|
||||
|
||||
ORT_RETURN_IF_ERROR(ClassifyAttentionMode(
|
||||
Node().OpType(), &attn,
|
||||
|
|
@ -101,26 +97,114 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
/*past=*/{past_key, past_value},
|
||||
/*present=*/{present_key, present_value}));
|
||||
|
||||
|
||||
using HipT = typename ToHipType<T>::MappedType;
|
||||
using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp<HipT>;
|
||||
auto workspace_bytes = AttentionTunableOp::GetWorkspaceNumBytes(&attn);
|
||||
auto workspace = GetScratchBuffer<void>(workspace_bytes, context->GetComputeStream());
|
||||
|
||||
hipStream_t stream = Stream(context);
|
||||
if (nullptr != present_key) { // process past present concat
|
||||
Strides dst_strides;
|
||||
|
||||
int4 past_shape;
|
||||
Strides past_src_strides;
|
||||
const HipT* past_key_src;
|
||||
const HipT* past_value_src;
|
||||
HipT* past_key_dst{};
|
||||
HipT* past_value_dst{};
|
||||
|
||||
int4 add_shape;
|
||||
Strides add_src_strides;
|
||||
const HipT* add_key_src =reinterpret_cast<const HipT*>(key->DataRaw());
|
||||
const HipT* add_value_src = reinterpret_cast<const HipT*>(value->DataRaw());
|
||||
HipT* add_key_dst;
|
||||
HipT* add_value_dst;
|
||||
|
||||
if (attn.mode == BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH ||
|
||||
attn.mode == BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH) {
|
||||
dst_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size);
|
||||
|
||||
past_shape = {attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size};
|
||||
past_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size);
|
||||
past_key_src = reinterpret_cast<const HipT*>(past_key->DataRaw());
|
||||
past_value_src = reinterpret_cast<const HipT*>(past_value->DataRaw());
|
||||
past_key_dst = reinterpret_cast<HipT*>(present_key->MutableDataRaw());
|
||||
past_value_dst = reinterpret_cast<HipT*>(present_value->MutableDataRaw());
|
||||
|
||||
if (attn.mode == BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH) {
|
||||
add_src_strides = Strides::BSNHMemory(attn.batch_size, attn.kv_sequence_length, attn.num_heads, attn.head_size);
|
||||
} else if (attn.mode == BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH) {
|
||||
add_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size);
|
||||
}
|
||||
} else if (attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH ||
|
||||
attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH) {
|
||||
dst_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size);
|
||||
|
||||
if (attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH) {
|
||||
add_src_strides = Strides::BSNHMemory(attn.batch_size, attn.kv_sequence_length, attn.num_heads, attn.head_size);
|
||||
} else if (attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH) {
|
||||
add_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size);
|
||||
}
|
||||
} else if (
|
||||
attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH ||
|
||||
attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH ||
|
||||
attn.mode == BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH ||
|
||||
attn.mode == BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH) {
|
||||
dst_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.max_sequence_length, attn.head_size);
|
||||
|
||||
if (attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH || attn.mode == BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH) {
|
||||
add_src_strides = Strides::BSNHMemory(attn.batch_size, attn.kv_sequence_length, attn.num_heads, attn.head_size);
|
||||
} else if (attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH || attn.mode == BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH) {
|
||||
add_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size);
|
||||
}
|
||||
} else {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
|
||||
"past present concatenation is not implemented for attention mode ", attn.mode);
|
||||
}
|
||||
add_shape = {attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size}; // kernel in coord (b,n,s,h)
|
||||
add_key_dst = reinterpret_cast<HipT*>(present_key->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0);
|
||||
add_value_dst = reinterpret_cast<HipT*>(present_value->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0);
|
||||
|
||||
if (past_key_dst) {
|
||||
ORT_RETURN_IF_ERROR(LaunchStridedCopy(
|
||||
stream, past_key_src, past_shape, past_src_strides.ForBNSHCoord(),
|
||||
past_key_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock));
|
||||
}
|
||||
if (past_value_dst) {
|
||||
ORT_RETURN_IF_ERROR(LaunchStridedCopy(
|
||||
stream, past_value_src, past_shape, past_src_strides.ForBNSHCoord(),
|
||||
past_value_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock));
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(LaunchStridedCopy(
|
||||
stream, add_key_src, add_shape, add_src_strides.ForBNSHCoord(),
|
||||
add_key_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock));
|
||||
ORT_RETURN_IF_ERROR(LaunchStridedCopy(
|
||||
stream, add_value_src, add_shape, add_src_strides.ForBNSHCoord(),
|
||||
add_value_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock));
|
||||
}
|
||||
|
||||
GemmSoftmaxGemmPermuteParams<HipT> params;
|
||||
params.tuning_ctx = GetTuningContext();
|
||||
params.stream = Stream(context);
|
||||
params.stream = stream;
|
||||
params.handle = GetRocblasHandle(context);
|
||||
params.attention = &attn;
|
||||
params.device_prop = &device_prop;
|
||||
params.scale = scale_ == 0 ? 1.0f / sqrt(attn.head_size) : scale_;
|
||||
std::tie(params.q_buffer, params.k_buffer, params.v_buffer) = ConvertToOffsetedBufferViews<HipT>(
|
||||
&attn,
|
||||
query->DataRaw(),
|
||||
key == nullptr ? nullptr : key->DataRaw(),
|
||||
value == nullptr ? nullptr : value->DataRaw());
|
||||
nullptr == query ? nullptr : reinterpret_cast<const HipT*>(query->DataRaw()),
|
||||
nullptr == key ? nullptr : reinterpret_cast<const HipT*>(key->DataRaw()),
|
||||
nullptr == value ? nullptr : reinterpret_cast<const HipT*>(value->DataRaw()),
|
||||
nullptr == present_key ? nullptr : reinterpret_cast<const HipT*>(present_key->DataRaw()),
|
||||
nullptr == present_value ? nullptr : reinterpret_cast<const HipT*>(present_value->DataRaw()));
|
||||
params.out_buffer = reinterpret_cast<HipT*>(output->MutableDataRaw());
|
||||
|
||||
if (key_padding_mask != nullptr) {
|
||||
params.mask_index_buffer = key_padding_mask->Data<int>();
|
||||
params.mask_index_dims = key_padding_mask->Shape().AsShapeVector();
|
||||
}
|
||||
|
||||
if (relative_position_bias != nullptr) {
|
||||
params.bias_buffer = reinterpret_cast<const HipT*>(relative_position_bias->DataRaw());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -239,6 +239,12 @@ static void RunAttentionTest(
|
|||
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
}
|
||||
|
||||
if (enable_rocm) {
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.push_back(DefaultRocmExecutionProvider(/*test_tunable_op=*/true));
|
||||
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
}
|
||||
|
||||
if (enable_cpu) {
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.push_back(DefaultCpuExecutionProvider());
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
#include "test/util/include/scoped_env_vars.h"
|
||||
#include "test/contrib_ops/attention_op_test_helper.h"
|
||||
|
||||
#if defined(USE_ROCM) && defined(USE_COMPOSABLE_KERNEL)
|
||||
#if defined(USE_ROCM) && defined(USE_COMPOSABLE_KERNEL) && !defined(USE_MIGRAPHX)
|
||||
#define DISABLE_ROCM false
|
||||
#else
|
||||
#define DISABLE_ROCM true
|
||||
|
|
@ -21,12 +21,6 @@
|
|||
#define ROCM_GTEST_SKIP(message)
|
||||
#endif
|
||||
|
||||
#if defined(USE_MIGRAPHX)
|
||||
#define MIGX_GTEST_SKIP(message) GTEST_SKIP_(message)
|
||||
#else
|
||||
#define MIGX_GTEST_SKIP(message)
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
|
|
@ -66,6 +60,16 @@ static void RunMultiHeadAttentionTest(
|
|||
bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()) && !use_float16 && !disable_cpu;
|
||||
bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get()) && !disable_dml;
|
||||
|
||||
if (enable_rocm && !use_float16) {
|
||||
LOGS_DEFAULT(WARNING) << "ROCm MHA only have kernel for half datatype implemented, skip float datatype tests";
|
||||
enable_rocm = false;
|
||||
}
|
||||
|
||||
if (enable_rocm && !bias_data.empty()) {
|
||||
LOGS_DEFAULT(WARNING) << "ROCm MHA does not support qkv_bias, skip qkv_bias tests";
|
||||
enable_rocm = false;
|
||||
}
|
||||
|
||||
if (enable_cpu || enable_cuda || enable_rocm || enable_dml) {
|
||||
OpTester tester("MultiHeadAttention", 1, onnxruntime::kMSDomain);
|
||||
tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(num_heads));
|
||||
|
|
@ -457,7 +461,6 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu
|
|||
// Test fused cross attention kernel
|
||||
// It requires head_size > 32 and head_size <= 64 for T4 GPU; hidden_size == v_hidden_size.
|
||||
TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize40) {
|
||||
ROCM_GTEST_SKIP("ROCm MHA does not support bias");
|
||||
AttentionTestData data;
|
||||
GetCrossAttentionData_HeadSize40(data);
|
||||
RunMultiHeadAttentionTests(data);
|
||||
|
|
@ -467,7 +470,7 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize40) {
|
|||
}
|
||||
|
||||
TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_Mask1D) {
|
||||
ROCM_GTEST_SKIP("ROCm MHA does not support mask");
|
||||
ROCM_GTEST_SKIP("ROCm MHA does not support mask type of MASK_1D_KEY_SEQ_LEN");
|
||||
AttentionTestData data;
|
||||
GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, true);
|
||||
RunMultiHeadAttentionTests(data, true);
|
||||
|
|
@ -477,7 +480,7 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_M
|
|||
}
|
||||
|
||||
TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_Mask2D) {
|
||||
ROCM_GTEST_SKIP("ROCm MHA does not support mask");
|
||||
ROCM_GTEST_SKIP("ROCm MHA expect failure due to ck bug");
|
||||
AttentionTestData data;
|
||||
GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, false);
|
||||
RunMultiHeadAttentionTests(data, true);
|
||||
|
|
@ -487,7 +490,6 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_M
|
|||
}
|
||||
|
||||
TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize32_LeftSidePadding_Mask2D) {
|
||||
ROCM_GTEST_SKIP("ROCm MHA does not support mask");
|
||||
AttentionTestData data;
|
||||
GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(data);
|
||||
RunMultiHeadAttentionTests(data, true);
|
||||
|
|
@ -497,14 +499,12 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize32_LeftSidePadding_Ma
|
|||
}
|
||||
|
||||
TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_NoBias_NoMask_PackedKV) {
|
||||
MIGX_GTEST_SKIP("MIGX MHA does not support Packed KV");
|
||||
AttentionTestData data;
|
||||
GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(data);
|
||||
RunMultiHeadAttentionTests(data);
|
||||
}
|
||||
|
||||
TEST(MultiHeadAttentionTest, SelfAttention_Batch2_HeadSize32_NoBias_NoMask_PackedQKV) {
|
||||
MIGX_GTEST_SKIP("MIGX MHA does not support Packed QKV");
|
||||
AttentionTestData data;
|
||||
GetSelfAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedQKV(data);
|
||||
RunMultiHeadAttentionTests(data);
|
||||
|
|
@ -513,7 +513,6 @@ TEST(MultiHeadAttentionTest, SelfAttention_Batch2_HeadSize32_NoBias_NoMask_Packe
|
|||
|
||||
// This tests qk_head_size != v_head_size
|
||||
TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize16_8) {
|
||||
ROCM_GTEST_SKIP("ROCm MHA does not support bias");
|
||||
AttentionTestData data;
|
||||
GetCrossAttentionData_HeadSize16_8(data);
|
||||
RunMultiHeadAttentionTests(data);
|
||||
|
|
@ -523,7 +522,6 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize16_8) {
|
|||
}
|
||||
|
||||
TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize16) {
|
||||
ROCM_GTEST_SKIP("ROCm MHA does not support bias");
|
||||
AttentionTestData data;
|
||||
GetCrossAttentionData_HeadSize16(data);
|
||||
RunMultiHeadAttentionTests(data);
|
||||
|
|
@ -533,21 +531,21 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize16) {
|
|||
}
|
||||
|
||||
TEST(MultiHeadAttentionTest, CrossAttentionWithPast) {
|
||||
ROCM_GTEST_SKIP("ROCm MHA does not support attention cache");
|
||||
ROCM_GTEST_SKIP("ROCm MHA only support head_size >= 8");
|
||||
AttentionTestData data;
|
||||
GetCrossAttentionDataWithPast(data);
|
||||
RunMultiHeadAttentionTests(data);
|
||||
}
|
||||
|
||||
TEST(MultiHeadAttentionTest, SelfAttention_WithPast_WithRelPosBias_ForT5) {
|
||||
ROCM_GTEST_SKIP("ROCm MHA does not support attention cache");
|
||||
ROCM_GTEST_SKIP("ROCm MHA only support head_size >= 8");
|
||||
AttentionTestData data;
|
||||
GetSelfAttentionData_WithPast_WithRelPosBias_ForT5(data);
|
||||
RunMultiHeadAttentionTests(data, true);
|
||||
}
|
||||
|
||||
TEST(MultiHeadAttentionTest, AttentionCutlassRelPosBias) {
|
||||
ROCM_GTEST_SKIP("ROCm does not support cutlass");
|
||||
// ROCM_GTEST_SKIP("ROCm does not support cutlass");
|
||||
AttentionTestData data;
|
||||
GetAttentionDataCutlassRelPosBias(data);
|
||||
RunMultiHeadAttentionTests(data);
|
||||
|
|
@ -555,7 +553,6 @@ TEST(MultiHeadAttentionTest, AttentionCutlassRelPosBias) {
|
|||
|
||||
TEST(MultiHeadAttentionTest, CrossAttention_DiffSequenceLengths) {
|
||||
// Whisper decoder cross attention without mask and different sequence lengths for Q and K/V
|
||||
ROCM_GTEST_SKIP("ROCm not supported");
|
||||
AttentionTestData data;
|
||||
GetCrossAttentionData_DiffSequenceLengths(data);
|
||||
RunMultiHeadAttentionTests(data);
|
||||
|
|
@ -569,7 +566,6 @@ TEST(MultiHeadAttentionTest, CrossAttention_DiffSequenceLengths) {
|
|||
|
||||
TEST(MultiHeadAttentionTest, SelfAttention_WithPastAndPresent_NoMask_NoRelPosBias) {
|
||||
// Whisper decoder self attention with past_kv and present_kv
|
||||
ROCM_GTEST_SKIP("ROCm not supported");
|
||||
AttentionTestData data;
|
||||
GetSelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias(data);
|
||||
RunMultiHeadAttentionTests(data);
|
||||
|
|
@ -583,7 +579,6 @@ TEST(MultiHeadAttentionTest, SelfAttention_WithPastAndPresent_NoMask_NoRelPosBia
|
|||
|
||||
TEST(MultiHeadAttentionTest, CrossAttention_WithPastPassedInDirectly_NoMask) {
|
||||
// Whisper decoder cross attention with past_kv in place of current KV and no present_kv
|
||||
ROCM_GTEST_SKIP("ROCm not supported");
|
||||
AttentionTestData data;
|
||||
GetCrossAttentionData_WithPastPassedInDirectly_NoMask(data);
|
||||
RunMultiHeadAttentionTests(data);
|
||||
|
|
|
|||
Loading…
Reference in a new issue