Refactor Attention cuda kernel (#17578)

* Break QkvToContext into small functions. Each fused and unfused kernel
will have separated function.
* Move DecoderAttention kernel to separated file
* Move KV cache related kernel to attention_kv_cache.cu

### Motivation and Context
To make the code easier to maintain.
This commit is contained in:
Tianlei Wu 2023-09-19 09:49:21 -07:00 committed by GitHub
parent 068300d97e
commit 730fab3050
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 934 additions and 758 deletions

View file

@ -11,6 +11,8 @@ set(contrib_ops_excluded_files
"bert/attention_softmax.h"
"bert/attention_softmax.cu"
"bert/attention_prepare_qkv.cu"
"bert/decoder_attention_impl.h"
"bert/decoder_attention_impl.cu"
"bert/decoder_masked_multihead_attention.h"
"bert/decoder_masked_multihead_attention.cc"
"bert/decoder_masked_self_attention.h"

File diff suppressed because it is too large Load diff

View file

@ -81,24 +81,20 @@ struct AttentionData {
mutable CumulatedSequenceLengthCache* cumulated_sequence_length_q_cache = nullptr;
mutable CumulatedSequenceLengthCache* cumulated_sequence_length_kv_cache = nullptr;
};
// Intermediate data pointers available after PrepareQKV
template <typename T>
struct QkvData {
// Intermediate data
T* q = nullptr;
T* k = nullptr;
T* v = nullptr;
T* after_v = nullptr; // pointer right after v
AttentionQkvFormat format = AttentionQkvFormat::Q_K_V_BSNH;
T* scratch = nullptr;
AttentionQkvFormat qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
};
template <typename T>
Status PrepareQkv(contrib::AttentionParameters& parameters,
AttentionData<T>& data,
cudaStream_t stream,
int max_threads_per_block,
QkvData<T>& qkv_data);
int max_threads_per_block);
template <typename T>
Status QkvToContext(
@ -108,33 +104,6 @@ Status QkvToContext(
contrib::AttentionParameters& parameters,
AttentionData<T>& data);
Status LaunchDecoderAttentionKernel(
const cudaDeviceProp& prop, // Device Properties
Stream* stream, // ORT Stream
cublasHandle_t& cublas, // Cublas handle
const size_t element_size, // Element size of input tensor
const int batch_size, // Batch size (B)
const int sequence_length, // Sequence length (S)
const int kv_sequence_length, // Key/Value/Cache sequence length
const int num_heads, // Number of attention heads (N)
const int head_size, // Hidden size per head (H)
const bool static_kv, // Whether cross attention or not
const bool use_past, // Whether use cache or not
const bool has_layer_state, // Whether output cache or not
const bool has_key_padding_mask, // Whether use key_padding_mask or not
const float mask_filter_value, // Mask filter value
const void* gemm_query_buffer, // Query buffer
const void* gemm_kv_buffer, // Key and value buffer
const bool* key_padding_mask, // Key padding mask
const void* key_cache, // Input key cache
const void* value_cache, // Input value cache
void* qkv_buffer, // Temporary buffer
void* workspace_buffer, // Temporary buffer
void* output, // Output tensor
void* new_key_cache, // New_key_cache tensor
void* new_value_cache // New_value_cache tensor
);
// BxNxSxH => BxSxNxH or SxBxNxH (reversed_bs is true)
Status LaunchTransCtx(cudaStream_t stream,
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
@ -184,14 +153,27 @@ Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int
int sequence_length, int total_sequence_length, bool pass_past_in_kv,
cudaStream_t stream,
int max_threads_per_block,
AttentionData<T>& data,
QkvData<T>& qkv);
AttentionData<T>& data);
template <typename T>
Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream,
const int max_sequence_length,
const int past_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const T* biases,
const T* qkv_buffer,
T* present);
template <typename T>
Status LaunchStridedCopy(cudaStream_t stream,
const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h)
T* out, longlong4 out_strides, // coord (b,n,s,h)
int max_threads_per_block);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -1,8 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cuda/bert/attention_impl.h"
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/cu_inc/common.cuh"
using namespace onnxruntime::cuda;
@ -244,48 +245,48 @@ Status LaunchConcatPastToPresent(cudaStream_t stream,
present);
}
#ifndef USE_ROCM // exclude from hipify
#ifndef USE_ROCM // exclude the following from hipify since they are not used in ROCM EP
template <typename T>
Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size,
int sequence_length, int total_sequence_length, bool pass_past_in_kv,
cudaStream_t stream,
int max_threads_per_block,
AttentionData<T>& data,
QkvData<T>& qkv) {
AttentionData<T>& data) {
// Concat past key value to present (2xBxNxLxH), where L is kv_sequence_length and T is total_sequence_length.
// past_k (BxNxPxH) + k (BxNxLxH) => present_k (BxNxTxH)
// past_v (BxNxPxH) + v (BxNxLxH) => present_v (BxNxTxH)
// When there is past state, the head size for Q/K/V shall be same: H == H_v.
if (nullptr != data.present) {
assert(qkv.format == AttentionQkvFormat::Q_K_V_BNSH || qkv.format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH);
assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH ||
data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH);
ORT_RETURN_IF_ERROR(
LaunchConcatPastToPresent(
stream, total_sequence_length, sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, data.past, qkv.k, data.present));
max_threads_per_block, data.past, data.k, data.present));
// Update pointers to present_k and present_v.
qkv.k = data.present;
qkv.v = data.present + batch_size * num_heads * total_sequence_length * qk_head_size;
}
if (nullptr != data.past_key || nullptr != data.present_key) {
data.k = data.present;
data.v = data.present + batch_size * num_heads * total_sequence_length * qk_head_size;
} else if (nullptr != data.past_key || nullptr != data.present_key) {
if (nullptr != data.past_key && nullptr == data.present_key) {
qkv.k = const_cast<T*>(data.past_key);
qkv.v = const_cast<T*>(data.past_value);
data.k = const_cast<T*>(data.past_key);
data.v = const_cast<T*>(data.past_value);
} else if (nullptr == data.past_key && nullptr != data.present_key) {
if (qkv.format == AttentionQkvFormat::Q_K_V_BNSH) {
qkv.k = data.present_key;
qkv.v = data.present_value;
if (data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH) {
data.k = data.present_key;
data.v = data.present_value;
} else {
assert(qkv.format == AttentionQkvFormat::Q_K_V_BSNH);
qkv.k = data.temp_k_workspace;
qkv.v = data.temp_v_workspace;
assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
data.k = data.temp_k_workspace;
data.v = data.temp_v_workspace;
}
} else if (pass_past_in_kv) {
// past_key and past_value are used directly as key and value in attention computations
qkv.k = const_cast<T*>(data.past_key);
qkv.v = const_cast<T*>(data.past_value);
data.k = const_cast<T*>(data.past_key);
data.v = const_cast<T*>(data.past_value);
// This path has a memory copy from past_key and past_value to present_key and present_value
// Avoid this path since the memory copy is unnecessary because past_key == present_key and
@ -298,14 +299,14 @@ Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int
ORT_RETURN_IF_ERROR(
LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length,
batch_size, qk_head_size, num_heads,
max_threads_per_block, 1, data.past_key, qkv.k, data.present_key));
max_threads_per_block, 1, data.past_key, data.k, data.present_key));
ORT_RETURN_IF_ERROR(
LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length,
batch_size, v_head_size, num_heads,
max_threads_per_block, 1, data.past_value, qkv.v, data.present_value));
max_threads_per_block, 1, data.past_value, data.v, data.present_value));
// Update pointers to present_k and present_v.
qkv.k = data.present_key;
qkv.v = data.present_value;
data.k = data.present_key;
data.v = data.present_value;
}
}
@ -317,15 +318,147 @@ template Status ConcatPastToPresent<float>(int batch_size, int num_heads, int qk
int sequence_length, int total_sequence_length, bool pass_past_in_kv,
cudaStream_t stream,
int max_threads_per_block,
AttentionData<float>& data,
QkvData<float>& qkv);
AttentionData<float>& data);
template Status ConcatPastToPresent<half>(int batch_size, int num_heads, int qk_head_size, int v_head_size,
int sequence_length, int total_sequence_length, bool pass_past_in_kv,
cudaStream_t stream,
int max_threads_per_block,
AttentionData<half>& data,
QkvData<half>& qkv);
AttentionData<half>& data);
// ----------------------------------------------------------------------------------
// Below kernels are for past and present sharing buffer
// ----------------------------------------------------------------------------------
template <typename T>
__global__ void AddBiasTransAppendKvToPresentSmall(
const T* qkv, const T* biases, T* present,
const int head_size, const int past_sequence_length, const int max_sequence_length) {
// Input: BxSxMxNxH (Format 1)
// Output: (2, B, N, [P..P+S) of MaxS, H),
// B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size
const int n = threadIdx.y;
const int s = blockIdx.x;
const int b = blockIdx.y;
const int N = blockDim.y;
const int S = gridDim.x;
const int B = gridDim.y;
constexpr int M = 3; // Matrix count in qkv
const int m = blockIdx.z + 1; // k = 1, v = 2
const int NH = N * head_size;
const int NHS = NH * S;
qkv += (n * head_size + (s * M + m) * NH + b * M * NHS);
if (biases) {
biases += (m * NH + n * head_size);
}
const int MsH = max_sequence_length * head_size;
const int NMsH = N * MsH;
const int BNMsH = B * NMsH;
present += ((past_sequence_length + s) * head_size + n * MsH + b * NMsH + (m - 1) * BNMsH);
for (int h = threadIdx.x; h < head_size; h += blockDim.x) {
T bias = (biases ? biases[h] : (T)0.0f);
present[h] = qkv[h] + bias;
}
}
template <typename T>
__global__ void AddBiasTransAppendKvToPresent(
const T* qkv, const T* biases, T* present,
const int head_size, const int past_sequence_length, const int max_sequence_length) {
// Input: BxSxMxNxH (Format 1)
// Output: (2, B, N, [P..P+S) of MaxS, H),
// B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size
const int n = blockIdx.x;
const int s = blockIdx.y;
const int b = (blockIdx.z >> 1);
const int N = gridDim.x;
const int S = gridDim.y;
const int B = (gridDim.z >> 1);
constexpr int M = 3; // Matrix count in qkv
const int m = (blockIdx.z & 0x1) + 1; // k = 1, v = 2
const int NH = N * head_size;
const int NHS = NH * S;
qkv += (n * head_size + (s * M + m) * NH + b * M * NHS);
if (biases) {
biases += (m * NH + n * head_size);
}
const int MsH = max_sequence_length * head_size;
const int NMsH = N * MsH;
const int BNMsH = B * NMsH;
present += ((past_sequence_length + s) * head_size + n * MsH + b * NMsH + (m - 1) * BNMsH);
for (int h = threadIdx.x; h < head_size; h += blockDim.x) {
T bias = (biases ? biases[h] : (T)0.0f);
present[h] = qkv[h] + bias;
}
}
// qkv buffer is merged tensor of shape (B,S,3,N,H), k v is the second/third of the 3.
// bias is of shape (3, NxH) or nullptr
// append to present of (2, B, N, (P..T) of M, H),
template <typename T>
Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream,
const int max_sequence_length,
const int past_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const T* biases,
const T* qkv_buffer,
T* present) {
assert(head_size <= (1 << 30));
int64_t nh = (int64_t)head_size * num_heads;
if (nh <= max_threads_per_block) {
const dim3 grid(sequence_length, batch_size, 2); // 2 for k and v
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
AddBiasTransAppendKvToPresentSmall<T><<<grid, block, 0, stream>>>(
qkv_buffer, biases, present, head_size, past_sequence_length, max_sequence_length);
} else {
const dim3 grid(num_heads, sequence_length, batch_size * 2); // 2 for k and v
const dim3 block(std::min(head_size, max_threads_per_block), 1, 1);
AddBiasTransAppendKvToPresent<T><<<grid, block, 0, stream>>>(
qkv_buffer, biases, present, head_size, past_sequence_length, max_sequence_length);
}
return CUDA_CALL(cudaGetLastError());
}
template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream,
const int max_sequence_length,
const int total_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const float* bias,
const float* qkv_buffer,
float* present);
template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream,
const int max_sequence_length,
const int total_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const half* bias,
const half* qkv_buffer,
half* present);
#endif
} // namespace cuda

View file

@ -1,9 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <cuda_fp16.h>
#include "core/providers/cuda/cu_inc/common.cuh"
#include "contrib_ops/cuda/bert/attention_impl.h"
#include "core/providers/cuda/cu_inc/common.cuh"
#include "contrib_ops/cuda/bert/add_bias_transpose.h"
#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
@ -406,22 +405,25 @@ Status PrepareQkv_MHA_NotPacked(contrib::AttentionParameters& parameters,
// Query (BxSxNxH) => Q (BxNxSxH)
constexpr int format = 0;
LaunchAddBiasTranspose<T>(stream, 1, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size,
data.query, data.bias, q,
true, -1);
LaunchAddBiasTranspose<T>(
stream, 1, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size,
data.query, data.bias, q,
true, -1);
// Key (BxLxNxH) => K (BxNxLxH)
LaunchAddBiasTranspose<T>(stream, 1, format, max_threads_per_block,
batch_size, kv_sequence_length, num_heads, qk_head_size,
data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, k,
true, -1);
LaunchAddBiasTranspose<T>(
stream, 1, format, max_threads_per_block,
batch_size, kv_sequence_length, num_heads, qk_head_size,
data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, k,
true, -1);
// Value (BxLxNxH_v) => K (BxNxLxH_v)
LaunchAddBiasTranspose<T>(stream, 1, format, max_threads_per_block,
batch_size, kv_sequence_length, num_heads, v_head_size,
data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, v,
true, -1);
LaunchAddBiasTranspose<T>(
stream, 1, format, max_threads_per_block,
batch_size, kv_sequence_length, num_heads, v_head_size,
data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, v,
true, -1);
DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size);
DUMP_TENSOR_D("k(BNSH)", k, batch_size, num_heads, kv_sequence_length, qk_head_size);
@ -435,8 +437,8 @@ template <typename T>
Status PrepareQkv(contrib::AttentionParameters& parameters,
AttentionData<T>& data,
cudaStream_t stream,
int max_threads_per_block,
QkvData<T>& qkv) {
int max_threads_per_block) {
data.scratch = data.workspace;
if (data.has_qkv_workspace) {
const int size_per_batch_q = parameters.sequence_length * parameters.head_size;
const int size_per_batch_k = parameters.kv_sequence_length * parameters.head_size;
@ -445,27 +447,27 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,
const size_t elements_q = static_cast<size_t>(batches) * static_cast<size_t>(size_per_batch_q);
const size_t elements_k = static_cast<size_t>(batches) * static_cast<size_t>(size_per_batch_k);
const size_t elements_v = static_cast<size_t>(batches) * static_cast<size_t>(size_per_batch_v);
qkv.q = data.workspace;
qkv.k = data.workspace + elements_q;
qkv.v = qkv.k + elements_k;
qkv.after_v = qkv.v + elements_v;
data.q = data.workspace;
data.k = data.workspace + elements_q;
data.v = data.k + elements_k;
data.scratch = data.v + elements_v;
}
if (nullptr != data.gemm_buffer) { // Attention operator
ORT_RETURN_IF_ERROR(PrepareQkv_Attention<T>(parameters, data, stream, max_threads_per_block,
qkv.format));
data.qkv_format));
} else if (data.past_key != nullptr || data.present_key != nullptr) { // mha operator with past/present state
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast(parameters, data, stream, max_threads_per_block,
qkv.q, qkv.k, qkv.v, qkv.format));
data.q, data.k, data.v, data.qkv_format));
} else if (data.key == nullptr) { // multihead attention operator, no past, packed qkv
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedQKV(parameters, data, stream, max_threads_per_block,
qkv.q, qkv.k, qkv.v, qkv.format));
data.q, data.k, data.v, data.qkv_format));
} else if (data.value == nullptr) { // multihead attention operator, no past, packed kv
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedKV(parameters, data, stream, max_threads_per_block,
qkv.q, qkv.k, qkv.v, qkv.format));
data.q, data.k, data.v, data.qkv_format));
} else { // multihead attention operator, no past, separated Q/K/V inputs
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NotPacked(parameters, data, stream, max_threads_per_block,
qkv.q, qkv.k, qkv.v, qkv.format));
data.q, data.k, data.v, data.qkv_format));
}
CUDA_RETURN_IF_ERROR(cudaGetLastError());
@ -477,15 +479,13 @@ template Status PrepareQkv<float>(
contrib::AttentionParameters& parameters,
AttentionData<float>& data,
cudaStream_t stream,
int max_threads_per_block,
QkvData<float>& qkv);
int max_threads_per_block);
template Status PrepareQkv<half>(
contrib::AttentionParameters& parameters,
AttentionData<half>& data,
cudaStream_t stream,
int max_threads_per_block,
QkvData<half>& qkv);
int max_threads_per_block);
} // namespace cuda
} // namespace contrib

View file

@ -18,7 +18,6 @@ limitations under the License.
*/
#include <cub/cub.cuh>
#include <cuda_fp16.h>
#include <math_constants.h>
#include "core/providers/cuda/cu_inc/common.cuh"
#include "core/providers/cuda/cuda_common.h"

View file

@ -1,8 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/cuda/bert/attention_impl.h"
#include "contrib_ops/cuda/bert/decoder_attention.h"
#include "contrib_ops/cuda/bert/decoder_attention_impl.h"
#include "contrib_ops/cuda/bert/transformer_cuda_common.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cuda/shared_inc/fpgeneric.h"
@ -85,7 +85,8 @@ Status CheckInputs(const TensorShape& query_shape,
}
if (kv_weights_dims[0] != hidden_size || kv_weights_dims[1] != 2 * static_cast<int64_t>(hidden_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "kv_weights shall have shape (hidden size, 2 * hidden size)");
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"kv_weights shall have shape (hidden size, 2 * hidden size)");
}
const auto& bias_dims = bias_shape.GetDims();
@ -137,7 +138,8 @@ Status CheckInputs(const TensorShape& query_shape,
const auto& value_cache_dims = value_cache->Shape().GetDims();
if (value_cache_dims.size() != 4) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value_cache' is expected to have 4 dimension, got ",
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'value_cache' is expected to have 4 dimension, got ",
value_cache_dims.size());
}
@ -353,10 +355,12 @@ Status DecoderAttention<T>::ComputeInternal(OpKernelContext* context) const {
}
}
size_t bytes = element_size * batch_size * (static_cast<size_t>(sequence_length) + static_cast<size_t>(2) * kv_sequence_length) * hidden_size;
size_t bytes = element_size * batch_size *
(static_cast<size_t>(sequence_length) + static_cast<size_t>(2) * kv_sequence_length) * hidden_size;
auto qkv_buffer_p = GetScratchBuffer<void>(bytes, context->GetComputeStream());
bytes = element_size * 2 * batch_size * sequence_length * num_heads_ * (static_cast<size_t>(2) * head_size + static_cast<size_t>(kv_sequence_length));
bytes = element_size * 2 * batch_size * sequence_length * num_heads_ *
(static_cast<size_t>(2) * head_size + static_cast<size_t>(kv_sequence_length));
auto workspace_p = GetScratchBuffer<void>(bytes, context->GetComputeStream());
Tensor* output(context->Output(0, query_shape));

View file

@ -0,0 +1,263 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/cuda/bert/decoder_attention_impl.h"
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/shared_inc/fpgeneric.h"
#include "contrib_ops/cuda/bert/attention_softmax.h"
using namespace onnxruntime::contrib::attention_softmax_cuda;
namespace onnxruntime {
namespace contrib {
namespace cuda {
template <typename T>
Status DecoderQkvToContext(
const cudaDeviceProp& device_prop,
Stream* ort_stream,
cublasHandle_t& cublas,
const size_t element_size,
const int batch_size,
const int sequence_length,
const int kv_sequence_length,
const int num_heads,
const int head_size,
const bool static_kv,
const bool use_past,
const bool has_layer_state,
const bool has_key_padding_mask,
const float mask_filter_value,
const T* gemm_query_buffer,
const T* gemm_kv_buffer,
const bool* key_padding_mask,
const T* key_cache,
const T* value_cache,
T* qkv_buffer,
T* workspace_buffer,
T* output,
T* new_key_cache,
T* new_value_cache) {
const int max_threads_per_block = device_prop.maxThreadsPerBlock;
const int BN = batch_size * num_heads;
const int BHN = BN * head_size;
const int BNS = BN * sequence_length;
const int k_buffer_offset = sequence_length * BHN;
const int v_buffer_offset = (sequence_length + kv_sequence_length) * BHN;
T* temp_qkv_buffer = workspace_buffer;
auto stream = static_cast<cudaStream_t>(ort_stream->GetHandle());
const T* q = qkv_buffer;
// transpose q and copy them to qkv_buffer
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, head_size, num_heads,
max_threads_per_block, true, gemm_query_buffer, qkv_buffer));
const T* k = qkv_buffer + k_buffer_offset;
const T* v = qkv_buffer + v_buffer_offset;
if (!has_layer_state || !use_past) {
if (!static_kv) {
// transpose kv and copy them to qkv_buffer
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads,
max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset));
} else {
// transpose kv and copy them to qkv_buffer
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, kv_sequence_length, batch_size, head_size, num_heads,
max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset));
}
} else {
if (!static_kv) {
// transpose kv and copy them to temp_buffer
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads,
max_threads_per_block, true, gemm_kv_buffer, temp_qkv_buffer));
// concat cache-k with k and copy to qkv_buffer
if (nullptr != key_cache) {
ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length,
sequence_length, batch_size, head_size, num_heads,
max_threads_per_block, 1,
key_cache,
temp_qkv_buffer,
qkv_buffer + k_buffer_offset));
}
// concat cache-v with v and copy to qkv_buffer
if (nullptr != value_cache) {
ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length,
sequence_length, batch_size, head_size, num_heads,
max_threads_per_block, 1,
value_cache,
temp_qkv_buffer + k_buffer_offset,
qkv_buffer + v_buffer_offset));
}
}
}
if (has_layer_state) {
if (use_past && static_kv) {
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_key_cache, key_cache, kv_sequence_length * BHN * sizeof(T),
cudaMemcpyDeviceToDevice, stream));
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_value_cache, value_cache, kv_sequence_length * BHN * sizeof(T),
cudaMemcpyDeviceToDevice, stream));
} else {
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_key_cache, k, kv_sequence_length * BHN * sizeof(T),
cudaMemcpyDeviceToDevice, stream));
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_value_cache, v, kv_sequence_length * BHN * sizeof(T),
cudaMemcpyDeviceToDevice, stream));
}
}
// scratch1: BxNxSxL buffer
// scratch2: BxNxSxL buffer
// scratch3: BxNxSxH buffer
T* scratch1 = temp_qkv_buffer + 3 * BHN * sequence_length;
T* scratch2 = scratch1 + BNS * kv_sequence_length;
T* scratch3 = scratch2 + BNS * kv_sequence_length;
// compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxL
// Q: BxNxSxH, K (present_k): BxNxLxH, Q*K': BxNxSxL
const float rsqrt_head_size = 1.f / sqrt(static_cast<float>(head_size));
const int temp_matrix_size = sequence_length * kv_sequence_length;
float one = 1.0f;
float zero = 0.f;
float alpha = rsqrt_head_size;
const int strideA = kv_sequence_length * head_size;
const int strideB = sequence_length * head_size;
if (use_past && static_kv) {
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
cublas, CUBLAS_OP_T, CUBLAS_OP_N,
kv_sequence_length, sequence_length, head_size,
&alpha, key_cache, head_size, strideA,
q, head_size, strideB,
&zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop));
} else {
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
cublas, CUBLAS_OP_T, CUBLAS_OP_N,
kv_sequence_length, sequence_length, head_size,
&alpha, k, head_size, strideA,
q, head_size, strideB,
&zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop));
}
constexpr bool is_unidirectional = false;
const T* add_before_softmax = nullptr;
if (has_key_padding_mask) {
constexpr int mask_dimension = 2;
constexpr int max_sequence_length = 0;
ORT_RETURN_IF_ERROR(ComputeSoftmaxWithRawMask<T>(
ort_stream, kv_sequence_length, sequence_length, batch_size,
num_heads, nullptr, key_padding_mask, add_before_softmax,
false /*broadcast rpb*/, scratch1, scratch2, is_unidirectional,
1.0f, mask_dimension, max_sequence_length, false, nullptr,
mask_filter_value));
} else {
ORT_RETURN_IF_ERROR(ComputeSoftmax<T>(
stream, kv_sequence_length, sequence_length, batch_size, num_heads,
add_before_softmax, false /*broadcast rpb*/, scratch1, scratch2,
is_unidirectional));
}
// compute P*V (as V*P), and store in scratch3: BxNxSxH
if (use_past && static_kv) {
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
cublas, CUBLAS_OP_N, CUBLAS_OP_N,
head_size, sequence_length, kv_sequence_length,
&one, value_cache, head_size, strideA,
scratch2, kv_sequence_length, temp_matrix_size,
&zero, scratch3, head_size, strideB, BN, device_prop));
} else {
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
cublas, CUBLAS_OP_N, CUBLAS_OP_N,
head_size, sequence_length, kv_sequence_length,
&one, v, head_size, strideA,
scratch2, kv_sequence_length, temp_matrix_size,
&zero, scratch3, head_size, strideB, BN, device_prop));
}
// scratch3 is BxNxSxH, transpose to output SxBxNxH
return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads,
max_threads_per_block, true, scratch3, output);
}
Status LaunchDecoderAttentionKernel(
const cudaDeviceProp& device_prop,
Stream* stream,
cublasHandle_t& cublas,
const size_t element_size,
const int batch_size,
const int sequence_length,
const int kv_sequence_length,
const int num_heads,
const int head_size,
const bool static_kv,
const bool use_past,
const bool has_layer_state,
const bool has_key_padding_mask,
const float mask_filter_value,
const void* gemm_query_buffer,
const void* gemm_kv_buffer,
const bool* key_padding_mask,
const void* key_cache,
const void* value_cache,
void* qkv_buffer,
void* workspace_buffer,
void* output,
void* new_key_cache,
void* new_value_cache) {
if (element_size == 2) {
return DecoderQkvToContext(
device_prop,
stream,
cublas,
element_size,
batch_size,
sequence_length,
kv_sequence_length,
num_heads,
head_size,
static_kv,
use_past,
has_layer_state,
has_key_padding_mask,
mask_filter_value,
reinterpret_cast<const half*>(gemm_query_buffer),
reinterpret_cast<const half*>(gemm_kv_buffer),
key_padding_mask,
reinterpret_cast<const half*>(key_cache),
reinterpret_cast<const half*>(value_cache),
reinterpret_cast<half*>(qkv_buffer),
reinterpret_cast<half*>(workspace_buffer),
reinterpret_cast<half*>(output),
reinterpret_cast<half*>(new_key_cache),
reinterpret_cast<half*>(new_value_cache));
} else {
return DecoderQkvToContext(
device_prop,
stream,
cublas,
element_size,
batch_size,
sequence_length,
kv_sequence_length,
num_heads,
head_size,
static_kv,
use_past,
has_layer_state,
has_key_padding_mask,
mask_filter_value,
reinterpret_cast<const float*>(gemm_query_buffer),
reinterpret_cast<const float*>(gemm_kv_buffer),
key_padding_mask,
reinterpret_cast<const float*>(key_cache),
reinterpret_cast<const float*>(value_cache),
reinterpret_cast<float*>(qkv_buffer),
reinterpret_cast<float*>(workspace_buffer),
reinterpret_cast<float*>(output),
reinterpret_cast<float*>(new_key_cache),
reinterpret_cast<float*>(new_value_cache));
}
}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,41 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "contrib_ops/cuda/bert/attention_impl.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
Status LaunchDecoderAttentionKernel(
const cudaDeviceProp& prop, // Device Properties
Stream* stream, // ORT Stream
cublasHandle_t& cublas, // Cublas handle
const size_t element_size, // Element size of input tensor
const int batch_size, // Batch size (B)
const int sequence_length, // Sequence length (S)
const int kv_sequence_length, // Key/Value/Cache sequence length
const int num_heads, // Number of attention heads (N)
const int head_size, // Hidden size per head (H)
const bool static_kv, // Whether cross attention or not
const bool use_past, // Whether use cache or not
const bool has_layer_state, // Whether output cache or not
const bool has_key_padding_mask, // Whether use key_padding_mask or not
const float mask_filter_value, // Mask filter value
const void* gemm_query_buffer, // Query buffer
const void* gemm_kv_buffer, // Key and value buffer
const bool* key_padding_mask, // Key padding mask
const void* key_cache, // Input key cache
const void* value_cache, // Input value cache
void* qkv_buffer, // Temporary buffer
void* workspace_buffer, // Temporary buffer
void* output, // Output tensor
void* new_key_cache, // New_key_cache tensor
void* new_value_cache // New_value_cache tensor
);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -30,6 +30,7 @@ limitations under the License.
#include "contrib_ops/cpu/bert/attention_base.h"
#include "contrib_ops/rocm/bert/attention_impl.h"
#include "contrib_ops/rocm/bert/attention_softmax.h"
#include "contrib_ops/rocm/bert/decoder_attention_impl.h"
using namespace onnxruntime::rocm;

View file

@ -28,34 +28,6 @@ size_t GetAttentionWorkspaceSize(
int sequence_length,
int past_sequence_length);
Status LaunchDecoderAttentionKernel(
const hipDeviceProp_t& prop, // Device Properties
RocmTuningContext* tuning_ctx, // context for tuning
Stream* stream, // ORT Stream
rocblas_handle& rocblas, // Rocblas handle
const size_t element_size, // Element size of input tensor
const int batch_size, // Batch size (B)
const int sequence_length, // Sequence length (S)
const int kv_sequence_length, // Key/Value/Cache sequence length
const int num_heads, // Number of attention heads (N)
const int head_size, // Hidden layer size per head (H)
const bool static_kv, // Whether cross attention or not
const bool use_past, // Whether use cache or not
const bool has_layer_state, // Whether output cache or not
const bool has_key_padding_mask, // Whether use key_padding_mask or not
const float mask_filter_value, // Mask filter value
const void* gemm_query_buffer, // Query buffer
const void* gemm_kv_buffer, // Key and value buffer
const bool* key_padding_mask, // Key padding mask
const void* key_cache, // Input key cache
const void* value_cache, // Input value cache
void* qkv_buffer, // Temporary buffer
void* workspace_buffer, // Temporary buffer
void* output, // Output tensor
void* new_key_cache, // New_key_cache tensor
void* new_value_cache // New_value_cache tensor
);
Status LaunchTransCtx(hipStream_t stream,
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
const int max_threads_per_block, const bool reversed_bs, const float* input, float* output);

View file

@ -0,0 +1,46 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <hip/hip_fp16.h>
#include <rocblas/rocblas.h>
#include "contrib_ops/cpu/bert/attention_common.h"
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#include "core/providers/rocm/tunable/rocm_tunable.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
Status LaunchDecoderAttentionKernel(
const hipDeviceProp_t& prop, // Device Properties
RocmTuningContext* tuning_ctx, // context for tuning
Stream* stream, // ORT Stream
rocblas_handle& rocblas, // Rocblas handle
const size_t element_size, // Element size of input tensor
const int batch_size, // Batch size (B)
const int sequence_length, // Sequence length (S)
const int kv_sequence_length, // Key/Value/Cache sequence length
const int num_heads, // Number of attention heads (N)
const int head_size, // Hidden layer size per head (H)
const bool static_kv, // Whether cross attention or not
const bool use_past, // Whether use cache or not
const bool has_layer_state, // Whether output cache or not
const bool has_key_padding_mask, // Whether use key_padding_mask or not
const float mask_filter_value, // Mask filter value
const void* gemm_query_buffer, // Query buffer
const void* gemm_kv_buffer, // Key and value buffer
const bool* key_padding_mask, // Key padding mask
const void* key_cache, // Input key cache
const void* value_cache, // Input value cache
void* qkv_buffer, // Temporary buffer
void* workspace_buffer, // Temporary buffer
void* output, // Output tensor
void* new_key_cache, // New_key_cache tensor
void* new_value_cache // New_value_cache tensor
);
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime