mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Refactor Attention cuda kernel (#17578)
* Break QkvToContext into small functions. Each fused and unfused kernel will have separated function. * Move DecoderAttention kernel to separated file * Move KV cache related kernel to attention_kv_cache.cu ### Motivation and Context To make the code easier to maintain.
This commit is contained in:
parent
068300d97e
commit
730fab3050
12 changed files with 934 additions and 758 deletions
|
|
@ -11,6 +11,8 @@ set(contrib_ops_excluded_files
|
|||
"bert/attention_softmax.h"
|
||||
"bert/attention_softmax.cu"
|
||||
"bert/attention_prepare_qkv.cu"
|
||||
"bert/decoder_attention_impl.h"
|
||||
"bert/decoder_attention_impl.cu"
|
||||
"bert/decoder_masked_multihead_attention.h"
|
||||
"bert/decoder_masked_multihead_attention.cc"
|
||||
"bert/decoder_masked_self_attention.h"
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -81,24 +81,20 @@ struct AttentionData {
|
|||
|
||||
mutable CumulatedSequenceLengthCache* cumulated_sequence_length_q_cache = nullptr;
|
||||
mutable CumulatedSequenceLengthCache* cumulated_sequence_length_kv_cache = nullptr;
|
||||
};
|
||||
|
||||
// Intermediate data pointers available after PrepareQKV
|
||||
template <typename T>
|
||||
struct QkvData {
|
||||
// Intermediate data
|
||||
T* q = nullptr;
|
||||
T* k = nullptr;
|
||||
T* v = nullptr;
|
||||
T* after_v = nullptr; // pointer right after v
|
||||
AttentionQkvFormat format = AttentionQkvFormat::Q_K_V_BSNH;
|
||||
T* scratch = nullptr;
|
||||
AttentionQkvFormat qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
Status PrepareQkv(contrib::AttentionParameters& parameters,
|
||||
AttentionData<T>& data,
|
||||
cudaStream_t stream,
|
||||
int max_threads_per_block,
|
||||
QkvData<T>& qkv_data);
|
||||
int max_threads_per_block);
|
||||
|
||||
template <typename T>
|
||||
Status QkvToContext(
|
||||
|
|
@ -108,33 +104,6 @@ Status QkvToContext(
|
|||
contrib::AttentionParameters& parameters,
|
||||
AttentionData<T>& data);
|
||||
|
||||
Status LaunchDecoderAttentionKernel(
|
||||
const cudaDeviceProp& prop, // Device Properties
|
||||
Stream* stream, // ORT Stream
|
||||
cublasHandle_t& cublas, // Cublas handle
|
||||
const size_t element_size, // Element size of input tensor
|
||||
const int batch_size, // Batch size (B)
|
||||
const int sequence_length, // Sequence length (S)
|
||||
const int kv_sequence_length, // Key/Value/Cache sequence length
|
||||
const int num_heads, // Number of attention heads (N)
|
||||
const int head_size, // Hidden size per head (H)
|
||||
const bool static_kv, // Whether cross attention or not
|
||||
const bool use_past, // Whether use cache or not
|
||||
const bool has_layer_state, // Whether output cache or not
|
||||
const bool has_key_padding_mask, // Whether use key_padding_mask or not
|
||||
const float mask_filter_value, // Mask filter value
|
||||
const void* gemm_query_buffer, // Query buffer
|
||||
const void* gemm_kv_buffer, // Key and value buffer
|
||||
const bool* key_padding_mask, // Key padding mask
|
||||
const void* key_cache, // Input key cache
|
||||
const void* value_cache, // Input value cache
|
||||
void* qkv_buffer, // Temporary buffer
|
||||
void* workspace_buffer, // Temporary buffer
|
||||
void* output, // Output tensor
|
||||
void* new_key_cache, // New_key_cache tensor
|
||||
void* new_value_cache // New_value_cache tensor
|
||||
);
|
||||
|
||||
// BxNxSxH => BxSxNxH or SxBxNxH (reversed_bs is true)
|
||||
Status LaunchTransCtx(cudaStream_t stream,
|
||||
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
|
||||
|
|
@ -184,14 +153,27 @@ Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int
|
|||
int sequence_length, int total_sequence_length, bool pass_past_in_kv,
|
||||
cudaStream_t stream,
|
||||
int max_threads_per_block,
|
||||
AttentionData<T>& data,
|
||||
QkvData<T>& qkv);
|
||||
AttentionData<T>& data);
|
||||
|
||||
template <typename T>
|
||||
Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream,
|
||||
const int max_sequence_length,
|
||||
const int past_sequence_length,
|
||||
const int sequence_length,
|
||||
const int batch_size,
|
||||
const int head_size,
|
||||
const int num_heads,
|
||||
const int max_threads_per_block,
|
||||
const T* biases,
|
||||
const T* qkv_buffer,
|
||||
T* present);
|
||||
|
||||
template <typename T>
|
||||
Status LaunchStridedCopy(cudaStream_t stream,
|
||||
const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h)
|
||||
T* out, longlong4 out_strides, // coord (b,n,s,h)
|
||||
int max_threads_per_block);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "contrib_ops/cuda/bert/attention_impl.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
|
||||
using namespace onnxruntime::cuda;
|
||||
|
||||
|
|
@ -244,48 +245,48 @@ Status LaunchConcatPastToPresent(cudaStream_t stream,
|
|||
present);
|
||||
}
|
||||
|
||||
#ifndef USE_ROCM // exclude from hipify
|
||||
#ifndef USE_ROCM // exclude the following from hipify since they are not used in ROCM EP
|
||||
|
||||
template <typename T>
|
||||
Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size,
|
||||
int sequence_length, int total_sequence_length, bool pass_past_in_kv,
|
||||
cudaStream_t stream,
|
||||
int max_threads_per_block,
|
||||
AttentionData<T>& data,
|
||||
QkvData<T>& qkv) {
|
||||
AttentionData<T>& data) {
|
||||
// Concat past key value to present (2xBxNxLxH), where L is kv_sequence_length and T is total_sequence_length.
|
||||
// past_k (BxNxPxH) + k (BxNxLxH) => present_k (BxNxTxH)
|
||||
// past_v (BxNxPxH) + v (BxNxLxH) => present_v (BxNxTxH)
|
||||
// When there is past state, the head size for Q/K/V shall be same: H == H_v.
|
||||
|
||||
if (nullptr != data.present) {
|
||||
assert(qkv.format == AttentionQkvFormat::Q_K_V_BNSH || qkv.format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH);
|
||||
assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH ||
|
||||
data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH);
|
||||
|
||||
ORT_RETURN_IF_ERROR(
|
||||
LaunchConcatPastToPresent(
|
||||
stream, total_sequence_length, sequence_length, batch_size, qk_head_size, num_heads,
|
||||
max_threads_per_block, data.past, qkv.k, data.present));
|
||||
max_threads_per_block, data.past, data.k, data.present));
|
||||
|
||||
// Update pointers to present_k and present_v.
|
||||
qkv.k = data.present;
|
||||
qkv.v = data.present + batch_size * num_heads * total_sequence_length * qk_head_size;
|
||||
}
|
||||
|
||||
if (nullptr != data.past_key || nullptr != data.present_key) {
|
||||
data.k = data.present;
|
||||
data.v = data.present + batch_size * num_heads * total_sequence_length * qk_head_size;
|
||||
} else if (nullptr != data.past_key || nullptr != data.present_key) {
|
||||
if (nullptr != data.past_key && nullptr == data.present_key) {
|
||||
qkv.k = const_cast<T*>(data.past_key);
|
||||
qkv.v = const_cast<T*>(data.past_value);
|
||||
data.k = const_cast<T*>(data.past_key);
|
||||
data.v = const_cast<T*>(data.past_value);
|
||||
} else if (nullptr == data.past_key && nullptr != data.present_key) {
|
||||
if (qkv.format == AttentionQkvFormat::Q_K_V_BNSH) {
|
||||
qkv.k = data.present_key;
|
||||
qkv.v = data.present_value;
|
||||
if (data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH) {
|
||||
data.k = data.present_key;
|
||||
data.v = data.present_value;
|
||||
} else {
|
||||
assert(qkv.format == AttentionQkvFormat::Q_K_V_BSNH);
|
||||
qkv.k = data.temp_k_workspace;
|
||||
qkv.v = data.temp_v_workspace;
|
||||
assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
|
||||
data.k = data.temp_k_workspace;
|
||||
data.v = data.temp_v_workspace;
|
||||
}
|
||||
} else if (pass_past_in_kv) {
|
||||
// past_key and past_value are used directly as key and value in attention computations
|
||||
qkv.k = const_cast<T*>(data.past_key);
|
||||
qkv.v = const_cast<T*>(data.past_value);
|
||||
data.k = const_cast<T*>(data.past_key);
|
||||
data.v = const_cast<T*>(data.past_value);
|
||||
|
||||
// This path has a memory copy from past_key and past_value to present_key and present_value
|
||||
// Avoid this path since the memory copy is unnecessary because past_key == present_key and
|
||||
|
|
@ -298,14 +299,14 @@ Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int
|
|||
ORT_RETURN_IF_ERROR(
|
||||
LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length,
|
||||
batch_size, qk_head_size, num_heads,
|
||||
max_threads_per_block, 1, data.past_key, qkv.k, data.present_key));
|
||||
max_threads_per_block, 1, data.past_key, data.k, data.present_key));
|
||||
ORT_RETURN_IF_ERROR(
|
||||
LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length,
|
||||
batch_size, v_head_size, num_heads,
|
||||
max_threads_per_block, 1, data.past_value, qkv.v, data.present_value));
|
||||
max_threads_per_block, 1, data.past_value, data.v, data.present_value));
|
||||
// Update pointers to present_k and present_v.
|
||||
qkv.k = data.present_key;
|
||||
qkv.v = data.present_value;
|
||||
data.k = data.present_key;
|
||||
data.v = data.present_value;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -317,15 +318,147 @@ template Status ConcatPastToPresent<float>(int batch_size, int num_heads, int qk
|
|||
int sequence_length, int total_sequence_length, bool pass_past_in_kv,
|
||||
cudaStream_t stream,
|
||||
int max_threads_per_block,
|
||||
AttentionData<float>& data,
|
||||
QkvData<float>& qkv);
|
||||
AttentionData<float>& data);
|
||||
|
||||
template Status ConcatPastToPresent<half>(int batch_size, int num_heads, int qk_head_size, int v_head_size,
|
||||
int sequence_length, int total_sequence_length, bool pass_past_in_kv,
|
||||
cudaStream_t stream,
|
||||
int max_threads_per_block,
|
||||
AttentionData<half>& data,
|
||||
QkvData<half>& qkv);
|
||||
AttentionData<half>& data);
|
||||
|
||||
// ----------------------------------------------------------------------------------
|
||||
// Below kernels are for past and present sharing buffer
|
||||
// ----------------------------------------------------------------------------------
|
||||
|
||||
template <typename T>
|
||||
__global__ void AddBiasTransAppendKvToPresentSmall(
|
||||
const T* qkv, const T* biases, T* present,
|
||||
const int head_size, const int past_sequence_length, const int max_sequence_length) {
|
||||
// Input: BxSxMxNxH (Format 1)
|
||||
// Output: (2, B, N, [P..P+S) of MaxS, H),
|
||||
// B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size
|
||||
const int n = threadIdx.y;
|
||||
const int s = blockIdx.x;
|
||||
const int b = blockIdx.y;
|
||||
const int N = blockDim.y;
|
||||
const int S = gridDim.x;
|
||||
const int B = gridDim.y;
|
||||
|
||||
constexpr int M = 3; // Matrix count in qkv
|
||||
const int m = blockIdx.z + 1; // k = 1, v = 2
|
||||
|
||||
const int NH = N * head_size;
|
||||
const int NHS = NH * S;
|
||||
|
||||
qkv += (n * head_size + (s * M + m) * NH + b * M * NHS);
|
||||
if (biases) {
|
||||
biases += (m * NH + n * head_size);
|
||||
}
|
||||
|
||||
const int MsH = max_sequence_length * head_size;
|
||||
const int NMsH = N * MsH;
|
||||
const int BNMsH = B * NMsH;
|
||||
present += ((past_sequence_length + s) * head_size + n * MsH + b * NMsH + (m - 1) * BNMsH);
|
||||
|
||||
for (int h = threadIdx.x; h < head_size; h += blockDim.x) {
|
||||
T bias = (biases ? biases[h] : (T)0.0f);
|
||||
present[h] = qkv[h] + bias;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void AddBiasTransAppendKvToPresent(
|
||||
const T* qkv, const T* biases, T* present,
|
||||
const int head_size, const int past_sequence_length, const int max_sequence_length) {
|
||||
// Input: BxSxMxNxH (Format 1)
|
||||
// Output: (2, B, N, [P..P+S) of MaxS, H),
|
||||
// B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size
|
||||
const int n = blockIdx.x;
|
||||
const int s = blockIdx.y;
|
||||
const int b = (blockIdx.z >> 1);
|
||||
const int N = gridDim.x;
|
||||
const int S = gridDim.y;
|
||||
const int B = (gridDim.z >> 1);
|
||||
|
||||
constexpr int M = 3; // Matrix count in qkv
|
||||
const int m = (blockIdx.z & 0x1) + 1; // k = 1, v = 2
|
||||
|
||||
const int NH = N * head_size;
|
||||
const int NHS = NH * S;
|
||||
|
||||
qkv += (n * head_size + (s * M + m) * NH + b * M * NHS);
|
||||
if (biases) {
|
||||
biases += (m * NH + n * head_size);
|
||||
}
|
||||
|
||||
const int MsH = max_sequence_length * head_size;
|
||||
const int NMsH = N * MsH;
|
||||
const int BNMsH = B * NMsH;
|
||||
present += ((past_sequence_length + s) * head_size + n * MsH + b * NMsH + (m - 1) * BNMsH);
|
||||
|
||||
for (int h = threadIdx.x; h < head_size; h += blockDim.x) {
|
||||
T bias = (biases ? biases[h] : (T)0.0f);
|
||||
present[h] = qkv[h] + bias;
|
||||
}
|
||||
}
|
||||
|
||||
// qkv buffer is merged tensor of shape (B,S,3,N,H), k v is the second/third of the 3.
|
||||
// bias is of shape (3, NxH) or nullptr
|
||||
// append to present of (2, B, N, (P..T) of M, H),
|
||||
template <typename T>
|
||||
Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream,
|
||||
const int max_sequence_length,
|
||||
const int past_sequence_length,
|
||||
const int sequence_length,
|
||||
const int batch_size,
|
||||
const int head_size,
|
||||
const int num_heads,
|
||||
const int max_threads_per_block,
|
||||
const T* biases,
|
||||
const T* qkv_buffer,
|
||||
T* present) {
|
||||
assert(head_size <= (1 << 30));
|
||||
|
||||
int64_t nh = (int64_t)head_size * num_heads;
|
||||
if (nh <= max_threads_per_block) {
|
||||
const dim3 grid(sequence_length, batch_size, 2); // 2 for k and v
|
||||
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
|
||||
|
||||
AddBiasTransAppendKvToPresentSmall<T><<<grid, block, 0, stream>>>(
|
||||
qkv_buffer, biases, present, head_size, past_sequence_length, max_sequence_length);
|
||||
} else {
|
||||
const dim3 grid(num_heads, sequence_length, batch_size * 2); // 2 for k and v
|
||||
const dim3 block(std::min(head_size, max_threads_per_block), 1, 1);
|
||||
AddBiasTransAppendKvToPresent<T><<<grid, block, 0, stream>>>(
|
||||
qkv_buffer, biases, present, head_size, past_sequence_length, max_sequence_length);
|
||||
}
|
||||
|
||||
return CUDA_CALL(cudaGetLastError());
|
||||
}
|
||||
|
||||
template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream,
|
||||
const int max_sequence_length,
|
||||
const int total_sequence_length,
|
||||
const int sequence_length,
|
||||
const int batch_size,
|
||||
const int head_size,
|
||||
const int num_heads,
|
||||
const int max_threads_per_block,
|
||||
const float* bias,
|
||||
const float* qkv_buffer,
|
||||
float* present);
|
||||
|
||||
template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream,
|
||||
const int max_sequence_length,
|
||||
const int total_sequence_length,
|
||||
const int sequence_length,
|
||||
const int batch_size,
|
||||
const int head_size,
|
||||
const int num_heads,
|
||||
const int max_threads_per_block,
|
||||
const half* bias,
|
||||
const half* qkv_buffer,
|
||||
half* present);
|
||||
#endif
|
||||
|
||||
} // namespace cuda
|
||||
|
|
@ -1,9 +1,8 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
#include "contrib_ops/cuda/bert/attention_impl.h"
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
#include "contrib_ops/cuda/bert/add_bias_transpose.h"
|
||||
#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
|
||||
|
||||
|
|
@ -406,22 +405,25 @@ Status PrepareQkv_MHA_NotPacked(contrib::AttentionParameters& parameters,
|
|||
|
||||
// Query (BxSxNxH) => Q (BxNxSxH)
|
||||
constexpr int format = 0;
|
||||
LaunchAddBiasTranspose<T>(stream, 1, format, max_threads_per_block,
|
||||
batch_size, sequence_length, num_heads, qk_head_size,
|
||||
data.query, data.bias, q,
|
||||
true, -1);
|
||||
LaunchAddBiasTranspose<T>(
|
||||
stream, 1, format, max_threads_per_block,
|
||||
batch_size, sequence_length, num_heads, qk_head_size,
|
||||
data.query, data.bias, q,
|
||||
true, -1);
|
||||
|
||||
// Key (BxLxNxH) => K (BxNxLxH)
|
||||
LaunchAddBiasTranspose<T>(stream, 1, format, max_threads_per_block,
|
||||
batch_size, kv_sequence_length, num_heads, qk_head_size,
|
||||
data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, k,
|
||||
true, -1);
|
||||
LaunchAddBiasTranspose<T>(
|
||||
stream, 1, format, max_threads_per_block,
|
||||
batch_size, kv_sequence_length, num_heads, qk_head_size,
|
||||
data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, k,
|
||||
true, -1);
|
||||
|
||||
// Value (BxLxNxH_v) => K (BxNxLxH_v)
|
||||
LaunchAddBiasTranspose<T>(stream, 1, format, max_threads_per_block,
|
||||
batch_size, kv_sequence_length, num_heads, v_head_size,
|
||||
data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, v,
|
||||
true, -1);
|
||||
LaunchAddBiasTranspose<T>(
|
||||
stream, 1, format, max_threads_per_block,
|
||||
batch_size, kv_sequence_length, num_heads, v_head_size,
|
||||
data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, v,
|
||||
true, -1);
|
||||
|
||||
DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size);
|
||||
DUMP_TENSOR_D("k(BNSH)", k, batch_size, num_heads, kv_sequence_length, qk_head_size);
|
||||
|
|
@ -435,8 +437,8 @@ template <typename T>
|
|||
Status PrepareQkv(contrib::AttentionParameters& parameters,
|
||||
AttentionData<T>& data,
|
||||
cudaStream_t stream,
|
||||
int max_threads_per_block,
|
||||
QkvData<T>& qkv) {
|
||||
int max_threads_per_block) {
|
||||
data.scratch = data.workspace;
|
||||
if (data.has_qkv_workspace) {
|
||||
const int size_per_batch_q = parameters.sequence_length * parameters.head_size;
|
||||
const int size_per_batch_k = parameters.kv_sequence_length * parameters.head_size;
|
||||
|
|
@ -445,27 +447,27 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,
|
|||
const size_t elements_q = static_cast<size_t>(batches) * static_cast<size_t>(size_per_batch_q);
|
||||
const size_t elements_k = static_cast<size_t>(batches) * static_cast<size_t>(size_per_batch_k);
|
||||
const size_t elements_v = static_cast<size_t>(batches) * static_cast<size_t>(size_per_batch_v);
|
||||
qkv.q = data.workspace;
|
||||
qkv.k = data.workspace + elements_q;
|
||||
qkv.v = qkv.k + elements_k;
|
||||
qkv.after_v = qkv.v + elements_v;
|
||||
data.q = data.workspace;
|
||||
data.k = data.workspace + elements_q;
|
||||
data.v = data.k + elements_k;
|
||||
data.scratch = data.v + elements_v;
|
||||
}
|
||||
|
||||
if (nullptr != data.gemm_buffer) { // Attention operator
|
||||
ORT_RETURN_IF_ERROR(PrepareQkv_Attention<T>(parameters, data, stream, max_threads_per_block,
|
||||
qkv.format));
|
||||
data.qkv_format));
|
||||
} else if (data.past_key != nullptr || data.present_key != nullptr) { // mha operator with past/present state
|
||||
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast(parameters, data, stream, max_threads_per_block,
|
||||
qkv.q, qkv.k, qkv.v, qkv.format));
|
||||
data.q, data.k, data.v, data.qkv_format));
|
||||
} else if (data.key == nullptr) { // multihead attention operator, no past, packed qkv
|
||||
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedQKV(parameters, data, stream, max_threads_per_block,
|
||||
qkv.q, qkv.k, qkv.v, qkv.format));
|
||||
data.q, data.k, data.v, data.qkv_format));
|
||||
} else if (data.value == nullptr) { // multihead attention operator, no past, packed kv
|
||||
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedKV(parameters, data, stream, max_threads_per_block,
|
||||
qkv.q, qkv.k, qkv.v, qkv.format));
|
||||
data.q, data.k, data.v, data.qkv_format));
|
||||
} else { // multihead attention operator, no past, separated Q/K/V inputs
|
||||
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NotPacked(parameters, data, stream, max_threads_per_block,
|
||||
qkv.q, qkv.k, qkv.v, qkv.format));
|
||||
data.q, data.k, data.v, data.qkv_format));
|
||||
}
|
||||
|
||||
CUDA_RETURN_IF_ERROR(cudaGetLastError());
|
||||
|
|
@ -477,15 +479,13 @@ template Status PrepareQkv<float>(
|
|||
contrib::AttentionParameters& parameters,
|
||||
AttentionData<float>& data,
|
||||
cudaStream_t stream,
|
||||
int max_threads_per_block,
|
||||
QkvData<float>& qkv);
|
||||
int max_threads_per_block);
|
||||
|
||||
template Status PrepareQkv<half>(
|
||||
contrib::AttentionParameters& parameters,
|
||||
AttentionData<half>& data,
|
||||
cudaStream_t stream,
|
||||
int max_threads_per_block,
|
||||
QkvData<half>& qkv);
|
||||
int max_threads_per_block);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ limitations under the License.
|
|||
*/
|
||||
|
||||
#include <cub/cub.cuh>
|
||||
#include <cuda_fp16.h>
|
||||
#include <math_constants.h>
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "contrib_ops/cuda/bert/attention_impl.h"
|
||||
#include "contrib_ops/cuda/bert/decoder_attention.h"
|
||||
#include "contrib_ops/cuda/bert/decoder_attention_impl.h"
|
||||
#include "contrib_ops/cuda/bert/transformer_cuda_common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/providers/cuda/shared_inc/fpgeneric.h"
|
||||
|
|
@ -85,7 +85,8 @@ Status CheckInputs(const TensorShape& query_shape,
|
|||
}
|
||||
|
||||
if (kv_weights_dims[0] != hidden_size || kv_weights_dims[1] != 2 * static_cast<int64_t>(hidden_size)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "kv_weights shall have shape (hidden size, 2 * hidden size)");
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"kv_weights shall have shape (hidden size, 2 * hidden size)");
|
||||
}
|
||||
|
||||
const auto& bias_dims = bias_shape.GetDims();
|
||||
|
|
@ -137,7 +138,8 @@ Status CheckInputs(const TensorShape& query_shape,
|
|||
|
||||
const auto& value_cache_dims = value_cache->Shape().GetDims();
|
||||
if (value_cache_dims.size() != 4) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value_cache' is expected to have 4 dimension, got ",
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'value_cache' is expected to have 4 dimension, got ",
|
||||
value_cache_dims.size());
|
||||
}
|
||||
|
||||
|
|
@ -353,10 +355,12 @@ Status DecoderAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
}
|
||||
}
|
||||
|
||||
size_t bytes = element_size * batch_size * (static_cast<size_t>(sequence_length) + static_cast<size_t>(2) * kv_sequence_length) * hidden_size;
|
||||
size_t bytes = element_size * batch_size *
|
||||
(static_cast<size_t>(sequence_length) + static_cast<size_t>(2) * kv_sequence_length) * hidden_size;
|
||||
auto qkv_buffer_p = GetScratchBuffer<void>(bytes, context->GetComputeStream());
|
||||
|
||||
bytes = element_size * 2 * batch_size * sequence_length * num_heads_ * (static_cast<size_t>(2) * head_size + static_cast<size_t>(kv_sequence_length));
|
||||
bytes = element_size * 2 * batch_size * sequence_length * num_heads_ *
|
||||
(static_cast<size_t>(2) * head_size + static_cast<size_t>(kv_sequence_length));
|
||||
auto workspace_p = GetScratchBuffer<void>(bytes, context->GetComputeStream());
|
||||
|
||||
Tensor* output(context->Output(0, query_shape));
|
||||
|
|
|
|||
263
onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu
Normal file
263
onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu
Normal file
|
|
@ -0,0 +1,263 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "contrib_ops/cuda/bert/decoder_attention_impl.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/shared_inc/fpgeneric.h"
|
||||
#include "contrib_ops/cuda/bert/attention_softmax.h"
|
||||
|
||||
using namespace onnxruntime::contrib::attention_softmax_cuda;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
template <typename T>
|
||||
Status DecoderQkvToContext(
|
||||
const cudaDeviceProp& device_prop,
|
||||
Stream* ort_stream,
|
||||
cublasHandle_t& cublas,
|
||||
const size_t element_size,
|
||||
const int batch_size,
|
||||
const int sequence_length,
|
||||
const int kv_sequence_length,
|
||||
const int num_heads,
|
||||
const int head_size,
|
||||
const bool static_kv,
|
||||
const bool use_past,
|
||||
const bool has_layer_state,
|
||||
const bool has_key_padding_mask,
|
||||
const float mask_filter_value,
|
||||
const T* gemm_query_buffer,
|
||||
const T* gemm_kv_buffer,
|
||||
const bool* key_padding_mask,
|
||||
const T* key_cache,
|
||||
const T* value_cache,
|
||||
T* qkv_buffer,
|
||||
T* workspace_buffer,
|
||||
T* output,
|
||||
T* new_key_cache,
|
||||
T* new_value_cache) {
|
||||
const int max_threads_per_block = device_prop.maxThreadsPerBlock;
|
||||
const int BN = batch_size * num_heads;
|
||||
const int BHN = BN * head_size;
|
||||
const int BNS = BN * sequence_length;
|
||||
const int k_buffer_offset = sequence_length * BHN;
|
||||
const int v_buffer_offset = (sequence_length + kv_sequence_length) * BHN;
|
||||
|
||||
T* temp_qkv_buffer = workspace_buffer;
|
||||
auto stream = static_cast<cudaStream_t>(ort_stream->GetHandle());
|
||||
|
||||
const T* q = qkv_buffer;
|
||||
// transpose q and copy them to qkv_buffer
|
||||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, head_size, num_heads,
|
||||
max_threads_per_block, true, gemm_query_buffer, qkv_buffer));
|
||||
|
||||
const T* k = qkv_buffer + k_buffer_offset;
|
||||
const T* v = qkv_buffer + v_buffer_offset;
|
||||
if (!has_layer_state || !use_past) {
|
||||
if (!static_kv) {
|
||||
// transpose kv and copy them to qkv_buffer
|
||||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads,
|
||||
max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset));
|
||||
} else {
|
||||
// transpose kv and copy them to qkv_buffer
|
||||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, kv_sequence_length, batch_size, head_size, num_heads,
|
||||
max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset));
|
||||
}
|
||||
} else {
|
||||
if (!static_kv) {
|
||||
// transpose kv and copy them to temp_buffer
|
||||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads,
|
||||
max_threads_per_block, true, gemm_kv_buffer, temp_qkv_buffer));
|
||||
// concat cache-k with k and copy to qkv_buffer
|
||||
if (nullptr != key_cache) {
|
||||
ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length,
|
||||
sequence_length, batch_size, head_size, num_heads,
|
||||
max_threads_per_block, 1,
|
||||
key_cache,
|
||||
temp_qkv_buffer,
|
||||
qkv_buffer + k_buffer_offset));
|
||||
}
|
||||
// concat cache-v with v and copy to qkv_buffer
|
||||
if (nullptr != value_cache) {
|
||||
ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length,
|
||||
sequence_length, batch_size, head_size, num_heads,
|
||||
max_threads_per_block, 1,
|
||||
value_cache,
|
||||
temp_qkv_buffer + k_buffer_offset,
|
||||
qkv_buffer + v_buffer_offset));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (has_layer_state) {
|
||||
if (use_past && static_kv) {
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_key_cache, key_cache, kv_sequence_length * BHN * sizeof(T),
|
||||
cudaMemcpyDeviceToDevice, stream));
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_value_cache, value_cache, kv_sequence_length * BHN * sizeof(T),
|
||||
cudaMemcpyDeviceToDevice, stream));
|
||||
} else {
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_key_cache, k, kv_sequence_length * BHN * sizeof(T),
|
||||
cudaMemcpyDeviceToDevice, stream));
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_value_cache, v, kv_sequence_length * BHN * sizeof(T),
|
||||
cudaMemcpyDeviceToDevice, stream));
|
||||
}
|
||||
}
|
||||
|
||||
// scratch1: BxNxSxL buffer
|
||||
// scratch2: BxNxSxL buffer
|
||||
// scratch3: BxNxSxH buffer
|
||||
T* scratch1 = temp_qkv_buffer + 3 * BHN * sequence_length;
|
||||
T* scratch2 = scratch1 + BNS * kv_sequence_length;
|
||||
T* scratch3 = scratch2 + BNS * kv_sequence_length;
|
||||
|
||||
// compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxL
|
||||
// Q: BxNxSxH, K (present_k): BxNxLxH, Q*K': BxNxSxL
|
||||
const float rsqrt_head_size = 1.f / sqrt(static_cast<float>(head_size));
|
||||
const int temp_matrix_size = sequence_length * kv_sequence_length;
|
||||
float one = 1.0f;
|
||||
float zero = 0.f;
|
||||
|
||||
float alpha = rsqrt_head_size;
|
||||
const int strideA = kv_sequence_length * head_size;
|
||||
const int strideB = sequence_length * head_size;
|
||||
if (use_past && static_kv) {
|
||||
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
|
||||
cublas, CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
kv_sequence_length, sequence_length, head_size,
|
||||
&alpha, key_cache, head_size, strideA,
|
||||
q, head_size, strideB,
|
||||
&zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop));
|
||||
} else {
|
||||
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
|
||||
cublas, CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
kv_sequence_length, sequence_length, head_size,
|
||||
&alpha, k, head_size, strideA,
|
||||
q, head_size, strideB,
|
||||
&zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop));
|
||||
}
|
||||
|
||||
constexpr bool is_unidirectional = false;
|
||||
const T* add_before_softmax = nullptr;
|
||||
if (has_key_padding_mask) {
|
||||
constexpr int mask_dimension = 2;
|
||||
constexpr int max_sequence_length = 0;
|
||||
ORT_RETURN_IF_ERROR(ComputeSoftmaxWithRawMask<T>(
|
||||
ort_stream, kv_sequence_length, sequence_length, batch_size,
|
||||
num_heads, nullptr, key_padding_mask, add_before_softmax,
|
||||
false /*broadcast rpb*/, scratch1, scratch2, is_unidirectional,
|
||||
1.0f, mask_dimension, max_sequence_length, false, nullptr,
|
||||
mask_filter_value));
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(ComputeSoftmax<T>(
|
||||
stream, kv_sequence_length, sequence_length, batch_size, num_heads,
|
||||
add_before_softmax, false /*broadcast rpb*/, scratch1, scratch2,
|
||||
is_unidirectional));
|
||||
}
|
||||
|
||||
// compute P*V (as V*P), and store in scratch3: BxNxSxH
|
||||
if (use_past && static_kv) {
|
||||
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
|
||||
cublas, CUBLAS_OP_N, CUBLAS_OP_N,
|
||||
head_size, sequence_length, kv_sequence_length,
|
||||
&one, value_cache, head_size, strideA,
|
||||
scratch2, kv_sequence_length, temp_matrix_size,
|
||||
&zero, scratch3, head_size, strideB, BN, device_prop));
|
||||
} else {
|
||||
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
|
||||
cublas, CUBLAS_OP_N, CUBLAS_OP_N,
|
||||
head_size, sequence_length, kv_sequence_length,
|
||||
&one, v, head_size, strideA,
|
||||
scratch2, kv_sequence_length, temp_matrix_size,
|
||||
&zero, scratch3, head_size, strideB, BN, device_prop));
|
||||
}
|
||||
|
||||
// scratch3 is BxNxSxH, transpose to output SxBxNxH
|
||||
return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads,
|
||||
max_threads_per_block, true, scratch3, output);
|
||||
}
|
||||
|
||||
Status LaunchDecoderAttentionKernel(
|
||||
const cudaDeviceProp& device_prop,
|
||||
Stream* stream,
|
||||
cublasHandle_t& cublas,
|
||||
const size_t element_size,
|
||||
const int batch_size,
|
||||
const int sequence_length,
|
||||
const int kv_sequence_length,
|
||||
const int num_heads,
|
||||
const int head_size,
|
||||
const bool static_kv,
|
||||
const bool use_past,
|
||||
const bool has_layer_state,
|
||||
const bool has_key_padding_mask,
|
||||
const float mask_filter_value,
|
||||
const void* gemm_query_buffer,
|
||||
const void* gemm_kv_buffer,
|
||||
const bool* key_padding_mask,
|
||||
const void* key_cache,
|
||||
const void* value_cache,
|
||||
void* qkv_buffer,
|
||||
void* workspace_buffer,
|
||||
void* output,
|
||||
void* new_key_cache,
|
||||
void* new_value_cache) {
|
||||
if (element_size == 2) {
|
||||
return DecoderQkvToContext(
|
||||
device_prop,
|
||||
stream,
|
||||
cublas,
|
||||
element_size,
|
||||
batch_size,
|
||||
sequence_length,
|
||||
kv_sequence_length,
|
||||
num_heads,
|
||||
head_size,
|
||||
static_kv,
|
||||
use_past,
|
||||
has_layer_state,
|
||||
has_key_padding_mask,
|
||||
mask_filter_value,
|
||||
reinterpret_cast<const half*>(gemm_query_buffer),
|
||||
reinterpret_cast<const half*>(gemm_kv_buffer),
|
||||
key_padding_mask,
|
||||
reinterpret_cast<const half*>(key_cache),
|
||||
reinterpret_cast<const half*>(value_cache),
|
||||
reinterpret_cast<half*>(qkv_buffer),
|
||||
reinterpret_cast<half*>(workspace_buffer),
|
||||
reinterpret_cast<half*>(output),
|
||||
reinterpret_cast<half*>(new_key_cache),
|
||||
reinterpret_cast<half*>(new_value_cache));
|
||||
} else {
|
||||
return DecoderQkvToContext(
|
||||
device_prop,
|
||||
stream,
|
||||
cublas,
|
||||
element_size,
|
||||
batch_size,
|
||||
sequence_length,
|
||||
kv_sequence_length,
|
||||
num_heads,
|
||||
head_size,
|
||||
static_kv,
|
||||
use_past,
|
||||
has_layer_state,
|
||||
has_key_padding_mask,
|
||||
mask_filter_value,
|
||||
reinterpret_cast<const float*>(gemm_query_buffer),
|
||||
reinterpret_cast<const float*>(gemm_kv_buffer),
|
||||
key_padding_mask,
|
||||
reinterpret_cast<const float*>(key_cache),
|
||||
reinterpret_cast<const float*>(value_cache),
|
||||
reinterpret_cast<float*>(qkv_buffer),
|
||||
reinterpret_cast<float*>(workspace_buffer),
|
||||
reinterpret_cast<float*>(output),
|
||||
reinterpret_cast<float*>(new_key_cache),
|
||||
reinterpret_cast<float*>(new_value_cache));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
41
onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h
Normal file
41
onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "contrib_ops/cuda/bert/attention_impl.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
Status LaunchDecoderAttentionKernel(
|
||||
const cudaDeviceProp& prop, // Device Properties
|
||||
Stream* stream, // ORT Stream
|
||||
cublasHandle_t& cublas, // Cublas handle
|
||||
const size_t element_size, // Element size of input tensor
|
||||
const int batch_size, // Batch size (B)
|
||||
const int sequence_length, // Sequence length (S)
|
||||
const int kv_sequence_length, // Key/Value/Cache sequence length
|
||||
const int num_heads, // Number of attention heads (N)
|
||||
const int head_size, // Hidden size per head (H)
|
||||
const bool static_kv, // Whether cross attention or not
|
||||
const bool use_past, // Whether use cache or not
|
||||
const bool has_layer_state, // Whether output cache or not
|
||||
const bool has_key_padding_mask, // Whether use key_padding_mask or not
|
||||
const float mask_filter_value, // Mask filter value
|
||||
const void* gemm_query_buffer, // Query buffer
|
||||
const void* gemm_kv_buffer, // Key and value buffer
|
||||
const bool* key_padding_mask, // Key padding mask
|
||||
const void* key_cache, // Input key cache
|
||||
const void* value_cache, // Input value cache
|
||||
void* qkv_buffer, // Temporary buffer
|
||||
void* workspace_buffer, // Temporary buffer
|
||||
void* output, // Output tensor
|
||||
void* new_key_cache, // New_key_cache tensor
|
||||
void* new_value_cache // New_value_cache tensor
|
||||
);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -30,6 +30,7 @@ limitations under the License.
|
|||
#include "contrib_ops/cpu/bert/attention_base.h"
|
||||
#include "contrib_ops/rocm/bert/attention_impl.h"
|
||||
#include "contrib_ops/rocm/bert/attention_softmax.h"
|
||||
#include "contrib_ops/rocm/bert/decoder_attention_impl.h"
|
||||
|
||||
using namespace onnxruntime::rocm;
|
||||
|
||||
|
|
|
|||
|
|
@ -28,34 +28,6 @@ size_t GetAttentionWorkspaceSize(
|
|||
int sequence_length,
|
||||
int past_sequence_length);
|
||||
|
||||
Status LaunchDecoderAttentionKernel(
|
||||
const hipDeviceProp_t& prop, // Device Properties
|
||||
RocmTuningContext* tuning_ctx, // context for tuning
|
||||
Stream* stream, // ORT Stream
|
||||
rocblas_handle& rocblas, // Rocblas handle
|
||||
const size_t element_size, // Element size of input tensor
|
||||
const int batch_size, // Batch size (B)
|
||||
const int sequence_length, // Sequence length (S)
|
||||
const int kv_sequence_length, // Key/Value/Cache sequence length
|
||||
const int num_heads, // Number of attention heads (N)
|
||||
const int head_size, // Hidden layer size per head (H)
|
||||
const bool static_kv, // Whether cross attention or not
|
||||
const bool use_past, // Whether use cache or not
|
||||
const bool has_layer_state, // Whether output cache or not
|
||||
const bool has_key_padding_mask, // Whether use key_padding_mask or not
|
||||
const float mask_filter_value, // Mask filter value
|
||||
const void* gemm_query_buffer, // Query buffer
|
||||
const void* gemm_kv_buffer, // Key and value buffer
|
||||
const bool* key_padding_mask, // Key padding mask
|
||||
const void* key_cache, // Input key cache
|
||||
const void* value_cache, // Input value cache
|
||||
void* qkv_buffer, // Temporary buffer
|
||||
void* workspace_buffer, // Temporary buffer
|
||||
void* output, // Output tensor
|
||||
void* new_key_cache, // New_key_cache tensor
|
||||
void* new_value_cache // New_value_cache tensor
|
||||
);
|
||||
|
||||
Status LaunchTransCtx(hipStream_t stream,
|
||||
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
|
||||
const int max_threads_per_block, const bool reversed_bs, const float* input, float* output);
|
||||
|
|
|
|||
46
onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h
Normal file
46
onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <rocblas/rocblas.h>
|
||||
#include "contrib_ops/cpu/bert/attention_common.h"
|
||||
#include "core/providers/rocm/shared_inc/rocm_utils.h"
|
||||
#include "core/providers/rocm/tunable/rocm_tunable.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
Status LaunchDecoderAttentionKernel(
|
||||
const hipDeviceProp_t& prop, // Device Properties
|
||||
RocmTuningContext* tuning_ctx, // context for tuning
|
||||
Stream* stream, // ORT Stream
|
||||
rocblas_handle& rocblas, // Rocblas handle
|
||||
const size_t element_size, // Element size of input tensor
|
||||
const int batch_size, // Batch size (B)
|
||||
const int sequence_length, // Sequence length (S)
|
||||
const int kv_sequence_length, // Key/Value/Cache sequence length
|
||||
const int num_heads, // Number of attention heads (N)
|
||||
const int head_size, // Hidden layer size per head (H)
|
||||
const bool static_kv, // Whether cross attention or not
|
||||
const bool use_past, // Whether use cache or not
|
||||
const bool has_layer_state, // Whether output cache or not
|
||||
const bool has_key_padding_mask, // Whether use key_padding_mask or not
|
||||
const float mask_filter_value, // Mask filter value
|
||||
const void* gemm_query_buffer, // Query buffer
|
||||
const void* gemm_kv_buffer, // Key and value buffer
|
||||
const bool* key_padding_mask, // Key padding mask
|
||||
const void* key_cache, // Input key cache
|
||||
const void* value_cache, // Input value cache
|
||||
void* qkv_buffer, // Temporary buffer
|
||||
void* workspace_buffer, // Temporary buffer
|
||||
void* output, // Output tensor
|
||||
void* new_key_cache, // New_key_cache tensor
|
||||
void* new_value_cache // New_value_cache tensor
|
||||
);
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue