From 9110e5b9bde306a409e86e71db3bdc5a2dda3829 Mon Sep 17 00:00:00 2001 From: cloudhan Date: Fri, 16 Jun 2023 14:17:56 +0800 Subject: [PATCH] [ROCm] Add attention kv cache for decoding (#16076) --- .../contrib_ops/cuda/bert/attention_impl.h | 6 + .../cuda/bert/attention_strided_copy.cu | 158 ++++++++++++++++++ .../contrib_ops/rocm/bert/attention.cu | 61 ++++--- .../contrib_ops/rocm/bert/attention_impl.cu | 18 ++ .../contrib_ops/rocm/bert/attention_impl.h | 31 +--- ...ed_gemm_softmax_gemm_permute_pipelines.cuh | 92 ++++++++-- .../rocm/bert/multihead_attention.cu | 114 +++++++++++-- .../test/contrib_ops/attention_op_test.cc | 6 + .../multihead_attention_op_test.cc | 37 ++-- 9 files changed, 435 insertions(+), 88 deletions(-) create mode 100644 onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index 3c0fa164fb..afc5a065a0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -180,6 +180,12 @@ Status LaunchConcatPastToPresent(cudaStream_t stream, const half* past, const half* k_v, half* present); + +template +Status LaunchStridedCopy(cudaStream_t stream, + const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h) + T* out, longlong4 out_strides, // coord (b,n,s,h) + int max_threads_per_block); } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu b/onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu new file mode 100644 index 0000000000..1466f5fcfe --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu @@ -0,0 +1,158 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cuda/bert/attention_impl.h" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +__global__ void StridedCopy(const T* in, const int H, longlong4 in_strides, // coord (b,n,s,h) + T* out, longlong4 out_strides // coord (b,n,s,h) +) { + const int h = threadIdx.x; + const int n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + if (h < H) { + const int in_offset = b * in_strides.x + n * in_strides.y + s * in_strides.z + h * in_strides.w; + const int out_offset = b * out_strides.x + n * out_strides.y + s * out_strides.z + h * out_strides.w; + out[out_offset] = in[in_offset]; + } +} + +template +__global__ void StridedCopyLarge(const T* in, const int H, longlong4 in_strides, // coord (b,n,s,h) + T* out, longlong4 out_strides // coord (b,n,s,h) +) { + // Use when (H*)*num_heads > 1024 + int h = threadIdx.x; + const int n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + + const int h_step = blockDim.x; + + while (h < H) { + const int in_offset = b * in_strides.x + n * in_strides.y + s * in_strides.z + h * in_strides.w; + const int out_offset = b * out_strides.x + n * out_strides.y + s * out_strides.z + h * out_strides.w; + out[out_offset] = in[in_offset]; + h += h_step; + } +} + +template +struct ToByteType; + +template <> +struct ToByteType<2> { + using T = int16_t; +}; + +template <> +struct ToByteType<4> { + using T = int32_t; +}; + +template <> +struct ToByteType<8> { + using T = int64_t; +}; + +template <> +struct ToByteType<16> { + using T = uint4; +}; + +template <> +struct ToByteType<32> { + using T = ulonglong4; +}; + +template +using ToBytes = typename ToByteType::T; + +template +Status LaunchStridedCopy(cudaStream_t stream, + const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h) + T* out, longlong4 out_strides, // coord (b,n,s,h) + int max_threads_per_block) { + int batch_size = in_shape.x; + int num_heads = in_shape.y; + int sequence_length = in_shape.z; + int head_size = in_shape.w; + if (sequence_length == 0) { + return Status::OK(); + } + + const dim3 grid(sequence_length, batch_size); + if (0 == (head_size % 4)) { // pack 4 element together + using Bytes = ToBytes; + const int H = head_size / 4; + in_strides.x /= 4; + in_strides.y /= 4; + in_strides.z /= 4; + out_strides.x /= 4; + out_strides.y /= 4; + out_strides.z /= 4; + if (H * num_heads <= max_threads_per_block) { + const dim3 block(H, num_heads, 1); + StridedCopy<<>>(reinterpret_cast(in), H, in_strides, + reinterpret_cast(out), out_strides); + } else { + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); + StridedCopyLarge<<>>(reinterpret_cast(in), H, in_strides, + reinterpret_cast(out), out_strides); + } + } else if (0 == (head_size % 2)) { // pack 2 element together + using Bytes = ToBytes; + const int H = head_size / 2; + in_strides.x /= 2; + in_strides.y /= 2; + in_strides.z /= 2; + out_strides.x /= 2; + out_strides.y /= 2; + out_strides.z /= 2; + if (H * num_heads <= max_threads_per_block) { + const dim3 block(H, num_heads, 1); + StridedCopy<<>>(reinterpret_cast(in), H, in_strides, + reinterpret_cast(out), out_strides); + } else { + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); + StridedCopyLarge<<>>(reinterpret_cast(in), H, in_strides, + reinterpret_cast(out), out_strides); + } + } else { + using Bytes = ToBytes; + if (head_size * num_heads <= max_threads_per_block) { + const dim3 block(head_size, num_heads, 1); + StridedCopy<<>>(reinterpret_cast(in), head_size, in_strides, + reinterpret_cast(out), out_strides); + } else { + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); + StridedCopyLarge<<>>(reinterpret_cast(in), head_size, in_strides, + reinterpret_cast(out), out_strides); + } + } + return CUDA_CALL(cudaGetLastError()); +} + +template Status LaunchStridedCopy( + cudaStream_t stream, + const float* in, int4 in_shape, longlong4 in_strides, + float* out, longlong4 out_strides, + int max_threads_per_block); + +template Status LaunchStridedCopy( + cudaStream_t stream, + const half* in, int4 in_shape, longlong4 in_strides, + half* out, longlong4 out_strides, + int max_threads_per_block); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention.cu b/onnxruntime/contrib_ops/rocm/bert/attention.cu index 124116497c..e7bb7a9a04 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/attention.cu @@ -88,10 +88,11 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(ClassifyAttentionMode( Node().OpType(), &attn, /*qkv=*/{}, /*past=*/{past}, /*present=*/{present})); - // TODO: support QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE and QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE ORT_ENFORCE(attn.mode == QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE || attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE || - attn.mode == QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE); + attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE || + attn.mode == QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE || + attn.mode == QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE); size_t qkv_project_output_bytes = QkvProjectGeneric::GetOutputNumBytes(&attn); size_t shared_workspace_bytes = std::max(QkvProjectGeneric::GetWorkspaceNumBytes(&attn), @@ -123,27 +124,46 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(QkvProjectGeneric::Run(&gemm_permute_params)); auto [q_buffer, k_buffer, v_buffer] = QkvProjectGeneric::UnspliceOutputQKV(&gemm_permute_params); + // NOTE: GemmPermute always output 3BNSH, k_buffer and v_buffer can be treated as 2BNSH if (nullptr != present) { - // Concat past (2xBxNxS'xH) to present (2xBxNxTxH): - // past_k (BxNxS'xH) + k (BxNxSxH) => present_k (BxNxTxH) - // past_v (BxNxS'xH) + v (BxNxSxH) => present_v (BxNxTxH) - const int batches = attn.batch_size * attn.num_heads; - const int present_size_per_batch = attn.total_sequence_length * attn.head_size; - ORT_RETURN_IF_ERROR( - LaunchConcatPastToPresent(Stream(context), - attn.total_sequence_length, - attn.sequence_length, - attn.batch_size, - attn.head_size, - attn.num_heads, - device_prop.maxThreadsPerBlock, - nullptr == past ? nullptr : reinterpret_cast(past->DataRaw()), - k_buffer, - reinterpret_cast(present->MutableDataRaw()))); + Strides dst_strides; // the output buffer is present Tensor, the buffer is the same - // update pointers to present_k and present_v. + int4 add_shape{2 * attn.batch_size, attn.num_heads, attn.sequence_length, attn.head_size}; + HipT* add_dest = nullptr; // destination of concatenated data to present + const HipT* const add_src = k_buffer; // source of concatenated data to present + const auto add_src_strides = Strides::BNSHMemory( + 2 * attn.batch_size, attn.num_heads, attn.sequence_length, attn.head_size); + + if (attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE) { + dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size); + add_dest = reinterpret_cast(present->MutableDataRaw()) /* + dst_strides.OffsetAt(0, 0, 0, 0)*/; + } else if (attn.mode == QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE) { + dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size); + add_dest = reinterpret_cast(present->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0); + + // We only need to copy past to present in this case. All other cases will be build the present incrementally + const int4 past_shape = {2 * attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size}; + HipT* const past_dest = reinterpret_cast(present->MutableDataRaw()); + const HipT* const past_src = reinterpret_cast(past->DataRaw()); + const Strides past_src_strides = Strides::BNSHMemory( + 2 * attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size); + + ORT_RETURN_IF_ERROR(LaunchStridedCopy(stream, past_src, past_shape, past_src_strides.ForBNSHCoord(), + past_dest, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); + } else if (attn.mode == QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE) { + dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.max_sequence_length, attn.head_size); + add_dest = reinterpret_cast(present->MutableDataRaw()) /* + dst_strides.OffsetAt(0, 0, 0, 0)*/; + } else if (attn.mode == QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE) { + dst_strides = Strides::BNSHMemory(2 * attn.batch_size, attn.num_heads, attn.max_sequence_length, attn.head_size); + add_dest = reinterpret_cast(present->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0); + } + + ORT_RETURN_IF_ERROR(LaunchStridedCopy(stream, add_src, add_shape, add_src_strides.ForBNSHCoord(), + add_dest, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); + + // update pointers to present_k and present_v. TODO: switch to ConvertToOffsetedBufferViews k_buffer = reinterpret_cast(present->MutableDataRaw()); - v_buffer = reinterpret_cast(present->MutableDataRaw()) + batches * present_size_per_batch; + v_buffer = reinterpret_cast(present->MutableDataRaw()) + dst_strides.OffsetAt(attn.batch_size, 0, 0, 0); } // For testing, environment variable ORT_TRANSFORMER_OPTIONS=1 could enable persistent softmax @@ -160,6 +180,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { params.device_prop = &device_prop; // FIXME: the params.scale seems to be different from AttentionParameters::scale; params.scale = 1.0f / sqrt(static_cast(attn.head_size)); + // TODO: switch to ConvertToOffsetedBufferViews params.q_buffer = q_buffer; params.k_buffer = k_buffer; params.v_buffer = v_buffer; diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu index b31311e6ed..40d9036452 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu @@ -122,6 +122,24 @@ Status ClassifyAttentionMode( attn->mode = BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE; return Status::OK(); } + } else if (num_qkv == 3 && num_past == 0 && num_present == 2) { + if (attn->past_present_share_buffer == false) { + if (attn->qkv_format == Q_K_V_BSNH) { + attn->mode = BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH; + return Status::OK(); + } else if (attn->pass_past_in_kv) { + attn->mode = BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH; + return Status::OK(); + } + } else { + if (attn->qkv_format == Q_K_V_BSNH) { + attn->mode = BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH; + return Status::OK(); + } else if (attn->pass_past_in_kv) { + attn->mode = BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH; + return Status::OK(); + } + } } else if (num_qkv == 3 && num_past == 2 && num_present == 2) { if (attn->past_present_share_buffer == false) { if (attn->qkv_format == Q_K_V_BSNH) { diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h index 5b6ec6de70..f4e68bea19 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h @@ -98,28 +98,6 @@ Status LaunchConcatTensorToTensor(hipStream_t stream, const half* tensor_add, half* tensor_out); -Status LaunchConcatPastToPresent(hipStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const float* past, - const float* k_v, - float* present); - -Status LaunchConcatPastToPresent(hipStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const half* past, - const half* k_v, - half* present); - inline rocblas_status _compat_rocblas_gemm_strided_batched_ex(rocblas_handle handle, rocblas_operation transa, rocblas_operation transb, @@ -191,6 +169,10 @@ enum AttentionMode { QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE, BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE, BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE, + BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH, + BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH, + BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH, + BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH, BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH, BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH, BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH, @@ -209,6 +191,11 @@ Status ClassifyAttentionMode(const std::string& op, const std::vector& past, const std::vector& present); +template +Status LaunchStridedCopy(hipStream_t stream, + const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h) + T* out, longlong4 out_strides, // coord (b,n,s,h) + int max_threads_per_block); } // namespace rocm } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh index c8febbc795..14e1430b41 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh @@ -67,6 +67,10 @@ are in composable kernels. The scale and add logic is performed via Acc0ElementO | BSNH | BLNH*| BLNH^| - | - | - | - | MHA basic | BSNH | BNLH*| BNLH^| - | - | - | - | MHA cross, pass_past_in_kv = true | BSNH | - | - | - | - | BNLH * | BNLH ^ | MHA cross, pass_past_in_kv = false +| BSNH | BLNH | BLNH | - | - | BNTH * | BNTH ^ | MHA cross, past_present_share_buffer = false +| BSNH | BNLH | BNLH | - | - | BNTH * | BNTH ^ | MHA cross, past_present_share_buffer = false +| BSNH | BLNH | BLNH | - | - | BNMH * | BNMH ^ | MHA cross, past_present_share_buffer = true +| BSNH | BNLH | BNLH | - | - | BNMH * | BNMH ^ | MHA cross, past_present_share_buffer = true | BSNH | BLNH | BLNH | BNPH | BNPH | BNTH * | BNTH ^ | MHA self, past_present_share_buffer = false | BSNH | BNLH | BNLH | BNPH | BNPH | BNTH * | BNTH ^ | MHA self, past_present_share_buffer = false | BSNH | BLNH | BLNH | BNMH | BNMH | BNMH * | BNMH ^ | MHA self, past_present_share_buffer = true @@ -76,7 +80,7 @@ are in composable kernels. The scale and add logic is performed via Acc0ElementO Q, K, V, past(K), pastV, present(K), presentV is the Input of the contrib OpKernel -About k_buffer and v_buffer, we always explicitly concat past to present and use present_k for k_buffer and present_b for v_buffer +About k_buffer and v_buffer, we always explicitly concat past to present and use present_k for k_buffer and present_v for v_buffer - Marked with `*` indicate the Tensor is used for k_buffer passing. - Marked with `^` indicate the Tensor is used for v_buffer passing. @@ -162,7 +166,7 @@ struct Strides { } template - T ForBNSHCoord() { + T ForBNSHCoord() const { using E = typename T::value_type; return T{static_cast(strides_for_bnsh_coord.x), static_cast(strides_for_bnsh_coord.y), @@ -171,7 +175,7 @@ struct Strides { } template - T ForBSNHCoord() { + T ForBSNHCoord() const { using E = typename T::value_type; return T{static_cast(strides_for_bnsh_coord.x), static_cast(strides_for_bnsh_coord.z), @@ -180,7 +184,7 @@ struct Strides { } template - T ForBNHSCoord() { + T ForBNHSCoord() const { using E = typename T::value_type; return T{static_cast(strides_for_bnsh_coord.x), static_cast(strides_for_bnsh_coord.y), @@ -188,6 +192,11 @@ struct Strides { static_cast(strides_for_bnsh_coord.z)}; } + int64_t OffsetAt(int b, int n, int s, int h) const { + return strides_for_bnsh_coord.x * b + strides_for_bnsh_coord.y * n + + strides_for_bnsh_coord.z * s + strides_for_bnsh_coord.w * h; + } + // store intermediate strides in the canonical (b,n,s,h) coordinate order longlong4 strides_for_bnsh_coord; }; @@ -200,13 +209,15 @@ std::tuple ConvertToOffsetedBufferViews( const T* value = nullptr, // const T* present = nullptr, // present or present_k const T* present_v = nullptr) { - ORT_UNUSED_PARAMETER(present_v); switch (attn->mode) { - case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: { + case QFMT_KFMT_VFMT_NONE_NONE_NONE_NONE: + case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE: + case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE: { return {reinterpret_cast(query), reinterpret_cast(key), reinterpret_cast(value)}; } + case QFMT_KFMT_VFMT_NONE_NONE_2BNTH_NONE: case QFMT_KFMT_VFMT_2BNPH_NONE_2BNTH_NONE: { auto offset = static_cast(attn->batch_size) * attn->num_heads * attn->total_sequence_length * attn->head_size; @@ -214,6 +225,25 @@ std::tuple ConvertToOffsetedBufferViews( reinterpret_cast(present), reinterpret_cast(present) + offset}; } + case QFMT_KFMT_VFMT_NONE_NONE_2BNMH_NONE: + case QFMT_KFMT_VFMT_2BNMH_NONE_2BNMH_NONE: { + auto offset = static_cast(attn->batch_size) * attn->num_heads * attn->max_sequence_length * + attn->head_size; + return {reinterpret_cast(query), + reinterpret_cast(present), + reinterpret_cast(present) + offset}; + } + case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH: + case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH: + case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: + case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: + case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH: + case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH: + case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: + case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: + return {reinterpret_cast(query), + reinterpret_cast(present), + reinterpret_cast(present_v)}; case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: { auto packed_kv = reinterpret_cast(key); return {reinterpret_cast(query), packed_kv, packed_kv + attn->head_size}; @@ -234,7 +264,8 @@ inline std::tuple GetQkvStrides(const RocmAttentionPa const int& N = attn->num_heads; const int& S = attn->sequence_length; const int& L = attn->kv_sequence_length; - // const int& T = attn->total_sequence_length; + const int& T = attn->total_sequence_length; + const int& M = attn->max_sequence_length; const int& H = attn->head_size; const int& Hv = attn->v_head_size; @@ -253,6 +284,36 @@ inline std::tuple GetQkvStrides(const RocmAttentionPa Strides::BSNHMemory(B, L, N, Hv), }; } + case BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH: + case BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH: + case BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH: + case BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH: + return { + Strides::BSNHMemory(B, S, N, H), + Strides::BNSHMemory(B, N, T, H), + Strides::BNSHMemory(B, N, T, Hv), + }; + case BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH: + case BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH: + case BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH: + case BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH: + return { + Strides::BSNHMemory(B, S, N, H), + Strides::BNSHMemory(B, N, M, H), + Strides::BNSHMemory(B, N, M, Hv), + }; + case BSNH_BLNH_BLNH_NONE_NONE_NONE_NONE: + return { + Strides::BSNHMemory(B, S, N, H), + Strides::BSNHMemory(B, L, N, H), + Strides::BSNHMemory(B, L, N, Hv), + }; + case BSNH_BNLH_BNLH_NONE_NONE_NONE_NONE: + return { + Strides::BSNHMemory(B, S, N, H), + Strides::BNSHMemory(B, N, L, H), + Strides::BNSHMemory(B, N, L, Hv), + }; case BSNH_BLN2H_NONE_NONE_NONE_NONE_NONE: return { Strides::BSNHMemory(B, S, N, H), @@ -310,8 +371,10 @@ struct GemmSoftmaxGemmPermuteParams : onnxruntime::rocm::tunable::OpParams { "_T", attention->total_sequence_length, "_N", attention->num_heads, "_H", attention->head_size, + "_Hv", attention->v_head_size, bias_buffer != nullptr ? "_B" : "_NB", "_M", mask_index_dims.size(), + "_QKV", attention->qkv_format, "_MODE", attention->mode); } @@ -320,7 +383,7 @@ struct GemmSoftmaxGemmPermuteParams : onnxruntime::rocm::tunable::OpParams { auto m = attention->sequence_length; auto n = attention->total_sequence_length; auto k = attention->head_size; - auto o = attention->head_size; + auto o = attention->v_head_size; auto batch = attention->batch_size * attention->num_heads; return {m, n, k, o, batch}; } @@ -491,6 +554,16 @@ class GemmSoftmaxGemmPermuteTunableOp : public tunable::TunableOp(out + out_offset) = store; } else { #pragma unroll - for (int i = tidx; i < mask_lengths.z; i++) { + for (int i = 0; i < mask_lengths.z - tidx; i++) { out[out_offset + i] = cvt(mask_buffer[in_offset + i]); } } @@ -639,7 +712,6 @@ auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() { { auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch(); ORT_ENFORCE(M == m && N == n && K == k && O == o && G0 * G1 == batch, "semantic mismatch"); - ORT_ENFORCE(K == O, "inner product dimension mismatch"); } auto [qs, ks, vs] = GetQkvStrides(attn); diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu index aa8a87a1da..ba634152ec 100644 --- a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu @@ -58,9 +58,11 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* past_key = context->Input(6); const Tensor* past_value = context->Input(7); - // TODO: Add support for bias, key_padding_mask and attention cache. - ORT_ENFORCE(bias == nullptr && key_padding_mask == nullptr && past_key == nullptr && past_value == nullptr, - "bias, key_padding_mask and attention cache is not supported"); + if (nullptr != bias) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "qkv_bias is not supported on ROCm EP. " + "User should fuse the qkv bias to qkv projection instead."); + } auto& device_prop = GetDeviceProp(); RocmAttentionParameters attn; @@ -72,14 +74,10 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { &attn, num_heads_, mask_filter_value_, scale_, false, false, device_prop.maxThreadsPerBlock)); - // TODO: support more qkv formats - ORT_ENFORCE(attn.qkv_format == Q_KV_BSNH_BSN2H || attn.qkv_format == QKV_BSN3H, "Got ", attn.qkv_format); - - int sequence_length = attn.sequence_length; TensorShapeVector output_shape(3); output_shape[0] = static_cast(attn.batch_size); - output_shape[1] = static_cast(sequence_length); + output_shape[1] = static_cast(attn.sequence_length); output_shape[2] = static_cast(attn.v_hidden_size); Tensor* output = context->Output(0, output_shape); @@ -92,8 +90,6 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { TensorShape present_shape(present_dims); Tensor* present_key = context->Output(1, present_shape); Tensor* present_value = context->Output(2, present_shape); - // TODO: Add support for attention cache - ORT_ENFORCE(present_key == nullptr && present_value == nullptr, "attention cache is not supported"); ORT_RETURN_IF_ERROR(ClassifyAttentionMode( Node().OpType(), &attn, @@ -101,26 +97,114 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { /*past=*/{past_key, past_value}, /*present=*/{present_key, present_value})); - using HipT = typename ToHipType::MappedType; using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp; auto workspace_bytes = AttentionTunableOp::GetWorkspaceNumBytes(&attn); auto workspace = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); + hipStream_t stream = Stream(context); + if (nullptr != present_key) { // process past present concat + Strides dst_strides; + + int4 past_shape; + Strides past_src_strides; + const HipT* past_key_src; + const HipT* past_value_src; + HipT* past_key_dst{}; + HipT* past_value_dst{}; + + int4 add_shape; + Strides add_src_strides; + const HipT* add_key_src =reinterpret_cast(key->DataRaw()); + const HipT* add_value_src = reinterpret_cast(value->DataRaw()); + HipT* add_key_dst; + HipT* add_value_dst; + + if (attn.mode == BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH || + attn.mode == BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH) { + dst_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size); + + past_shape = {attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size}; + past_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.past_sequence_length, attn.head_size); + past_key_src = reinterpret_cast(past_key->DataRaw()); + past_value_src = reinterpret_cast(past_value->DataRaw()); + past_key_dst = reinterpret_cast(present_key->MutableDataRaw()); + past_value_dst = reinterpret_cast(present_value->MutableDataRaw()); + + if (attn.mode == BSNH_BLNH_BLNH_BNPH_BNPH_BNTH_BNTH) { + add_src_strides = Strides::BSNHMemory(attn.batch_size, attn.kv_sequence_length, attn.num_heads, attn.head_size); + } else if (attn.mode == BSNH_BNLH_BNLH_BNPH_BNPH_BNTH_BNTH) { + add_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size); + } + } else if (attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH || + attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH) { + dst_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.total_sequence_length, attn.head_size); + + if (attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNTH_BNTH) { + add_src_strides = Strides::BSNHMemory(attn.batch_size, attn.kv_sequence_length, attn.num_heads, attn.head_size); + } else if (attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNTH_BNTH) { + add_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size); + } + } else if ( + attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH || + attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH || + attn.mode == BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH || + attn.mode == BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH) { + dst_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.max_sequence_length, attn.head_size); + + if (attn.mode == BSNH_BLNH_BLNH_NONE_NONE_BNMH_BNMH || attn.mode == BSNH_BLNH_BLNH_BNMH_BNMH_BNMH_BNMH) { + add_src_strides = Strides::BSNHMemory(attn.batch_size, attn.kv_sequence_length, attn.num_heads, attn.head_size); + } else if (attn.mode == BSNH_BNLH_BNLH_NONE_NONE_BNMH_BNMH || attn.mode == BSNH_BNLH_BNLH_BNMH_BNMH_BNMH_BNMH) { + add_src_strides = Strides::BNSHMemory(attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size); + } + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "past present concatenation is not implemented for attention mode ", attn.mode); + } + add_shape = {attn.batch_size, attn.num_heads, attn.kv_sequence_length, attn.head_size}; // kernel in coord (b,n,s,h) + add_key_dst = reinterpret_cast(present_key->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0); + add_value_dst = reinterpret_cast(present_value->MutableDataRaw()) + dst_strides.OffsetAt(0, 0, attn.past_sequence_length, 0); + + if (past_key_dst) { + ORT_RETURN_IF_ERROR(LaunchStridedCopy( + stream, past_key_src, past_shape, past_src_strides.ForBNSHCoord(), + past_key_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); + } + if (past_value_dst) { + ORT_RETURN_IF_ERROR(LaunchStridedCopy( + stream, past_value_src, past_shape, past_src_strides.ForBNSHCoord(), + past_value_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); + } + + ORT_RETURN_IF_ERROR(LaunchStridedCopy( + stream, add_key_src, add_shape, add_src_strides.ForBNSHCoord(), + add_key_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); + ORT_RETURN_IF_ERROR(LaunchStridedCopy( + stream, add_value_src, add_shape, add_src_strides.ForBNSHCoord(), + add_value_dst, dst_strides.ForBNSHCoord(), device_prop.maxThreadsPerBlock)); + } + GemmSoftmaxGemmPermuteParams params; params.tuning_ctx = GetTuningContext(); - params.stream = Stream(context); + params.stream = stream; params.handle = GetRocblasHandle(context); params.attention = &attn; params.device_prop = &device_prop; params.scale = scale_ == 0 ? 1.0f / sqrt(attn.head_size) : scale_; std::tie(params.q_buffer, params.k_buffer, params.v_buffer) = ConvertToOffsetedBufferViews( &attn, - query->DataRaw(), - key == nullptr ? nullptr : key->DataRaw(), - value == nullptr ? nullptr : value->DataRaw()); + nullptr == query ? nullptr : reinterpret_cast(query->DataRaw()), + nullptr == key ? nullptr : reinterpret_cast(key->DataRaw()), + nullptr == value ? nullptr : reinterpret_cast(value->DataRaw()), + nullptr == present_key ? nullptr : reinterpret_cast(present_key->DataRaw()), + nullptr == present_value ? nullptr : reinterpret_cast(present_value->DataRaw())); params.out_buffer = reinterpret_cast(output->MutableDataRaw()); + if (key_padding_mask != nullptr) { + params.mask_index_buffer = key_padding_mask->Data(); + params.mask_index_dims = key_padding_mask->Shape().AsShapeVector(); + } + if (relative_position_bias != nullptr) { params.bias_buffer = reinterpret_cast(relative_position_bias->DataRaw()); } diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc index e1deeec424..8686d235df 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test.cc @@ -239,6 +239,12 @@ static void RunAttentionTest( tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } + if (enable_rocm) { + std::vector> execution_providers; + execution_providers.push_back(DefaultRocmExecutionProvider(/*test_tunable_op=*/true)); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (enable_cpu) { std::vector> execution_providers; execution_providers.push_back(DefaultCpuExecutionProvider()); diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc index 2295873b42..f458031a35 100644 --- a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc @@ -9,7 +9,7 @@ #include "test/util/include/scoped_env_vars.h" #include "test/contrib_ops/attention_op_test_helper.h" -#if defined(USE_ROCM) && defined(USE_COMPOSABLE_KERNEL) +#if defined(USE_ROCM) && defined(USE_COMPOSABLE_KERNEL) && !defined(USE_MIGRAPHX) #define DISABLE_ROCM false #else #define DISABLE_ROCM true @@ -21,12 +21,6 @@ #define ROCM_GTEST_SKIP(message) #endif -#if defined(USE_MIGRAPHX) -#define MIGX_GTEST_SKIP(message) GTEST_SKIP_(message) -#else -#define MIGX_GTEST_SKIP(message) -#endif - namespace onnxruntime { namespace test { @@ -66,6 +60,16 @@ static void RunMultiHeadAttentionTest( bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()) && !use_float16 && !disable_cpu; bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get()) && !disable_dml; + if (enable_rocm && !use_float16) { + LOGS_DEFAULT(WARNING) << "ROCm MHA only have kernel for half datatype implemented, skip float datatype tests"; + enable_rocm = false; + } + + if (enable_rocm && !bias_data.empty()) { + LOGS_DEFAULT(WARNING) << "ROCm MHA does not support qkv_bias, skip qkv_bias tests"; + enable_rocm = false; + } + if (enable_cpu || enable_cuda || enable_rocm || enable_dml) { OpTester tester("MultiHeadAttention", 1, onnxruntime::kMSDomain); tester.AddAttribute("num_heads", static_cast(num_heads)); @@ -457,7 +461,6 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu // Test fused cross attention kernel // It requires head_size > 32 and head_size <= 64 for T4 GPU; hidden_size == v_hidden_size. TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize40) { - ROCM_GTEST_SKIP("ROCm MHA does not support bias"); AttentionTestData data; GetCrossAttentionData_HeadSize40(data); RunMultiHeadAttentionTests(data); @@ -467,7 +470,7 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize40) { } TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_Mask1D) { - ROCM_GTEST_SKIP("ROCm MHA does not support mask"); + ROCM_GTEST_SKIP("ROCm MHA does not support mask type of MASK_1D_KEY_SEQ_LEN"); AttentionTestData data; GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, true); RunMultiHeadAttentionTests(data, true); @@ -477,7 +480,7 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_M } TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_Mask2D) { - ROCM_GTEST_SKIP("ROCm MHA does not support mask"); + ROCM_GTEST_SKIP("ROCm MHA expect failure due to ck bug"); AttentionTestData data; GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, false); RunMultiHeadAttentionTests(data, true); @@ -487,7 +490,6 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_M } TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize32_LeftSidePadding_Mask2D) { - ROCM_GTEST_SKIP("ROCm MHA does not support mask"); AttentionTestData data; GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(data); RunMultiHeadAttentionTests(data, true); @@ -497,14 +499,12 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize32_LeftSidePadding_Ma } TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_NoBias_NoMask_PackedKV) { - MIGX_GTEST_SKIP("MIGX MHA does not support Packed KV"); AttentionTestData data; GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(data); RunMultiHeadAttentionTests(data); } TEST(MultiHeadAttentionTest, SelfAttention_Batch2_HeadSize32_NoBias_NoMask_PackedQKV) { - MIGX_GTEST_SKIP("MIGX MHA does not support Packed QKV"); AttentionTestData data; GetSelfAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedQKV(data); RunMultiHeadAttentionTests(data); @@ -513,7 +513,6 @@ TEST(MultiHeadAttentionTest, SelfAttention_Batch2_HeadSize32_NoBias_NoMask_Packe // This tests qk_head_size != v_head_size TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize16_8) { - ROCM_GTEST_SKIP("ROCm MHA does not support bias"); AttentionTestData data; GetCrossAttentionData_HeadSize16_8(data); RunMultiHeadAttentionTests(data); @@ -523,7 +522,6 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize16_8) { } TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize16) { - ROCM_GTEST_SKIP("ROCm MHA does not support bias"); AttentionTestData data; GetCrossAttentionData_HeadSize16(data); RunMultiHeadAttentionTests(data); @@ -533,21 +531,21 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize16) { } TEST(MultiHeadAttentionTest, CrossAttentionWithPast) { - ROCM_GTEST_SKIP("ROCm MHA does not support attention cache"); + ROCM_GTEST_SKIP("ROCm MHA only support head_size >= 8"); AttentionTestData data; GetCrossAttentionDataWithPast(data); RunMultiHeadAttentionTests(data); } TEST(MultiHeadAttentionTest, SelfAttention_WithPast_WithRelPosBias_ForT5) { - ROCM_GTEST_SKIP("ROCm MHA does not support attention cache"); + ROCM_GTEST_SKIP("ROCm MHA only support head_size >= 8"); AttentionTestData data; GetSelfAttentionData_WithPast_WithRelPosBias_ForT5(data); RunMultiHeadAttentionTests(data, true); } TEST(MultiHeadAttentionTest, AttentionCutlassRelPosBias) { - ROCM_GTEST_SKIP("ROCm does not support cutlass"); + // ROCM_GTEST_SKIP("ROCm does not support cutlass"); AttentionTestData data; GetAttentionDataCutlassRelPosBias(data); RunMultiHeadAttentionTests(data); @@ -555,7 +553,6 @@ TEST(MultiHeadAttentionTest, AttentionCutlassRelPosBias) { TEST(MultiHeadAttentionTest, CrossAttention_DiffSequenceLengths) { // Whisper decoder cross attention without mask and different sequence lengths for Q and K/V - ROCM_GTEST_SKIP("ROCm not supported"); AttentionTestData data; GetCrossAttentionData_DiffSequenceLengths(data); RunMultiHeadAttentionTests(data); @@ -569,7 +566,6 @@ TEST(MultiHeadAttentionTest, CrossAttention_DiffSequenceLengths) { TEST(MultiHeadAttentionTest, SelfAttention_WithPastAndPresent_NoMask_NoRelPosBias) { // Whisper decoder self attention with past_kv and present_kv - ROCM_GTEST_SKIP("ROCm not supported"); AttentionTestData data; GetSelfAttentionData_WithPastAndPresent_NoMask_NoRelPosBias(data); RunMultiHeadAttentionTests(data); @@ -583,7 +579,6 @@ TEST(MultiHeadAttentionTest, SelfAttention_WithPastAndPresent_NoMask_NoRelPosBia TEST(MultiHeadAttentionTest, CrossAttention_WithPastPassedInDirectly_NoMask) { // Whisper decoder cross attention with past_kv in place of current KV and no present_kv - ROCM_GTEST_SKIP("ROCm not supported"); AttentionTestData data; GetCrossAttentionData_WithPastPassedInDirectly_NoMask(data); RunMultiHeadAttentionTests(data);