mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
Refactoring of attention cuda kernel: move prepare qkv and concat_past_to_present (#17559)
To avoid a huge cu file and make code more readable: - Move PrepareQKV to separate cu file (attention_prepare_qkv.cu) - Move ConcatPastToPresent to attention_concat.cu - Add default value for AttentionData - Add a data structure QkvData to track Q, K and V pointers and track QKV format.
This commit is contained in:
parent
af80542e65
commit
adb0be45d3
8 changed files with 729 additions and 674 deletions
|
|
@ -10,6 +10,7 @@ set(contrib_ops_excluded_files
|
|||
"bert/attention_impl.cu"
|
||||
"bert/attention_softmax.h"
|
||||
"bert/attention_softmax.cu"
|
||||
"bert/attention_prepare_qkv.cu"
|
||||
"bert/decoder_masked_multihead_attention.h"
|
||||
"bert/decoder_masked_multihead_attention.cc"
|
||||
"bert/decoder_masked_self_attention.h"
|
||||
|
|
|
|||
|
|
@ -249,30 +249,28 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
AttentionData<CudaT> data;
|
||||
data.gemm_buffer = reinterpret_cast<CudaT*>(gemm_buffer.get());
|
||||
data.bias = nullptr == bias ? nullptr : reinterpret_cast<const CudaT*>(bias->Data<T>());
|
||||
data.query = nullptr;
|
||||
data.key = nullptr;
|
||||
data.value = nullptr;
|
||||
data.mask_index = (nullptr == mask_index) ? nullptr : mask_index->Data<int>();
|
||||
data.mask_index_dims = (nullptr == mask_index) ? gsl::span<const int64_t>() : mask_index->Shape().GetDims();
|
||||
data.past = (nullptr == past) ? nullptr : reinterpret_cast<const CudaT*>(past->Data<T>());
|
||||
data.past_key = nullptr;
|
||||
data.past_value = nullptr;
|
||||
data.relative_position_bias = (nullptr == relative_position_bias)
|
||||
? nullptr
|
||||
: reinterpret_cast<const CudaT*>(relative_position_bias->Data<T>());
|
||||
if (nullptr != bias) {
|
||||
data.bias = reinterpret_cast<const CudaT*>(bias->Data<T>());
|
||||
}
|
||||
if (nullptr != mask_index) {
|
||||
data.mask_index = mask_index->Data<int>();
|
||||
data.mask_index_dims = mask_index->Shape().GetDims();
|
||||
}
|
||||
if (nullptr != past) {
|
||||
data.past = reinterpret_cast<const CudaT*>(past->Data<T>());
|
||||
}
|
||||
if (nullptr != relative_position_bias) {
|
||||
data.relative_position_bias = reinterpret_cast<const CudaT*>(relative_position_bias->Data<T>());
|
||||
}
|
||||
data.has_qkv_workspace = true;
|
||||
data.workspace = reinterpret_cast<CudaT*>(work_space.get());
|
||||
data.output = reinterpret_cast<CudaT*>(output->MutableData<T>());
|
||||
data.present = (nullptr == present) ? nullptr : reinterpret_cast<CudaT*>(present->MutableData<T>());
|
||||
data.present_key = nullptr;
|
||||
data.present_value = nullptr;
|
||||
if (nullptr != present) {
|
||||
data.present = reinterpret_cast<CudaT*>(present->MutableData<T>());
|
||||
}
|
||||
data.fused_runner = reinterpret_cast<void*>(fused_runner);
|
||||
data.fused_cross_attention_kernel = nullptr;
|
||||
data.use_flash_attention = use_flash_attention;
|
||||
data.use_memory_efficient_attention = use_memory_efficient_attention;
|
||||
data.cumulated_sequence_length_q_cache = nullptr;
|
||||
data.cumulated_sequence_length_kv_cache = nullptr;
|
||||
|
||||
return QkvToContext<CudaT>(device_prop, cublas, context->GetComputeStream(), parameters, data);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -93,16 +93,16 @@ __global__ void ConcatTensorToTensorLarge(const int tensor_add_sequence_length,
|
|||
}
|
||||
|
||||
Status LaunchConcatTensorToTensor(cudaStream_t stream,
|
||||
const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int batch_size,
|
||||
const int head_size,
|
||||
const int num_heads,
|
||||
const int max_threads_per_block,
|
||||
const int matrix_num,
|
||||
const float* tensor_in,
|
||||
const float* tensor_add,
|
||||
float* tensor_out) {
|
||||
const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int batch_size,
|
||||
const int head_size,
|
||||
const int num_heads,
|
||||
const int max_threads_per_block,
|
||||
const int matrix_num,
|
||||
const float* tensor_in,
|
||||
const float* tensor_add,
|
||||
float* tensor_out) {
|
||||
const dim3 grid(all_sequence_length, batch_size, matrix_num);
|
||||
if (0 == (head_size & 1)) {
|
||||
const int H = head_size / 2;
|
||||
|
|
@ -137,16 +137,16 @@ Status LaunchConcatTensorToTensor(cudaStream_t stream,
|
|||
}
|
||||
|
||||
Status LaunchConcatTensorToTensor(cudaStream_t stream,
|
||||
const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int batch_size,
|
||||
const int head_size,
|
||||
const int num_heads,
|
||||
const int max_threads_per_block,
|
||||
const int matrix_num,
|
||||
const half* tensor_in,
|
||||
const half* tensor_add,
|
||||
half* tensor_out) {
|
||||
const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int batch_size,
|
||||
const int head_size,
|
||||
const int num_heads,
|
||||
const int max_threads_per_block,
|
||||
const int matrix_num,
|
||||
const half* tensor_in,
|
||||
const half* tensor_add,
|
||||
half* tensor_out) {
|
||||
const dim3 grid(all_sequence_length, batch_size, matrix_num);
|
||||
if (0 == (head_size % 4)) {
|
||||
const int H = head_size / 4;
|
||||
|
|
@ -197,15 +197,15 @@ Status LaunchConcatTensorToTensor(cudaStream_t stream,
|
|||
}
|
||||
|
||||
Status LaunchConcatPastToPresent(cudaStream_t stream,
|
||||
const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int batch_size,
|
||||
const int head_size,
|
||||
const int num_heads,
|
||||
const int max_threads_per_block,
|
||||
const float* past,
|
||||
const float* k_v,
|
||||
float* present) {
|
||||
const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int batch_size,
|
||||
const int head_size,
|
||||
const int num_heads,
|
||||
const int max_threads_per_block,
|
||||
const float* past,
|
||||
const float* k_v,
|
||||
float* present) {
|
||||
return LaunchConcatTensorToTensor(
|
||||
stream,
|
||||
all_sequence_length,
|
||||
|
|
@ -221,15 +221,15 @@ Status LaunchConcatPastToPresent(cudaStream_t stream,
|
|||
}
|
||||
|
||||
Status LaunchConcatPastToPresent(cudaStream_t stream,
|
||||
const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int batch_size,
|
||||
const int head_size,
|
||||
const int num_heads,
|
||||
const int max_threads_per_block,
|
||||
const half* past,
|
||||
const half* k_v,
|
||||
half* present) {
|
||||
const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int batch_size,
|
||||
const int head_size,
|
||||
const int num_heads,
|
||||
const int max_threads_per_block,
|
||||
const half* past,
|
||||
const half* k_v,
|
||||
half* present) {
|
||||
return LaunchConcatTensorToTensor(
|
||||
stream,
|
||||
all_sequence_length,
|
||||
|
|
@ -244,6 +244,90 @@ Status LaunchConcatPastToPresent(cudaStream_t stream,
|
|||
present);
|
||||
}
|
||||
|
||||
#ifndef USE_ROCM // exclude from hipify
|
||||
template <typename T>
|
||||
Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size,
|
||||
int sequence_length, int total_sequence_length, bool pass_past_in_kv,
|
||||
cudaStream_t stream,
|
||||
int max_threads_per_block,
|
||||
AttentionData<T>& data,
|
||||
QkvData<T>& qkv) {
|
||||
// Concat past key value to present (2xBxNxLxH), where L is kv_sequence_length and T is total_sequence_length.
|
||||
// past_k (BxNxPxH) + k (BxNxLxH) => present_k (BxNxTxH)
|
||||
// past_v (BxNxPxH) + v (BxNxLxH) => present_v (BxNxTxH)
|
||||
// When there is past state, the head size for Q/K/V shall be same: H == H_v.
|
||||
|
||||
if (nullptr != data.present) {
|
||||
assert(qkv.format == AttentionQkvFormat::Q_K_V_BNSH || qkv.format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH);
|
||||
ORT_RETURN_IF_ERROR(
|
||||
LaunchConcatPastToPresent(
|
||||
stream, total_sequence_length, sequence_length, batch_size, qk_head_size, num_heads,
|
||||
max_threads_per_block, data.past, qkv.k, data.present));
|
||||
|
||||
// Update pointers to present_k and present_v.
|
||||
qkv.k = data.present;
|
||||
qkv.v = data.present + batch_size * num_heads * total_sequence_length * qk_head_size;
|
||||
}
|
||||
|
||||
if (nullptr != data.past_key || nullptr != data.present_key) {
|
||||
if (nullptr != data.past_key && nullptr == data.present_key) {
|
||||
qkv.k = const_cast<T*>(data.past_key);
|
||||
qkv.v = const_cast<T*>(data.past_value);
|
||||
} else if (nullptr == data.past_key && nullptr != data.present_key) {
|
||||
if (qkv.format == AttentionQkvFormat::Q_K_V_BNSH) {
|
||||
qkv.k = data.present_key;
|
||||
qkv.v = data.present_value;
|
||||
} else {
|
||||
assert(qkv.format == AttentionQkvFormat::Q_K_V_BSNH);
|
||||
qkv.k = data.temp_k_workspace;
|
||||
qkv.v = data.temp_v_workspace;
|
||||
}
|
||||
} else if (pass_past_in_kv) {
|
||||
// past_key and past_value are used directly as key and value in attention computations
|
||||
qkv.k = const_cast<T*>(data.past_key);
|
||||
qkv.v = const_cast<T*>(data.past_value);
|
||||
|
||||
// This path has a memory copy from past_key and past_value to present_key and present_value
|
||||
// Avoid this path since the memory copy is unnecessary because past_key == present_key and
|
||||
// past_value == present_value
|
||||
int64_t k_size = (int64_t)batch_size * num_heads * total_sequence_length * qk_head_size;
|
||||
int64_t v_size = (int64_t)batch_size * num_heads * total_sequence_length * v_head_size;
|
||||
cudaMemcpyAsync(data.present_key, data.past_key, k_size * sizeof(T), cudaMemcpyDeviceToDevice, stream);
|
||||
cudaMemcpyAsync(data.present_value, data.past_value, v_size * sizeof(T), cudaMemcpyDeviceToDevice, stream);
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(
|
||||
LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length,
|
||||
batch_size, qk_head_size, num_heads,
|
||||
max_threads_per_block, 1, data.past_key, qkv.k, data.present_key));
|
||||
ORT_RETURN_IF_ERROR(
|
||||
LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length,
|
||||
batch_size, v_head_size, num_heads,
|
||||
max_threads_per_block, 1, data.past_value, qkv.v, data.present_value));
|
||||
// Update pointers to present_k and present_v.
|
||||
qkv.k = data.present_key;
|
||||
qkv.v = data.present_value;
|
||||
}
|
||||
}
|
||||
|
||||
return CUDA_CALL(cudaGetLastError());
|
||||
}
|
||||
|
||||
// Template Instantiation
|
||||
template Status ConcatPastToPresent<float>(int batch_size, int num_heads, int qk_head_size, int v_head_size,
|
||||
int sequence_length, int total_sequence_length, bool pass_past_in_kv,
|
||||
cudaStream_t stream,
|
||||
int max_threads_per_block,
|
||||
AttentionData<float>& data,
|
||||
QkvData<float>& qkv);
|
||||
|
||||
template Status ConcatPastToPresent<half>(int batch_size, int num_heads, int qk_head_size, int v_head_size,
|
||||
int sequence_length, int total_sequence_length, bool pass_past_in_kv,
|
||||
cudaStream_t stream,
|
||||
int max_threads_per_block,
|
||||
AttentionData<half>& data,
|
||||
QkvData<half>& qkv);
|
||||
#endif
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -26,16 +26,11 @@ limitations under the License.
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <cassert>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cub/cub.cuh>
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/shared_inc/fpgeneric.h"
|
||||
#include "contrib_ops/cuda/bert/attention_impl.h"
|
||||
#include "contrib_ops/cuda/bert/attention_softmax.h"
|
||||
#include "contrib_ops/cuda/bert/transformer_common.h"
|
||||
#include "contrib_ops/cuda/bert/add_bias_transpose.h"
|
||||
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h"
|
||||
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h"
|
||||
#include "contrib_ops/cpu/bert/attention_base.h"
|
||||
|
|
@ -43,6 +38,7 @@ limitations under the License.
|
|||
#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
|
||||
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
|
||||
#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
|
||||
#include "contrib_ops/cuda/bert/attention_impl.h"
|
||||
|
||||
using namespace onnxruntime::cuda;
|
||||
using namespace onnxruntime::contrib::attention_softmax_cuda;
|
||||
|
|
@ -286,446 +282,6 @@ template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream,
|
|||
const half* qkv_buffer,
|
||||
half* present);
|
||||
|
||||
template <typename T>
|
||||
Status PrepareQkv_Attention(contrib::AttentionParameters& parameters,
|
||||
AttentionData<T>& data,
|
||||
cudaStream_t stream,
|
||||
int max_threads_per_block,
|
||||
AttentionQkvFormat& qkv_format) {
|
||||
const int batch_size = parameters.batch_size;
|
||||
const int sequence_length = parameters.sequence_length;
|
||||
const int num_heads = parameters.num_heads;
|
||||
const int qk_head_size = parameters.head_size;
|
||||
const int v_head_size = parameters.v_head_size;
|
||||
const bool past_present_share_buffer = parameters.past_present_share_buffer;
|
||||
void* fused_runner = data.fused_runner;
|
||||
bool use_flash_or_efficient_attention = data.use_flash_attention || data.use_memory_efficient_attention;
|
||||
|
||||
T* qkv = data.workspace;
|
||||
|
||||
bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional);
|
||||
bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional);
|
||||
|
||||
if (data.bias == nullptr) {
|
||||
assert(nullptr == fused_runner);
|
||||
// For quantized attention, bias has been added so only need transpose here.
|
||||
// gemm_buffer should be BxSx3xNxH => qkv: 3xBxNxSxH
|
||||
assert(qk_head_size == v_head_size);
|
||||
int matrix_to_trans = (past_present_share_buffer ? 1 : 3);
|
||||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, matrix_to_trans, sequence_length, batch_size, qk_head_size, num_heads,
|
||||
max_threads_per_block, false, data.gemm_buffer, qkv, 3));
|
||||
qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
|
||||
} else {
|
||||
// For fused TRT attention, transpose qkv to BxSxNx3xH (format 2)
|
||||
// For flash or memory efficient attention, transpose to 3xBxSxNxH (format 3)
|
||||
// For unfused kernel, transpose to 3xBxNxSxH (format 1)
|
||||
// For fused causal kernel, use format 1 since we need have K and V to update present state,
|
||||
// at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel.
|
||||
const int format = (use_fused_kernel ? 2 : (use_flash_or_efficient_attention ? 3 : 1));
|
||||
qkv_format = use_fused_kernel
|
||||
? AttentionQkvFormat::QKV_BSN3H
|
||||
: (use_flash_or_efficient_attention
|
||||
? AttentionQkvFormat::Q_K_V_BSNH
|
||||
: (use_fused_causal
|
||||
? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH
|
||||
: AttentionQkvFormat::Q_K_V_BNSH));
|
||||
|
||||
// For fused causal, we will update gemm_buffer with bias directly.
|
||||
T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr;
|
||||
|
||||
int matrix_to_transpose = ((format == AttentionQkvFormat::Q_K_V_BNSH && past_present_share_buffer) ? 1 : 3);
|
||||
// format 1: BxSx(NH + NH + NH_v) => BxNxSxH + BxNxSxH + BxNxSxH_v
|
||||
// format 2: BxSx(NH + NH + NH) => BxSxNx(H + H + H)
|
||||
LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block,
|
||||
batch_size, sequence_length, num_heads, qk_head_size,
|
||||
data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias,
|
||||
3, parameters.do_rotary, parameters.past_sequence_length);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// For MultiHeadAttention with past state
|
||||
template <typename T>
|
||||
Status PrepareQkv_MHA_WithPast(contrib::AttentionParameters& parameters,
|
||||
AttentionData<T>& data,
|
||||
cudaStream_t stream,
|
||||
int max_threads_per_block,
|
||||
T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
|
||||
const int batch_size = parameters.batch_size;
|
||||
const int sequence_length = parameters.sequence_length;
|
||||
const int kv_sequence_length = parameters.kv_sequence_length;
|
||||
const int num_heads = parameters.num_heads;
|
||||
const int qk_head_size = parameters.head_size;
|
||||
const int v_head_size = parameters.v_head_size;
|
||||
|
||||
DUMP_TENSOR_INIT();
|
||||
|
||||
if (data.bias == nullptr) {
|
||||
// Below logic does not support fused attention with past without bias
|
||||
// When there is past state, the format shall be BxNxSxH, so we disable fused attention when there is past.
|
||||
|
||||
// cross attention with past state
|
||||
if (data.past_key != nullptr && data.present_key == nullptr) {
|
||||
assert(data.past_value != nullptr);
|
||||
assert(data.query != nullptr);
|
||||
assert(data.key == nullptr);
|
||||
assert(data.value == nullptr);
|
||||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
|
||||
max_threads_per_block, false, data.query, q));
|
||||
}
|
||||
// cross attention with present state or self attention with present state
|
||||
else if (data.past_key == nullptr && data.present_key != nullptr) {
|
||||
assert(data.past_value == nullptr);
|
||||
assert(data.present_value != nullptr);
|
||||
assert(data.query != nullptr);
|
||||
assert(data.key != nullptr);
|
||||
assert(data.value != nullptr);
|
||||
|
||||
// TODO: supporting packed qkv for self attention may benefit performance
|
||||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
|
||||
max_threads_per_block, false, data.query, q));
|
||||
|
||||
// TODO: supporting packed kv for cross attention may benefit performance
|
||||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
|
||||
max_threads_per_block, false, data.key, data.present_key));
|
||||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
|
||||
max_threads_per_block, false, data.value, data.present_value));
|
||||
}
|
||||
// self attention with past and present state
|
||||
else {
|
||||
assert(data.past_key != nullptr);
|
||||
assert(data.past_value != nullptr);
|
||||
assert(data.present_key != nullptr);
|
||||
assert(data.present_value != nullptr);
|
||||
assert(data.query != nullptr);
|
||||
assert(data.key != nullptr);
|
||||
assert(data.value != nullptr);
|
||||
// TODO: supporting packed qkv for self attention may benefit performance
|
||||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
|
||||
max_threads_per_block, false, data.query, q));
|
||||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
|
||||
max_threads_per_block, false, data.key, k));
|
||||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
|
||||
max_threads_per_block, false, data.value, v));
|
||||
}
|
||||
qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
|
||||
}
|
||||
#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION
|
||||
// When past_key/past_value are inputted directly as key/value and there is no present_key/present_value
|
||||
else if ((data.use_memory_efficient_attention || data.use_flash_attention) &&
|
||||
data.past_key != nullptr &&
|
||||
data.past_value != nullptr &&
|
||||
parameters.pass_past_in_kv) {
|
||||
// Transpose past_key and past_value to use memory efficient attention
|
||||
|
||||
// past_key (BxNxSxH) => temp_k_workspace (BxSxNxH)
|
||||
ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads,
|
||||
max_threads_per_block, false, data.past_key, data.temp_k_workspace));
|
||||
// past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
|
||||
ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads,
|
||||
max_threads_per_block, false, data.past_value, data.temp_v_workspace));
|
||||
|
||||
// query => q, temp_k_workspace => k, temp_v_workspace => v
|
||||
LaunchAddBias(stream, max_threads_per_block,
|
||||
batch_size, sequence_length, kv_sequence_length,
|
||||
num_heads, qk_head_size, v_head_size,
|
||||
data.bias, data.query, data.temp_k_workspace, data.temp_v_workspace, q, k, v);
|
||||
|
||||
DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size);
|
||||
qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
|
||||
|
||||
data.past_key = nullptr;
|
||||
data.past_value = nullptr;
|
||||
}
|
||||
// When there is no past_key/past_value and there is present_key/present_value
|
||||
// (e.g. get initial kv to use as past_kv in the next iteration)
|
||||
else if ((data.use_memory_efficient_attention || data.use_flash_attention) &&
|
||||
data.present_key != nullptr &&
|
||||
data.present_value != nullptr) {
|
||||
// Use memory efficient attention kernel
|
||||
LaunchAddBias(stream, max_threads_per_block,
|
||||
batch_size, sequence_length, kv_sequence_length,
|
||||
num_heads, qk_head_size, v_head_size,
|
||||
data.bias, data.query, data.key, data.value, q, data.temp_k_workspace, data.temp_v_workspace);
|
||||
|
||||
// temp_k_workspace (BxSxNxH) => present_k (BxNxSxH)
|
||||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
|
||||
max_threads_per_block, false, data.temp_k_workspace, data.present_key));
|
||||
|
||||
// temp_v_workspace (BxSxNxH_v) => present_v (BxNxSxH_v)
|
||||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
|
||||
max_threads_per_block, false, data.temp_v_workspace, data.present_value));
|
||||
|
||||
DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("k(BSNH)", data.temp_k_workspace, batch_size, kv_sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("v(BSNH)", data.temp_v_workspace, batch_size, kv_sequence_length, num_heads, v_head_size);
|
||||
qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
|
||||
}
|
||||
#endif
|
||||
else {
|
||||
// Use unfused kernel for Q, use unfused kernel for K and V if needed
|
||||
constexpr int format = 0;
|
||||
// Query (BxSxNxH) => Q (BxNxSxH)
|
||||
LaunchAddBiasTranspose<T>(stream, 1, format, max_threads_per_block,
|
||||
batch_size, sequence_length, num_heads, qk_head_size,
|
||||
data.query, data.bias, q,
|
||||
true, -1);
|
||||
|
||||
if (!parameters.pass_past_in_kv) {
|
||||
T* k_dest = (data.past_key == nullptr && data.present_key != nullptr) ? data.present_key : k;
|
||||
T* v_dest = (data.past_value == nullptr && data.present_value != nullptr) ? data.present_value : v;
|
||||
|
||||
// Key (BxLxNxH) => K (BxNxLxH)
|
||||
LaunchAddBiasTranspose<T>(stream, 1, format, max_threads_per_block,
|
||||
batch_size, kv_sequence_length, num_heads, qk_head_size,
|
||||
data.key, data.bias + num_heads * qk_head_size, k_dest,
|
||||
true, -1);
|
||||
|
||||
// Value (BxLxNxH_v) => V (BxNxLxH_v)
|
||||
LaunchAddBiasTranspose<T>(stream, 1, format, max_threads_per_block,
|
||||
batch_size, kv_sequence_length, num_heads, v_head_size,
|
||||
data.value, data.bias + 2 * num_heads * qk_head_size, v_dest,
|
||||
true, -1);
|
||||
|
||||
DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size);
|
||||
DUMP_TENSOR_D("k(BNSH)", k_dest, batch_size, num_heads, kv_sequence_length, qk_head_size);
|
||||
DUMP_TENSOR_D("v(BNSH)", v_dest, batch_size, num_heads, kv_sequence_length, v_head_size);
|
||||
}
|
||||
qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// For MultiHeadAttention without past state, with packed QKV inputs
|
||||
template <typename T>
|
||||
Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters,
|
||||
AttentionData<T>& data,
|
||||
cudaStream_t stream,
|
||||
int max_threads_per_block,
|
||||
T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
|
||||
const int batch_size = parameters.batch_size;
|
||||
const int sequence_length = parameters.sequence_length;
|
||||
const int num_heads = parameters.num_heads;
|
||||
const int qk_head_size = parameters.head_size;
|
||||
const int v_head_size = parameters.v_head_size;
|
||||
void* fused_runner = data.fused_runner;
|
||||
|
||||
T* qkv = data.workspace;
|
||||
|
||||
bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional);
|
||||
|
||||
assert(data.bias == nullptr);
|
||||
assert(qk_head_size == v_head_size);
|
||||
|
||||
DUMP_TENSOR_INIT();
|
||||
DUMP_TENSOR_D("packed_qkv", data.query, batch_size * sequence_length, num_heads, 3, qk_head_size);
|
||||
|
||||
if (data.use_memory_efficient_attention || data.use_flash_attention) {
|
||||
// unpack qkv to BSNH. Note that there is no bias so we need not output query to q.
|
||||
constexpr int format = 4;
|
||||
T* qkv_add_bias = nullptr;
|
||||
LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block,
|
||||
batch_size, sequence_length, num_heads, qk_head_size,
|
||||
data.query, data.bias, qkv,
|
||||
true, v_head_size, qkv_add_bias, 3);
|
||||
DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("k(BSNH)", k, batch_size, sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("v(BSNH)", v, batch_size, sequence_length, num_heads, v_head_size);
|
||||
qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
|
||||
} else {
|
||||
if (!use_fused_kernel) {
|
||||
return ORT_MAKE_STATUS(
|
||||
ONNXRUNTIME, NOT_IMPLEMENTED,
|
||||
"packed QKV format is not implemented for current GPU. Please disable it in fusion options.");
|
||||
}
|
||||
|
||||
qkv_format = AttentionQkvFormat::QKV_BSN3H;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// For MultiHeadAttention without past state, with packed KV inputs
|
||||
template <typename T>
|
||||
Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters,
|
||||
AttentionData<T>& data,
|
||||
cudaStream_t stream,
|
||||
int max_threads_per_block,
|
||||
T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
|
||||
const int batch_size = parameters.batch_size;
|
||||
const int kv_sequence_length = parameters.kv_sequence_length;
|
||||
const int num_heads = parameters.num_heads;
|
||||
const int qk_head_size = parameters.head_size;
|
||||
const int v_head_size = parameters.v_head_size;
|
||||
|
||||
// TODO: unpack kv to BNSH for unfused kernel so that we can remove the following constraint.
|
||||
// CheckInputs verified this constraint.
|
||||
assert(data.bias == nullptr);
|
||||
assert(qk_head_size == v_head_size);
|
||||
|
||||
DUMP_TENSOR_INIT();
|
||||
DUMP_TENSOR_D("packed_kv", data.key, batch_size * kv_sequence_length, num_heads, 2, qk_head_size);
|
||||
|
||||
if (data.use_memory_efficient_attention || data.use_flash_attention) {
|
||||
// unpack kv to BSNH. Note that there is no bias so we need not output query to q.
|
||||
constexpr int format = 4;
|
||||
T* qkv_add_bias = nullptr;
|
||||
const T* kv_bias = (data.bias == nullptr ? data.bias : data.bias + parameters.hidden_size);
|
||||
LaunchAddBiasTranspose(stream, 2, format, max_threads_per_block,
|
||||
batch_size, kv_sequence_length, num_heads, qk_head_size,
|
||||
data.key, kv_bias, k,
|
||||
true, v_head_size, qkv_add_bias, 2);
|
||||
DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size);
|
||||
qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
|
||||
} else {
|
||||
if (data.fused_cross_attention_kernel == nullptr) {
|
||||
return ORT_MAKE_STATUS(
|
||||
ONNXRUNTIME, NOT_IMPLEMENTED,
|
||||
"packed KV format is not implemented for current GPU. Please disable packed kv in fusion options.");
|
||||
}
|
||||
|
||||
qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// For MultiHeadAttention without past state, with Q, K and V inputs
|
||||
template <typename T>
|
||||
Status PrepareQkv_MHA_NotPacked(contrib::AttentionParameters& parameters,
|
||||
AttentionData<T>& data,
|
||||
cudaStream_t stream,
|
||||
int max_threads_per_block,
|
||||
T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
|
||||
const int batch_size = parameters.batch_size;
|
||||
const int sequence_length = parameters.sequence_length;
|
||||
const int kv_sequence_length = parameters.kv_sequence_length;
|
||||
const int num_heads = parameters.num_heads;
|
||||
const int qk_head_size = parameters.head_size;
|
||||
const int v_head_size = parameters.v_head_size;
|
||||
void* fused_runner = data.fused_runner;
|
||||
|
||||
T* qkv = data.workspace;
|
||||
|
||||
bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional);
|
||||
bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional);
|
||||
|
||||
// gemm_buffer == nullptr and not packed
|
||||
assert(data.query != nullptr && data.key != nullptr && data.value != nullptr);
|
||||
|
||||
DUMP_TENSOR_INIT();
|
||||
DUMP_TENSOR_D("query", data.query, batch_size, sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("key", data.key, batch_size, kv_sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("value", data.value, batch_size, kv_sequence_length, num_heads, v_head_size);
|
||||
|
||||
#if DUMP_TENSOR_LEVEL > 1
|
||||
if (data.bias != nullptr) {
|
||||
DUMP_TENSOR_D("query_bias", data.bias, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("key_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("value_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size);
|
||||
}
|
||||
#endif
|
||||
|
||||
if (data.relative_position_bias != nullptr && parameters.broadcast_res_pos_bias) {
|
||||
DUMP_TENSOR_D("relative_position_bias", data.relative_position_bias,
|
||||
num_heads, sequence_length, kv_sequence_length);
|
||||
}
|
||||
|
||||
if (data.mask_index != nullptr && parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) {
|
||||
DUMP_TENSOR_D("mask_index", data.mask_index, 3 * batch_size + 2, 1);
|
||||
}
|
||||
|
||||
if (data.fused_cross_attention_kernel != nullptr) {
|
||||
assert(qk_head_size == v_head_size);
|
||||
|
||||
// For fused cross attention, besides adding bias, K and V needed to be packed:
|
||||
// K (BxSxNxH), V (BxSxNxH) => BxSxNx2xH
|
||||
LaunchAddBiasTransposeTrt(
|
||||
stream, max_threads_per_block,
|
||||
batch_size, sequence_length,
|
||||
num_heads, qk_head_size,
|
||||
data.bias, data.query, data.key, data.value, qkv, true, kv_sequence_length);
|
||||
|
||||
qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
|
||||
}
|
||||
#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION
|
||||
else if (data.use_memory_efficient_attention || data.use_flash_attention) {
|
||||
LaunchAddBias(stream, max_threads_per_block,
|
||||
batch_size, sequence_length, kv_sequence_length,
|
||||
num_heads, qk_head_size, v_head_size,
|
||||
data.bias, data.query, data.key, data.value, q, k, v);
|
||||
|
||||
DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size);
|
||||
qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
|
||||
}
|
||||
#endif
|
||||
else if (use_fused_kernel) {
|
||||
assert(qk_head_size == v_head_size);
|
||||
|
||||
// Q (BxSxNxH), K (BxSxNxH), V (BxSxNxH) => BxSxNx(H + H + H)
|
||||
LaunchAddBiasTransposeTrt(
|
||||
stream, max_threads_per_block,
|
||||
batch_size, sequence_length,
|
||||
num_heads, qk_head_size,
|
||||
data.bias, data.query, data.key, data.value, qkv, false, kv_sequence_length);
|
||||
DUMP_TENSOR_D("qkv(BSN3H)", qkv, batch_size, sequence_length, num_heads, 2 * qk_head_size + v_head_size);
|
||||
|
||||
qkv_format = AttentionQkvFormat::QKV_BSN3H;
|
||||
} else { // unfused kernel
|
||||
ORT_ENFORCE(!use_fused_causal, "MultiHeadAttention has not enabled fused causal");
|
||||
|
||||
// Query (BxSxNxH) => Q (BxNxSxH)
|
||||
constexpr int format = 0;
|
||||
LaunchAddBiasTranspose<T>(stream, 1, format, max_threads_per_block,
|
||||
batch_size, sequence_length, num_heads, qk_head_size,
|
||||
data.query, data.bias, q,
|
||||
true, -1);
|
||||
|
||||
// Key (BxLxNxH) => K (BxNxLxH)
|
||||
LaunchAddBiasTranspose<T>(stream, 1, format, max_threads_per_block,
|
||||
batch_size, kv_sequence_length, num_heads, qk_head_size,
|
||||
data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, k,
|
||||
true, -1);
|
||||
|
||||
// Value (BxLxNxH_v) => K (BxNxLxH_v)
|
||||
LaunchAddBiasTranspose<T>(stream, 1, format, max_threads_per_block,
|
||||
batch_size, kv_sequence_length, num_heads, v_head_size,
|
||||
data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, v,
|
||||
true, -1);
|
||||
|
||||
DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size);
|
||||
DUMP_TENSOR_D("k(BNSH)", k, batch_size, num_heads, kv_sequence_length, qk_head_size);
|
||||
DUMP_TENSOR_D("v(BNSH)", v, batch_size, num_heads, kv_sequence_length, v_head_size);
|
||||
qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status PrepareQkv(contrib::AttentionParameters& parameters,
|
||||
AttentionData<T>& data,
|
||||
cudaStream_t stream,
|
||||
int max_threads_per_block,
|
||||
T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
|
||||
if (nullptr != data.gemm_buffer) { // Attention operator
|
||||
ORT_RETURN_IF_ERROR(PrepareQkv_Attention<T>(parameters, data, stream, max_threads_per_block, qkv_format));
|
||||
} else if (data.past_key != nullptr || data.present_key != nullptr) { // mha operator with past/present state
|
||||
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format));
|
||||
} else if (data.key == nullptr) { // multihead attention operator, no past, packed qkv
|
||||
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedQKV(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format));
|
||||
} else if (data.value == nullptr) { // multihead attention operator, no past, packed kv
|
||||
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedKV(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format));
|
||||
} else { // multihead attention operator, no past, separated Q/K/V inputs
|
||||
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NotPacked(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format));
|
||||
}
|
||||
|
||||
CUDA_RETURN_IF_ERROR(cudaGetLastError());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status QkvToContext(
|
||||
const cudaDeviceProp& device_prop,
|
||||
|
|
@ -755,92 +311,22 @@ Status QkvToContext(
|
|||
|
||||
const int batches = batch_size * num_heads;
|
||||
|
||||
T* qkv = nullptr;
|
||||
T* q = nullptr;
|
||||
T* k = nullptr;
|
||||
T* v = nullptr;
|
||||
T* scratch1 = data.workspace;
|
||||
if (data.has_qkv_workspace) {
|
||||
const int size_per_batch_q = sequence_length * qk_head_size;
|
||||
const int size_per_batch_k = kv_sequence_length * qk_head_size;
|
||||
const int size_per_batch_v = kv_sequence_length * v_head_size;
|
||||
const size_t elements_q = static_cast<size_t>(batches) * static_cast<size_t>(size_per_batch_q);
|
||||
const size_t elements_k = static_cast<size_t>(batches) * static_cast<size_t>(size_per_batch_k);
|
||||
const size_t elements_v = static_cast<size_t>(batches) * static_cast<size_t>(size_per_batch_v);
|
||||
qkv = data.workspace;
|
||||
q = qkv;
|
||||
k = q + elements_q;
|
||||
v = k + elements_k;
|
||||
scratch1 = v + elements_v;
|
||||
}
|
||||
|
||||
bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional);
|
||||
bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional);
|
||||
|
||||
AttentionQkvFormat qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
|
||||
ORT_RETURN_IF_ERROR(PrepareQkv<T>(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format));
|
||||
QkvData<T> qkv;
|
||||
ORT_RETURN_IF_ERROR(PrepareQkv<T>(parameters, data, stream, max_threads_per_block, qkv));
|
||||
T* scratch1 = data.has_qkv_workspace ? qkv.after_v : data.workspace;
|
||||
|
||||
int present_size_per_batch_k = 0;
|
||||
int present_size_per_batch_v = 0;
|
||||
if (!past_present_share_buffer) {
|
||||
// Concat past key value to present (2xBxNxLxH), where L is kv_sequence_length and T is total_sequence_length.
|
||||
// past_k (BxNxPxH) + k (BxNxLxH) => present_k (BxNxTxH)
|
||||
// past_v (BxNxPxH) + v (BxNxLxH) => present_v (BxNxTxH)
|
||||
// When there is past state, the head size for Q/K/V shall be same: H == H_v.
|
||||
present_size_per_batch_k = total_sequence_length * qk_head_size;
|
||||
present_size_per_batch_v = total_sequence_length * v_head_size;
|
||||
ORT_RETURN_IF_ERROR(ConcatPastToPresent(batch_size, num_heads, qk_head_size, v_head_size,
|
||||
sequence_length, total_sequence_length, parameters.pass_past_in_kv,
|
||||
stream, max_threads_per_block, data, qkv));
|
||||
|
||||
if (nullptr != data.present) {
|
||||
assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH || qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH);
|
||||
ORT_RETURN_IF_ERROR(
|
||||
LaunchConcatPastToPresent(
|
||||
stream, total_sequence_length, sequence_length, batch_size, qk_head_size, num_heads,
|
||||
max_threads_per_block, data.past, k, data.present));
|
||||
|
||||
// Update pointers to present_k and present_v.
|
||||
k = data.present;
|
||||
v = data.present + batches * present_size_per_batch_k;
|
||||
}
|
||||
|
||||
if (nullptr != data.past_key || nullptr != data.present_key) {
|
||||
if (nullptr != data.past_key && nullptr == data.present_key) {
|
||||
k = const_cast<T*>(data.past_key);
|
||||
v = const_cast<T*>(data.past_value);
|
||||
} else if (nullptr == data.past_key && nullptr != data.present_key) {
|
||||
if (qkv_format == AttentionQkvFormat::Q_K_V_BNSH) {
|
||||
k = data.present_key;
|
||||
v = data.present_value;
|
||||
} else {
|
||||
assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
|
||||
k = data.temp_k_workspace;
|
||||
v = data.temp_v_workspace;
|
||||
}
|
||||
} else if (parameters.pass_past_in_kv) {
|
||||
// past_key and past_value are used directly as key and value in attention computations
|
||||
k = const_cast<T*>(data.past_key);
|
||||
v = const_cast<T*>(data.past_value);
|
||||
|
||||
// This path has a memory copy from past_key and past_value to present_key and present_value
|
||||
// Avoid this path since the memory copy is unnecessary because past_key == present_key and
|
||||
// past_value == present_value
|
||||
int64_t k_size = (int64_t)batch_size * num_heads * parameters.total_sequence_length * qk_head_size;
|
||||
int64_t v_size = (int64_t)batch_size * num_heads * parameters.total_sequence_length * v_head_size;
|
||||
cudaMemcpyAsync(data.present_key, data.past_key, k_size * sizeof(T), cudaMemcpyDeviceToDevice, stream);
|
||||
cudaMemcpyAsync(data.present_value, data.past_value, v_size * sizeof(T), cudaMemcpyDeviceToDevice, stream);
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(
|
||||
LaunchConcatTensorToTensor(stream, parameters.total_sequence_length, sequence_length,
|
||||
batch_size, qk_head_size, num_heads,
|
||||
max_threads_per_block, 1, data.past_key, k, data.present_key));
|
||||
ORT_RETURN_IF_ERROR(
|
||||
LaunchConcatTensorToTensor(stream, parameters.total_sequence_length, sequence_length,
|
||||
batch_size, v_head_size, num_heads,
|
||||
max_threads_per_block, 1, data.past_value, v, data.present_value));
|
||||
// Update pointers to present_k and present_v.
|
||||
k = data.present_key;
|
||||
v = data.present_value;
|
||||
}
|
||||
}
|
||||
} else { // past_present_share_buffer
|
||||
assert(qk_head_size == v_head_size);
|
||||
assert(data.fused_cross_attention_kernel == nullptr);
|
||||
|
|
@ -870,15 +356,15 @@ Status QkvToContext(
|
|||
|
||||
present_size_per_batch_k = parameters.max_sequence_length * qk_head_size;
|
||||
present_size_per_batch_v = present_size_per_batch_k;
|
||||
k = data.present;
|
||||
v = data.present + batches * present_size_per_batch_k;
|
||||
qkv.k = data.present;
|
||||
qkv.v = data.present + batches * present_size_per_batch_k;
|
||||
}
|
||||
|
||||
// Q, K and V are ready now
|
||||
DUMP_TENSOR_INIT();
|
||||
|
||||
if (data.fused_cross_attention_kernel != nullptr) {
|
||||
assert(qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H);
|
||||
assert(qkv.format == AttentionQkvFormat::Q_KV_BSNH_BSN2H);
|
||||
|
||||
// We only enable fused cross attention when there is no key padding mask.
|
||||
// Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query.
|
||||
|
|
@ -902,8 +388,8 @@ Status QkvToContext(
|
|||
reinterpret_cast<FusedMultiHeadCrossAttentionKernel const*>(data.fused_cross_attention_kernel);
|
||||
|
||||
// When there is no bias, we can directly use q and packed kv from inputs.
|
||||
void const* query = q;
|
||||
void const* packed_kv = k;
|
||||
void const* query = qkv.q;
|
||||
void const* packed_kv = qkv.k;
|
||||
if (data.value == nullptr && data.bias == nullptr) {
|
||||
query = data.query;
|
||||
packed_kv = data.key;
|
||||
|
|
@ -951,10 +437,10 @@ Status QkvToContext(
|
|||
fused_fp16_runner->setup(S, B);
|
||||
|
||||
if (use_fused_kernel) {
|
||||
assert(qkv_format == AttentionQkvFormat::QKV_BSN3H);
|
||||
assert(qkv.format == AttentionQkvFormat::QKV_BSN3H);
|
||||
|
||||
// When there is no bias, we can directly use packed qkv from inputs.
|
||||
void const* packed_qkv = qkv;
|
||||
void const* packed_qkv = qkv.q;
|
||||
if (data.query != nullptr && data.key == nullptr && data.bias == nullptr) {
|
||||
packed_qkv = data.query;
|
||||
}
|
||||
|
|
@ -962,7 +448,7 @@ Status QkvToContext(
|
|||
fused_fp16_runner->run(packed_qkv, sequence_offset, data.output, stream);
|
||||
DUMP_TENSOR("fused output", data.output, batch_size, sequence_length, num_heads, v_head_size);
|
||||
} else {
|
||||
assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH);
|
||||
assert(qkv.format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH);
|
||||
fused_fp16_runner->run(data.gemm_buffer, sequence_offset, data.output, stream);
|
||||
DUMP_TENSOR("fused causal output", data.output, batch_size, sequence_length, num_heads, v_head_size);
|
||||
}
|
||||
|
|
@ -975,22 +461,22 @@ Status QkvToContext(
|
|||
|
||||
#if USE_FLASH_ATTENTION
|
||||
if (data.use_flash_attention) {
|
||||
assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
|
||||
assert(qkv.format == AttentionQkvFormat::Q_K_V_BSNH);
|
||||
assert(nullptr == data.mask_index);
|
||||
assert(nullptr == data.relative_position_bias);
|
||||
assert(parameters.head_size == parameters.v_head_size);
|
||||
|
||||
void* query = reinterpret_cast<void*>(q);
|
||||
void* key = reinterpret_cast<void*>(k);
|
||||
void* value = reinterpret_cast<void*>(v);
|
||||
void* query = reinterpret_cast<void*>(qkv.q);
|
||||
void* key = reinterpret_cast<void*>(qkv.k);
|
||||
void* value = reinterpret_cast<void*>(qkv.v);
|
||||
// For packed KV, we can use query input directly.
|
||||
if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr && data.bias == nullptr) {
|
||||
query = reinterpret_cast<void*>(const_cast<T*>(data.query));
|
||||
}
|
||||
|
||||
DUMP_TENSOR_D("q(BSNH)", reinterpret_cast<const T*>(query), batch_size, sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("k(BSNH)", k, batch_size, parameters.total_sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("v(BSNH)", v, batch_size, parameters.total_sequence_length, num_heads, v_head_size);
|
||||
DUMP_TENSOR_D("k(BSNH)", qkv.k, batch_size, parameters.total_sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("v(BSNH)", qkv.v, batch_size, parameters.total_sequence_length, num_heads, v_head_size);
|
||||
|
||||
constexpr bool is_causal = false;
|
||||
ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd(
|
||||
|
|
@ -1008,11 +494,11 @@ Status QkvToContext(
|
|||
if (data.use_memory_efficient_attention) {
|
||||
// We only enable fused cross attention when there is no key padding mask.
|
||||
// Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query.
|
||||
assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
|
||||
assert(qkv.format == AttentionQkvFormat::Q_K_V_BSNH);
|
||||
|
||||
const void* query = q;
|
||||
const void* key = k;
|
||||
const void* value = v;
|
||||
const void* query = qkv.q;
|
||||
const void* key = qkv.k;
|
||||
const void* value = qkv.v;
|
||||
// For packed KV, we can use query input directly.
|
||||
if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr) {
|
||||
assert(data.bias == nullptr);
|
||||
|
|
@ -1020,8 +506,8 @@ Status QkvToContext(
|
|||
}
|
||||
|
||||
DUMP_TENSOR_D("q(BSNH)", reinterpret_cast<const T*>(query), batch_size, sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("k(BSNH)", k, batch_size, parameters.total_sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("v(BSNH)", v, batch_size, parameters.total_sequence_length, num_heads, v_head_size);
|
||||
DUMP_TENSOR_D("k(BSNH)", qkv.k, batch_size, parameters.total_sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("v(BSNH)", qkv.v, batch_size, parameters.total_sequence_length, num_heads, v_head_size);
|
||||
|
||||
MemoryEfficientAttentionParams p;
|
||||
p.sm = device_prop.major * 10 + device_prop.minor;
|
||||
|
|
@ -1061,7 +547,7 @@ Status QkvToContext(
|
|||
#endif
|
||||
|
||||
// The following are unfused attention.
|
||||
assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH);
|
||||
assert(qkv.format == AttentionQkvFormat::Q_K_V_BNSH);
|
||||
const int* mask_index = data.mask_index;
|
||||
gsl::span<const int64_t>& mask_index_dims = data.mask_index_dims;
|
||||
|
||||
|
|
@ -1082,12 +568,12 @@ Status QkvToContext(
|
|||
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
|
||||
cublas, CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
total_sequence_length, sequence_length, qk_head_size,
|
||||
&alpha, k, qk_head_size, present_size_per_batch_k,
|
||||
q, qk_head_size, sequence_length * qk_head_size,
|
||||
&alpha, qkv.k, qk_head_size, present_size_per_batch_k,
|
||||
qkv.q, qk_head_size, sequence_length * qk_head_size,
|
||||
&zero, scratch1, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop));
|
||||
|
||||
DUMP_TENSOR_D("Q", q, batch_size, num_heads, sequence_length, qk_head_size);
|
||||
DUMP_TENSOR_D("K", k, batch_size, num_heads, qk_head_size, sequence_length);
|
||||
DUMP_TENSOR_D("Q", qkv.q, batch_size, num_heads, sequence_length, qk_head_size);
|
||||
DUMP_TENSOR_D("K", qkv.k, batch_size, num_heads, qk_head_size, sequence_length);
|
||||
DUMP_TENSOR_D("QK", scratch1, batch_size, num_heads, sequence_length, total_sequence_length);
|
||||
|
||||
const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads,
|
||||
|
|
@ -1126,14 +612,14 @@ Status QkvToContext(
|
|||
}
|
||||
|
||||
DUMP_TENSOR_D("Softmax", scratch2, batch_size, num_heads, sequence_length, total_sequence_length);
|
||||
DUMP_TENSOR_D("V", v, batch_size, num_heads, sequence_length, v_head_size);
|
||||
DUMP_TENSOR_D("V", qkv.v, batch_size, num_heads, sequence_length, v_head_size);
|
||||
|
||||
// compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v
|
||||
T* temp_output = qkv;
|
||||
T* temp_output = qkv.q;
|
||||
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
|
||||
cublas, CUBLAS_OP_N, CUBLAS_OP_N,
|
||||
v_head_size, sequence_length, total_sequence_length,
|
||||
&one, v, v_head_size, present_size_per_batch_v,
|
||||
&one, qkv.v, v_head_size, present_size_per_batch_v,
|
||||
scratch2, total_sequence_length, sequence_length * total_sequence_length,
|
||||
&zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop));
|
||||
|
||||
|
|
|
|||
|
|
@ -2,11 +2,12 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "core/providers/cuda/shared_inc/cuda_utils.h"
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <cublas_v2.h>
|
||||
#include "contrib_ops/cpu/bert/attention_common.h"
|
||||
#include "core/common/gsl.h"
|
||||
#include "core/framework/allocator.h"
|
||||
#include "contrib_ops/cpu/bert/attention_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -49,39 +50,56 @@ size_t GetAttentionWorkspaceSize(
|
|||
|
||||
template <typename T>
|
||||
struct AttentionData {
|
||||
T* gemm_buffer;
|
||||
const T* bias;
|
||||
T* gemm_buffer = nullptr;
|
||||
const T* bias = nullptr;
|
||||
|
||||
const T* query;
|
||||
const T* key;
|
||||
const T* value;
|
||||
const int* mask_index;
|
||||
const T* query = nullptr;
|
||||
const T* key = nullptr;
|
||||
const T* value = nullptr;
|
||||
const int* mask_index = nullptr;
|
||||
gsl::span<const int64_t> mask_index_dims;
|
||||
const T* past;
|
||||
const T* past_key;
|
||||
const T* past_value;
|
||||
const T* relative_position_bias;
|
||||
const T* past = nullptr;
|
||||
const T* past_key = nullptr;
|
||||
const T* past_value = nullptr;
|
||||
const T* relative_position_bias = nullptr;
|
||||
|
||||
bool has_qkv_workspace;
|
||||
T* workspace;
|
||||
T* temp_k_workspace;
|
||||
T* temp_v_workspace;
|
||||
bool has_qkv_workspace = false;
|
||||
T* workspace = nullptr;
|
||||
T* temp_k_workspace = nullptr;
|
||||
T* temp_v_workspace = nullptr;
|
||||
|
||||
T* output;
|
||||
T* present;
|
||||
T* present_key;
|
||||
T* present_value;
|
||||
T* output = nullptr;
|
||||
T* present = nullptr;
|
||||
T* present_key = nullptr;
|
||||
T* present_value = nullptr;
|
||||
|
||||
void* fused_runner;
|
||||
const void* fused_cross_attention_kernel;
|
||||
void* fused_runner = nullptr;
|
||||
const void* fused_cross_attention_kernel = nullptr;
|
||||
|
||||
bool use_flash_attention;
|
||||
bool use_memory_efficient_attention;
|
||||
bool use_flash_attention = false;
|
||||
bool use_memory_efficient_attention = false;
|
||||
|
||||
mutable CumulatedSequenceLengthCache* cumulated_sequence_length_q_cache;
|
||||
mutable CumulatedSequenceLengthCache* cumulated_sequence_length_kv_cache;
|
||||
mutable CumulatedSequenceLengthCache* cumulated_sequence_length_q_cache = nullptr;
|
||||
mutable CumulatedSequenceLengthCache* cumulated_sequence_length_kv_cache = nullptr;
|
||||
};
|
||||
|
||||
// Intermediate data pointers available after PrepareQKV
|
||||
template <typename T>
|
||||
struct QkvData {
|
||||
T* q = nullptr;
|
||||
T* k = nullptr;
|
||||
T* v = nullptr;
|
||||
T* after_v = nullptr; // pointer right after v
|
||||
AttentionQkvFormat format = AttentionQkvFormat::Q_K_V_BSNH;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
Status PrepareQkv(contrib::AttentionParameters& parameters,
|
||||
AttentionData<T>& data,
|
||||
cudaStream_t stream,
|
||||
int max_threads_per_block,
|
||||
QkvData<T>& qkv_data);
|
||||
|
||||
template <typename T>
|
||||
Status QkvToContext(
|
||||
const cudaDeviceProp& device_prop,
|
||||
|
|
@ -161,27 +179,13 @@ Status LaunchConcatTensorToTensor(cudaStream_t stream,
|
|||
const half* tensor_add,
|
||||
half* tensor_out);
|
||||
|
||||
Status LaunchConcatPastToPresent(cudaStream_t stream,
|
||||
const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int batch_size,
|
||||
const int head_size,
|
||||
const int num_heads,
|
||||
const int max_threads_per_block,
|
||||
const float* past,
|
||||
const float* k_v,
|
||||
float* present);
|
||||
|
||||
Status LaunchConcatPastToPresent(cudaStream_t stream,
|
||||
const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int batch_size,
|
||||
const int head_size,
|
||||
const int num_heads,
|
||||
const int max_threads_per_block,
|
||||
const half* past,
|
||||
const half* k_v,
|
||||
half* present);
|
||||
template <typename T>
|
||||
Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size,
|
||||
int sequence_length, int total_sequence_length, bool pass_past_in_kv,
|
||||
cudaStream_t stream,
|
||||
int max_threads_per_block,
|
||||
AttentionData<T>& data,
|
||||
QkvData<T>& qkv);
|
||||
|
||||
template <typename T>
|
||||
Status LaunchStridedCopy(cudaStream_t stream,
|
||||
|
|
|
|||
492
onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
Normal file
492
onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
Normal file
|
|
@ -0,0 +1,492 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
#include "contrib_ops/cuda/bert/attention_impl.h"
|
||||
#include "contrib_ops/cuda/bert/add_bias_transpose.h"
|
||||
#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
|
||||
|
||||
using namespace onnxruntime::cuda;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
template <typename T>
|
||||
Status PrepareQkv_Attention(contrib::AttentionParameters& parameters,
|
||||
AttentionData<T>& data,
|
||||
cudaStream_t stream,
|
||||
int max_threads_per_block,
|
||||
AttentionQkvFormat& qkv_format) {
|
||||
const int batch_size = parameters.batch_size;
|
||||
const int sequence_length = parameters.sequence_length;
|
||||
const int num_heads = parameters.num_heads;
|
||||
const int qk_head_size = parameters.head_size;
|
||||
const int v_head_size = parameters.v_head_size;
|
||||
const bool past_present_share_buffer = parameters.past_present_share_buffer;
|
||||
void* fused_runner = data.fused_runner;
|
||||
bool use_flash_or_efficient_attention = data.use_flash_attention || data.use_memory_efficient_attention;
|
||||
|
||||
T* qkv = data.workspace;
|
||||
|
||||
bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional);
|
||||
bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional);
|
||||
|
||||
if (data.bias == nullptr) {
|
||||
assert(nullptr == fused_runner);
|
||||
// For quantized attention, bias has been added so only need transpose here.
|
||||
// gemm_buffer should be BxSx3xNxH => qkv: 3xBxNxSxH
|
||||
assert(qk_head_size == v_head_size);
|
||||
int matrix_to_trans = (past_present_share_buffer ? 1 : 3);
|
||||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, matrix_to_trans, sequence_length, batch_size, qk_head_size, num_heads,
|
||||
max_threads_per_block, false, data.gemm_buffer, qkv, 3));
|
||||
qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
|
||||
} else {
|
||||
// For fused TRT attention, transpose qkv to BxSxNx3xH (format 2)
|
||||
// For flash or memory efficient attention, transpose to 3xBxSxNxH (format 3)
|
||||
// For unfused kernel, transpose to 3xBxNxSxH (format 1)
|
||||
// For fused causal kernel, use format 1 since we need have K and V to update present state,
|
||||
// at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel.
|
||||
const int format = (use_fused_kernel ? 2 : (use_flash_or_efficient_attention ? 3 : 1));
|
||||
qkv_format = use_fused_kernel
|
||||
? AttentionQkvFormat::QKV_BSN3H
|
||||
: (use_flash_or_efficient_attention
|
||||
? AttentionQkvFormat::Q_K_V_BSNH
|
||||
: (use_fused_causal
|
||||
? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH
|
||||
: AttentionQkvFormat::Q_K_V_BNSH));
|
||||
|
||||
// For fused causal, we will update gemm_buffer with bias directly.
|
||||
T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr;
|
||||
|
||||
int matrix_to_transpose = ((format == AttentionQkvFormat::Q_K_V_BNSH && past_present_share_buffer) ? 1 : 3);
|
||||
// format 1: BxSx(NH + NH + NH_v) => BxNxSxH + BxNxSxH + BxNxSxH_v
|
||||
// format 2: BxSx(NH + NH + NH) => BxSxNx(H + H + H)
|
||||
LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block,
|
||||
batch_size, sequence_length, num_heads, qk_head_size,
|
||||
data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias,
|
||||
3, parameters.do_rotary, parameters.past_sequence_length);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// For MultiHeadAttention with past state
|
||||
template <typename T>
|
||||
Status PrepareQkv_MHA_WithPast(contrib::AttentionParameters& parameters,
|
||||
AttentionData<T>& data,
|
||||
cudaStream_t stream,
|
||||
int max_threads_per_block,
|
||||
T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
|
||||
const int batch_size = parameters.batch_size;
|
||||
const int sequence_length = parameters.sequence_length;
|
||||
const int kv_sequence_length = parameters.kv_sequence_length;
|
||||
const int num_heads = parameters.num_heads;
|
||||
const int qk_head_size = parameters.head_size;
|
||||
const int v_head_size = parameters.v_head_size;
|
||||
|
||||
DUMP_TENSOR_INIT();
|
||||
|
||||
if (data.bias == nullptr) {
|
||||
// Below logic does not support fused attention with past without bias
|
||||
// When there is past state, the format shall be BxNxSxH, so we disable fused attention when there is past.
|
||||
|
||||
// cross attention with past state
|
||||
if (data.past_key != nullptr && data.present_key == nullptr) {
|
||||
assert(data.past_value != nullptr);
|
||||
assert(data.query != nullptr);
|
||||
assert(data.key == nullptr);
|
||||
assert(data.value == nullptr);
|
||||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
|
||||
max_threads_per_block, false, data.query, q));
|
||||
}
|
||||
// cross attention with present state or self attention with present state
|
||||
else if (data.past_key == nullptr && data.present_key != nullptr) {
|
||||
assert(data.past_value == nullptr);
|
||||
assert(data.present_value != nullptr);
|
||||
assert(data.query != nullptr);
|
||||
assert(data.key != nullptr);
|
||||
assert(data.value != nullptr);
|
||||
|
||||
// TODO: supporting packed qkv for self attention may benefit performance
|
||||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
|
||||
max_threads_per_block, false, data.query, q));
|
||||
|
||||
// TODO: supporting packed kv for cross attention may benefit performance
|
||||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
|
||||
max_threads_per_block, false, data.key, data.present_key));
|
||||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
|
||||
max_threads_per_block, false, data.value, data.present_value));
|
||||
}
|
||||
// self attention with past and present state
|
||||
else {
|
||||
assert(data.past_key != nullptr);
|
||||
assert(data.past_value != nullptr);
|
||||
assert(data.present_key != nullptr);
|
||||
assert(data.present_value != nullptr);
|
||||
assert(data.query != nullptr);
|
||||
assert(data.key != nullptr);
|
||||
assert(data.value != nullptr);
|
||||
// TODO: supporting packed qkv for self attention may benefit performance
|
||||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
|
||||
max_threads_per_block, false, data.query, q));
|
||||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
|
||||
max_threads_per_block, false, data.key, k));
|
||||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
|
||||
max_threads_per_block, false, data.value, v));
|
||||
}
|
||||
qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
|
||||
}
|
||||
#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION
|
||||
// When past_key/past_value are inputted directly as key/value and there is no present_key/present_value
|
||||
else if ((data.use_memory_efficient_attention || data.use_flash_attention) &&
|
||||
data.past_key != nullptr &&
|
||||
data.past_value != nullptr &&
|
||||
parameters.pass_past_in_kv) {
|
||||
// Transpose past_key and past_value to use memory efficient attention
|
||||
|
||||
// past_key (BxNxSxH) => temp_k_workspace (BxSxNxH)
|
||||
ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads,
|
||||
max_threads_per_block, false, data.past_key, data.temp_k_workspace));
|
||||
// past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
|
||||
ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads,
|
||||
max_threads_per_block, false, data.past_value, data.temp_v_workspace));
|
||||
|
||||
// query => q, temp_k_workspace => k, temp_v_workspace => v
|
||||
LaunchAddBias(stream, max_threads_per_block,
|
||||
batch_size, sequence_length, kv_sequence_length,
|
||||
num_heads, qk_head_size, v_head_size,
|
||||
data.bias, data.query, data.temp_k_workspace, data.temp_v_workspace, q, k, v);
|
||||
|
||||
DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size);
|
||||
qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
|
||||
|
||||
data.past_key = nullptr;
|
||||
data.past_value = nullptr;
|
||||
}
|
||||
// When there is no past_key/past_value and there is present_key/present_value
|
||||
// (e.g. get initial kv to use as past_kv in the next iteration)
|
||||
else if ((data.use_memory_efficient_attention || data.use_flash_attention) &&
|
||||
data.present_key != nullptr &&
|
||||
data.present_value != nullptr) {
|
||||
// Use memory efficient attention kernel
|
||||
LaunchAddBias(stream, max_threads_per_block,
|
||||
batch_size, sequence_length, kv_sequence_length,
|
||||
num_heads, qk_head_size, v_head_size,
|
||||
data.bias, data.query, data.key, data.value, q, data.temp_k_workspace, data.temp_v_workspace);
|
||||
|
||||
// temp_k_workspace (BxSxNxH) => present_k (BxNxSxH)
|
||||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
|
||||
max_threads_per_block, false, data.temp_k_workspace, data.present_key));
|
||||
|
||||
// temp_v_workspace (BxSxNxH_v) => present_v (BxNxSxH_v)
|
||||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
|
||||
max_threads_per_block, false, data.temp_v_workspace, data.present_value));
|
||||
|
||||
DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("k(BSNH)", data.temp_k_workspace, batch_size, kv_sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("v(BSNH)", data.temp_v_workspace, batch_size, kv_sequence_length, num_heads, v_head_size);
|
||||
qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
|
||||
}
|
||||
#endif
|
||||
else {
|
||||
// Use unfused kernel for Q, use unfused kernel for K and V if needed
|
||||
constexpr int format = 0;
|
||||
// Query (BxSxNxH) => Q (BxNxSxH)
|
||||
LaunchAddBiasTranspose<T>(stream, 1, format, max_threads_per_block,
|
||||
batch_size, sequence_length, num_heads, qk_head_size,
|
||||
data.query, data.bias, q,
|
||||
true, -1);
|
||||
|
||||
if (!parameters.pass_past_in_kv) {
|
||||
T* k_dest = (data.past_key == nullptr && data.present_key != nullptr) ? data.present_key : k;
|
||||
T* v_dest = (data.past_value == nullptr && data.present_value != nullptr) ? data.present_value : v;
|
||||
|
||||
// Key (BxLxNxH) => K (BxNxLxH)
|
||||
LaunchAddBiasTranspose<T>(stream, 1, format, max_threads_per_block,
|
||||
batch_size, kv_sequence_length, num_heads, qk_head_size,
|
||||
data.key, data.bias + num_heads * qk_head_size, k_dest,
|
||||
true, -1);
|
||||
|
||||
// Value (BxLxNxH_v) => V (BxNxLxH_v)
|
||||
LaunchAddBiasTranspose<T>(stream, 1, format, max_threads_per_block,
|
||||
batch_size, kv_sequence_length, num_heads, v_head_size,
|
||||
data.value, data.bias + 2 * num_heads * qk_head_size, v_dest,
|
||||
true, -1);
|
||||
|
||||
DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size);
|
||||
DUMP_TENSOR_D("k(BNSH)", k_dest, batch_size, num_heads, kv_sequence_length, qk_head_size);
|
||||
DUMP_TENSOR_D("v(BNSH)", v_dest, batch_size, num_heads, kv_sequence_length, v_head_size);
|
||||
}
|
||||
qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// For MultiHeadAttention without past state, with packed QKV inputs
|
||||
template <typename T>
|
||||
Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters,
|
||||
AttentionData<T>& data,
|
||||
cudaStream_t stream,
|
||||
int max_threads_per_block,
|
||||
T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
|
||||
const int batch_size = parameters.batch_size;
|
||||
const int sequence_length = parameters.sequence_length;
|
||||
const int num_heads = parameters.num_heads;
|
||||
const int qk_head_size = parameters.head_size;
|
||||
const int v_head_size = parameters.v_head_size;
|
||||
void* fused_runner = data.fused_runner;
|
||||
|
||||
T* qkv = data.workspace;
|
||||
|
||||
bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional);
|
||||
|
||||
assert(data.bias == nullptr);
|
||||
assert(qk_head_size == v_head_size);
|
||||
|
||||
DUMP_TENSOR_INIT();
|
||||
DUMP_TENSOR_D("packed_qkv", data.query, batch_size * sequence_length, num_heads, 3, qk_head_size);
|
||||
|
||||
if (data.use_memory_efficient_attention || data.use_flash_attention) {
|
||||
// unpack qkv to BSNH. Note that there is no bias so we need not output query to q.
|
||||
constexpr int format = 4;
|
||||
T* qkv_add_bias = nullptr;
|
||||
LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block,
|
||||
batch_size, sequence_length, num_heads, qk_head_size,
|
||||
data.query, data.bias, qkv,
|
||||
true, v_head_size, qkv_add_bias, 3);
|
||||
DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("k(BSNH)", k, batch_size, sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("v(BSNH)", v, batch_size, sequence_length, num_heads, v_head_size);
|
||||
qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
|
||||
} else {
|
||||
if (!use_fused_kernel) {
|
||||
return ORT_MAKE_STATUS(
|
||||
ONNXRUNTIME, NOT_IMPLEMENTED,
|
||||
"packed QKV format is not implemented for current GPU. Please disable it in fusion options.");
|
||||
}
|
||||
|
||||
qkv_format = AttentionQkvFormat::QKV_BSN3H;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// For MultiHeadAttention without past state, with packed KV inputs
|
||||
template <typename T>
|
||||
Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters,
|
||||
AttentionData<T>& data,
|
||||
cudaStream_t stream,
|
||||
int max_threads_per_block,
|
||||
T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
|
||||
const int batch_size = parameters.batch_size;
|
||||
const int kv_sequence_length = parameters.kv_sequence_length;
|
||||
const int num_heads = parameters.num_heads;
|
||||
const int qk_head_size = parameters.head_size;
|
||||
const int v_head_size = parameters.v_head_size;
|
||||
|
||||
// TODO: unpack kv to BNSH for unfused kernel so that we can remove the following constraint.
|
||||
// CheckInputs verified this constraint.
|
||||
assert(data.bias == nullptr);
|
||||
assert(qk_head_size == v_head_size);
|
||||
|
||||
DUMP_TENSOR_INIT();
|
||||
DUMP_TENSOR_D("packed_kv", data.key, batch_size * kv_sequence_length, num_heads, 2, qk_head_size);
|
||||
|
||||
if (data.use_memory_efficient_attention || data.use_flash_attention) {
|
||||
// unpack kv to BSNH. Note that there is no bias so we need not output query to q.
|
||||
constexpr int format = 4;
|
||||
T* qkv_add_bias = nullptr;
|
||||
const T* kv_bias = (data.bias == nullptr ? data.bias : data.bias + parameters.hidden_size);
|
||||
LaunchAddBiasTranspose(stream, 2, format, max_threads_per_block,
|
||||
batch_size, kv_sequence_length, num_heads, qk_head_size,
|
||||
data.key, kv_bias, k,
|
||||
true, v_head_size, qkv_add_bias, 2);
|
||||
DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size);
|
||||
qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
|
||||
} else {
|
||||
if (data.fused_cross_attention_kernel == nullptr) {
|
||||
return ORT_MAKE_STATUS(
|
||||
ONNXRUNTIME, NOT_IMPLEMENTED,
|
||||
"packed KV format is not implemented for current GPU. Please disable packed kv in fusion options.");
|
||||
}
|
||||
|
||||
qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// For MultiHeadAttention without past state, with Q, K and V inputs
|
||||
template <typename T>
|
||||
Status PrepareQkv_MHA_NotPacked(contrib::AttentionParameters& parameters,
|
||||
AttentionData<T>& data,
|
||||
cudaStream_t stream,
|
||||
int max_threads_per_block,
|
||||
T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
|
||||
const int batch_size = parameters.batch_size;
|
||||
const int sequence_length = parameters.sequence_length;
|
||||
const int kv_sequence_length = parameters.kv_sequence_length;
|
||||
const int num_heads = parameters.num_heads;
|
||||
const int qk_head_size = parameters.head_size;
|
||||
const int v_head_size = parameters.v_head_size;
|
||||
void* fused_runner = data.fused_runner;
|
||||
|
||||
T* qkv = data.workspace;
|
||||
|
||||
bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional);
|
||||
bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional);
|
||||
|
||||
// gemm_buffer == nullptr and not packed
|
||||
assert(data.query != nullptr && data.key != nullptr && data.value != nullptr);
|
||||
|
||||
DUMP_TENSOR_INIT();
|
||||
DUMP_TENSOR_D("query", data.query, batch_size, sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("key", data.key, batch_size, kv_sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("value", data.value, batch_size, kv_sequence_length, num_heads, v_head_size);
|
||||
|
||||
#if DUMP_TENSOR_LEVEL > 1
|
||||
if (data.bias != nullptr) {
|
||||
DUMP_TENSOR_D("query_bias", data.bias, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("key_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("value_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size);
|
||||
}
|
||||
#endif
|
||||
|
||||
if (data.relative_position_bias != nullptr && parameters.broadcast_res_pos_bias) {
|
||||
DUMP_TENSOR_D("relative_position_bias", data.relative_position_bias,
|
||||
num_heads, sequence_length, kv_sequence_length);
|
||||
}
|
||||
|
||||
if (data.mask_index != nullptr && parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) {
|
||||
DUMP_TENSOR_D("mask_index", data.mask_index, 3 * batch_size + 2, 1);
|
||||
}
|
||||
|
||||
if (data.fused_cross_attention_kernel != nullptr) {
|
||||
assert(qk_head_size == v_head_size);
|
||||
|
||||
// For fused cross attention, besides adding bias, K and V needed to be packed:
|
||||
// K (BxSxNxH), V (BxSxNxH) => BxSxNx2xH
|
||||
LaunchAddBiasTransposeTrt(
|
||||
stream, max_threads_per_block,
|
||||
batch_size, sequence_length,
|
||||
num_heads, qk_head_size,
|
||||
data.bias, data.query, data.key, data.value, qkv, true, kv_sequence_length);
|
||||
|
||||
qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
|
||||
}
|
||||
#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION
|
||||
else if (data.use_memory_efficient_attention || data.use_flash_attention) {
|
||||
LaunchAddBias(stream, max_threads_per_block,
|
||||
batch_size, sequence_length, kv_sequence_length,
|
||||
num_heads, qk_head_size, v_head_size,
|
||||
data.bias, data.query, data.key, data.value, q, k, v);
|
||||
|
||||
DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size);
|
||||
qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
|
||||
}
|
||||
#endif
|
||||
else if (use_fused_kernel) {
|
||||
assert(qk_head_size == v_head_size);
|
||||
|
||||
// Q (BxSxNxH), K (BxSxNxH), V (BxSxNxH) => BxSxNx(H + H + H)
|
||||
LaunchAddBiasTransposeTrt(
|
||||
stream, max_threads_per_block,
|
||||
batch_size, sequence_length,
|
||||
num_heads, qk_head_size,
|
||||
data.bias, data.query, data.key, data.value, qkv, false, kv_sequence_length);
|
||||
DUMP_TENSOR_D("qkv(BSN3H)", qkv, batch_size, sequence_length, num_heads, 2 * qk_head_size + v_head_size);
|
||||
|
||||
qkv_format = AttentionQkvFormat::QKV_BSN3H;
|
||||
} else { // unfused kernel
|
||||
ORT_ENFORCE(!use_fused_causal, "MultiHeadAttention has not enabled fused causal");
|
||||
|
||||
// Query (BxSxNxH) => Q (BxNxSxH)
|
||||
constexpr int format = 0;
|
||||
LaunchAddBiasTranspose<T>(stream, 1, format, max_threads_per_block,
|
||||
batch_size, sequence_length, num_heads, qk_head_size,
|
||||
data.query, data.bias, q,
|
||||
true, -1);
|
||||
|
||||
// Key (BxLxNxH) => K (BxNxLxH)
|
||||
LaunchAddBiasTranspose<T>(stream, 1, format, max_threads_per_block,
|
||||
batch_size, kv_sequence_length, num_heads, qk_head_size,
|
||||
data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, k,
|
||||
true, -1);
|
||||
|
||||
// Value (BxLxNxH_v) => K (BxNxLxH_v)
|
||||
LaunchAddBiasTranspose<T>(stream, 1, format, max_threads_per_block,
|
||||
batch_size, kv_sequence_length, num_heads, v_head_size,
|
||||
data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, v,
|
||||
true, -1);
|
||||
|
||||
DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size);
|
||||
DUMP_TENSOR_D("k(BNSH)", k, batch_size, num_heads, kv_sequence_length, qk_head_size);
|
||||
DUMP_TENSOR_D("v(BNSH)", v, batch_size, num_heads, kv_sequence_length, v_head_size);
|
||||
qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status PrepareQkv(contrib::AttentionParameters& parameters,
|
||||
AttentionData<T>& data,
|
||||
cudaStream_t stream,
|
||||
int max_threads_per_block,
|
||||
QkvData<T>& qkv) {
|
||||
if (data.has_qkv_workspace) {
|
||||
const int size_per_batch_q = parameters.sequence_length * parameters.head_size;
|
||||
const int size_per_batch_k = parameters.kv_sequence_length * parameters.head_size;
|
||||
const int size_per_batch_v = parameters.kv_sequence_length * parameters.v_head_size;
|
||||
const int batches = parameters.batch_size * parameters.num_heads;
|
||||
const size_t elements_q = static_cast<size_t>(batches) * static_cast<size_t>(size_per_batch_q);
|
||||
const size_t elements_k = static_cast<size_t>(batches) * static_cast<size_t>(size_per_batch_k);
|
||||
const size_t elements_v = static_cast<size_t>(batches) * static_cast<size_t>(size_per_batch_v);
|
||||
qkv.q = data.workspace;
|
||||
qkv.k = data.workspace + elements_q;
|
||||
qkv.v = qkv.k + elements_k;
|
||||
qkv.after_v = qkv.v + elements_v;
|
||||
}
|
||||
|
||||
if (nullptr != data.gemm_buffer) { // Attention operator
|
||||
ORT_RETURN_IF_ERROR(PrepareQkv_Attention<T>(parameters, data, stream, max_threads_per_block,
|
||||
qkv.format));
|
||||
} else if (data.past_key != nullptr || data.present_key != nullptr) { // mha operator with past/present state
|
||||
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast(parameters, data, stream, max_threads_per_block,
|
||||
qkv.q, qkv.k, qkv.v, qkv.format));
|
||||
} else if (data.key == nullptr) { // multihead attention operator, no past, packed qkv
|
||||
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedQKV(parameters, data, stream, max_threads_per_block,
|
||||
qkv.q, qkv.k, qkv.v, qkv.format));
|
||||
} else if (data.value == nullptr) { // multihead attention operator, no past, packed kv
|
||||
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedKV(parameters, data, stream, max_threads_per_block,
|
||||
qkv.q, qkv.k, qkv.v, qkv.format));
|
||||
} else { // multihead attention operator, no past, separated Q/K/V inputs
|
||||
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NotPacked(parameters, data, stream, max_threads_per_block,
|
||||
qkv.q, qkv.k, qkv.v, qkv.format));
|
||||
}
|
||||
|
||||
CUDA_RETURN_IF_ERROR(cudaGetLastError());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Template Instantiation
|
||||
template Status PrepareQkv<float>(
|
||||
contrib::AttentionParameters& parameters,
|
||||
AttentionData<float>& data,
|
||||
cudaStream_t stream,
|
||||
int max_threads_per_block,
|
||||
QkvData<float>& qkv);
|
||||
|
||||
template Status PrepareQkv<half>(
|
||||
contrib::AttentionParameters& parameters,
|
||||
AttentionData<half>& data,
|
||||
cudaStream_t stream,
|
||||
int max_threads_per_block,
|
||||
QkvData<half>& qkv);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -263,14 +263,12 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
AttentionData<CudaT> data;
|
||||
data.gemm_buffer = nullptr;
|
||||
data.bias = (nullptr == bias) ? nullptr : reinterpret_cast<const CudaT*>(bias->Data<T>());
|
||||
data.query = reinterpret_cast<const CudaT*>(query->Data<T>());
|
||||
data.key = (nullptr == key || parameters.pass_past_in_kv) ? nullptr : reinterpret_cast<const CudaT*>(key->Data<T>());
|
||||
data.value = (nullptr == value || parameters.pass_past_in_kv) ? nullptr : reinterpret_cast<const CudaT*>(value->Data<T>());
|
||||
data.mask_index = (nullptr == key_padding_mask) ? nullptr : key_padding_mask->Data<int>();
|
||||
data.mask_index_dims = (nullptr == key_padding_mask) ? gsl::span<const int64_t>() : key_padding_mask->Shape().GetDims();
|
||||
data.past = nullptr;
|
||||
data.past_key = pass_key_value_as_past ? reinterpret_cast<const CudaT*>(key->Data<T>())
|
||||
: (nullptr == past_key) ? nullptr
|
||||
: reinterpret_cast<const CudaT*>(past_key->Data<T>());
|
||||
|
|
@ -283,7 +281,6 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
data.temp_k_workspace = use_temp_k_v_workspace ? reinterpret_cast<CudaT*>(temp_k_work_space.get()) : nullptr;
|
||||
data.temp_v_workspace = use_temp_k_v_workspace ? reinterpret_cast<CudaT*>(temp_v_work_space.get()) : nullptr;
|
||||
data.output = reinterpret_cast<CudaT*>(output->MutableData<T>());
|
||||
data.present = nullptr;
|
||||
data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast<CudaT*>(present_key->MutableData<T>());
|
||||
data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast<CudaT*>(present_value->MutableData<T>());
|
||||
data.fused_runner = reinterpret_cast<void*>(fused_runner);
|
||||
|
|
|
|||
|
|
@ -195,28 +195,21 @@ Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
|
|||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
AttentionData<CudaT> data;
|
||||
data.gemm_buffer = reinterpret_cast<CudaT*>(gemm_buffer.get());
|
||||
data.bias = nullptr; // bias has been added
|
||||
data.query = nullptr;
|
||||
data.key = nullptr;
|
||||
data.value = nullptr;
|
||||
data.mask_index = (nullptr == mask_index) ? nullptr : mask_index->Data<int>();
|
||||
data.mask_index_dims = (nullptr == mask_index) ? gsl::span<const int64_t>() : mask_index->Shape().GetDims();
|
||||
data.past = (nullptr == past_tensor) ? nullptr : reinterpret_cast<const CudaT*>(past_tensor->Data<T>());
|
||||
data.past_key = nullptr;
|
||||
data.past_value = nullptr;
|
||||
data.relative_position_bias = nullptr; // add_qk is not supported in quantized attention
|
||||
if (nullptr != mask_index) {
|
||||
data.mask_index = mask_index->Data<int>();
|
||||
data.mask_index_dims = mask_index->Shape().GetDims();
|
||||
}
|
||||
|
||||
if (nullptr != past_tensor) {
|
||||
data.past = reinterpret_cast<const CudaT*>(past_tensor->Data<T>());
|
||||
}
|
||||
|
||||
data.has_qkv_workspace = true;
|
||||
data.workspace = reinterpret_cast<CudaT*>(work_space.get());
|
||||
data.output = reinterpret_cast<CudaT*>(output->MutableData<T>());
|
||||
data.present = (nullptr == present) ? nullptr : reinterpret_cast<CudaT*>(present->MutableData<T>());
|
||||
data.present_key = nullptr;
|
||||
data.present_value = nullptr;
|
||||
data.fused_runner = fused_runner;
|
||||
data.fused_cross_attention_kernel = nullptr;
|
||||
data.use_flash_attention = use_flash_attention;
|
||||
data.use_memory_efficient_attention = use_memory_efficient_attention;
|
||||
data.cumulated_sequence_length_q_cache = nullptr;
|
||||
data.cumulated_sequence_length_kv_cache = nullptr;
|
||||
if (nullptr != present) {
|
||||
data.present = reinterpret_cast<CudaT*>(present->MutableData<T>());
|
||||
}
|
||||
|
||||
return QkvToContext<CudaT>(GetDeviceProp(), cublas, context->GetComputeStream(), parameters, data);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue