Refactoring of attention cuda kernel: move prepare qkv and concat_past_to_present (#17559)

To avoid a huge cu file and make code more readable:
 - Move PrepareQKV to separate cu file (attention_prepare_qkv.cu)
 - Move ConcatPastToPresent to attention_concat.cu
 - Add default value for AttentionData
- Add a data structure QkvData to track Q, K and V pointers and track
QKV format.
This commit is contained in:
Tianlei Wu 2023-09-15 10:57:29 -07:00 committed by GitHub
parent af80542e65
commit adb0be45d3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 729 additions and 674 deletions

View file

@ -10,6 +10,7 @@ set(contrib_ops_excluded_files
"bert/attention_impl.cu"
"bert/attention_softmax.h"
"bert/attention_softmax.cu"
"bert/attention_prepare_qkv.cu"
"bert/decoder_masked_multihead_attention.h"
"bert/decoder_masked_multihead_attention.cc"
"bert/decoder_masked_self_attention.h"

View file

@ -249,30 +249,28 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
typedef typename ToCudaType<T>::MappedType CudaT;
AttentionData<CudaT> data;
data.gemm_buffer = reinterpret_cast<CudaT*>(gemm_buffer.get());
data.bias = nullptr == bias ? nullptr : reinterpret_cast<const CudaT*>(bias->Data<T>());
data.query = nullptr;
data.key = nullptr;
data.value = nullptr;
data.mask_index = (nullptr == mask_index) ? nullptr : mask_index->Data<int>();
data.mask_index_dims = (nullptr == mask_index) ? gsl::span<const int64_t>() : mask_index->Shape().GetDims();
data.past = (nullptr == past) ? nullptr : reinterpret_cast<const CudaT*>(past->Data<T>());
data.past_key = nullptr;
data.past_value = nullptr;
data.relative_position_bias = (nullptr == relative_position_bias)
? nullptr
: reinterpret_cast<const CudaT*>(relative_position_bias->Data<T>());
if (nullptr != bias) {
data.bias = reinterpret_cast<const CudaT*>(bias->Data<T>());
}
if (nullptr != mask_index) {
data.mask_index = mask_index->Data<int>();
data.mask_index_dims = mask_index->Shape().GetDims();
}
if (nullptr != past) {
data.past = reinterpret_cast<const CudaT*>(past->Data<T>());
}
if (nullptr != relative_position_bias) {
data.relative_position_bias = reinterpret_cast<const CudaT*>(relative_position_bias->Data<T>());
}
data.has_qkv_workspace = true;
data.workspace = reinterpret_cast<CudaT*>(work_space.get());
data.output = reinterpret_cast<CudaT*>(output->MutableData<T>());
data.present = (nullptr == present) ? nullptr : reinterpret_cast<CudaT*>(present->MutableData<T>());
data.present_key = nullptr;
data.present_value = nullptr;
if (nullptr != present) {
data.present = reinterpret_cast<CudaT*>(present->MutableData<T>());
}
data.fused_runner = reinterpret_cast<void*>(fused_runner);
data.fused_cross_attention_kernel = nullptr;
data.use_flash_attention = use_flash_attention;
data.use_memory_efficient_attention = use_memory_efficient_attention;
data.cumulated_sequence_length_q_cache = nullptr;
data.cumulated_sequence_length_kv_cache = nullptr;
return QkvToContext<CudaT>(device_prop, cublas, context->GetComputeStream(), parameters, data);
}

View file

@ -93,16 +93,16 @@ __global__ void ConcatTensorToTensorLarge(const int tensor_add_sequence_length,
}
Status LaunchConcatTensorToTensor(cudaStream_t stream,
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const int matrix_num,
const float* tensor_in,
const float* tensor_add,
float* tensor_out) {
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const int matrix_num,
const float* tensor_in,
const float* tensor_add,
float* tensor_out) {
const dim3 grid(all_sequence_length, batch_size, matrix_num);
if (0 == (head_size & 1)) {
const int H = head_size / 2;
@ -137,16 +137,16 @@ Status LaunchConcatTensorToTensor(cudaStream_t stream,
}
Status LaunchConcatTensorToTensor(cudaStream_t stream,
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const int matrix_num,
const half* tensor_in,
const half* tensor_add,
half* tensor_out) {
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const int matrix_num,
const half* tensor_in,
const half* tensor_add,
half* tensor_out) {
const dim3 grid(all_sequence_length, batch_size, matrix_num);
if (0 == (head_size % 4)) {
const int H = head_size / 4;
@ -197,15 +197,15 @@ Status LaunchConcatTensorToTensor(cudaStream_t stream,
}
Status LaunchConcatPastToPresent(cudaStream_t stream,
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const float* past,
const float* k_v,
float* present) {
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const float* past,
const float* k_v,
float* present) {
return LaunchConcatTensorToTensor(
stream,
all_sequence_length,
@ -221,15 +221,15 @@ Status LaunchConcatPastToPresent(cudaStream_t stream,
}
Status LaunchConcatPastToPresent(cudaStream_t stream,
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const half* past,
const half* k_v,
half* present) {
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const half* past,
const half* k_v,
half* present) {
return LaunchConcatTensorToTensor(
stream,
all_sequence_length,
@ -244,6 +244,90 @@ Status LaunchConcatPastToPresent(cudaStream_t stream,
present);
}
#ifndef USE_ROCM // exclude from hipify
template <typename T>
Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size,
int sequence_length, int total_sequence_length, bool pass_past_in_kv,
cudaStream_t stream,
int max_threads_per_block,
AttentionData<T>& data,
QkvData<T>& qkv) {
// Concat past key value to present (2xBxNxLxH), where L is kv_sequence_length and T is total_sequence_length.
// past_k (BxNxPxH) + k (BxNxLxH) => present_k (BxNxTxH)
// past_v (BxNxPxH) + v (BxNxLxH) => present_v (BxNxTxH)
// When there is past state, the head size for Q/K/V shall be same: H == H_v.
if (nullptr != data.present) {
assert(qkv.format == AttentionQkvFormat::Q_K_V_BNSH || qkv.format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH);
ORT_RETURN_IF_ERROR(
LaunchConcatPastToPresent(
stream, total_sequence_length, sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, data.past, qkv.k, data.present));
// Update pointers to present_k and present_v.
qkv.k = data.present;
qkv.v = data.present + batch_size * num_heads * total_sequence_length * qk_head_size;
}
if (nullptr != data.past_key || nullptr != data.present_key) {
if (nullptr != data.past_key && nullptr == data.present_key) {
qkv.k = const_cast<T*>(data.past_key);
qkv.v = const_cast<T*>(data.past_value);
} else if (nullptr == data.past_key && nullptr != data.present_key) {
if (qkv.format == AttentionQkvFormat::Q_K_V_BNSH) {
qkv.k = data.present_key;
qkv.v = data.present_value;
} else {
assert(qkv.format == AttentionQkvFormat::Q_K_V_BSNH);
qkv.k = data.temp_k_workspace;
qkv.v = data.temp_v_workspace;
}
} else if (pass_past_in_kv) {
// past_key and past_value are used directly as key and value in attention computations
qkv.k = const_cast<T*>(data.past_key);
qkv.v = const_cast<T*>(data.past_value);
// This path has a memory copy from past_key and past_value to present_key and present_value
// Avoid this path since the memory copy is unnecessary because past_key == present_key and
// past_value == present_value
int64_t k_size = (int64_t)batch_size * num_heads * total_sequence_length * qk_head_size;
int64_t v_size = (int64_t)batch_size * num_heads * total_sequence_length * v_head_size;
cudaMemcpyAsync(data.present_key, data.past_key, k_size * sizeof(T), cudaMemcpyDeviceToDevice, stream);
cudaMemcpyAsync(data.present_value, data.past_value, v_size * sizeof(T), cudaMemcpyDeviceToDevice, stream);
} else {
ORT_RETURN_IF_ERROR(
LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length,
batch_size, qk_head_size, num_heads,
max_threads_per_block, 1, data.past_key, qkv.k, data.present_key));
ORT_RETURN_IF_ERROR(
LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length,
batch_size, v_head_size, num_heads,
max_threads_per_block, 1, data.past_value, qkv.v, data.present_value));
// Update pointers to present_k and present_v.
qkv.k = data.present_key;
qkv.v = data.present_value;
}
}
return CUDA_CALL(cudaGetLastError());
}
// Template Instantiation
template Status ConcatPastToPresent<float>(int batch_size, int num_heads, int qk_head_size, int v_head_size,
int sequence_length, int total_sequence_length, bool pass_past_in_kv,
cudaStream_t stream,
int max_threads_per_block,
AttentionData<float>& data,
QkvData<float>& qkv);
template Status ConcatPastToPresent<half>(int batch_size, int num_heads, int qk_head_size, int v_head_size,
int sequence_length, int total_sequence_length, bool pass_past_in_kv,
cudaStream_t stream,
int max_threads_per_block,
AttentionData<half>& data,
QkvData<half>& qkv);
#endif
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -26,16 +26,11 @@ limitations under the License.
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <cassert>
#include <cuda_fp16.h>
#include <cub/cub.cuh>
#include "core/providers/cuda/cu_inc/common.cuh"
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/shared_inc/fpgeneric.h"
#include "contrib_ops/cuda/bert/attention_impl.h"
#include "contrib_ops/cuda/bert/attention_softmax.h"
#include "contrib_ops/cuda/bert/transformer_common.h"
#include "contrib_ops/cuda/bert/add_bias_transpose.h"
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h"
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h"
#include "contrib_ops/cpu/bert/attention_base.h"
@ -43,6 +38,7 @@ limitations under the License.
#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
#include "contrib_ops/cuda/bert/attention_impl.h"
using namespace onnxruntime::cuda;
using namespace onnxruntime::contrib::attention_softmax_cuda;
@ -286,446 +282,6 @@ template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream,
const half* qkv_buffer,
half* present);
template <typename T>
Status PrepareQkv_Attention(contrib::AttentionParameters& parameters,
AttentionData<T>& data,
cudaStream_t stream,
int max_threads_per_block,
AttentionQkvFormat& qkv_format) {
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int num_heads = parameters.num_heads;
const int qk_head_size = parameters.head_size;
const int v_head_size = parameters.v_head_size;
const bool past_present_share_buffer = parameters.past_present_share_buffer;
void* fused_runner = data.fused_runner;
bool use_flash_or_efficient_attention = data.use_flash_attention || data.use_memory_efficient_attention;
T* qkv = data.workspace;
bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional);
bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional);
if (data.bias == nullptr) {
assert(nullptr == fused_runner);
// For quantized attention, bias has been added so only need transpose here.
// gemm_buffer should be BxSx3xNxH => qkv: 3xBxNxSxH
assert(qk_head_size == v_head_size);
int matrix_to_trans = (past_present_share_buffer ? 1 : 3);
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, matrix_to_trans, sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.gemm_buffer, qkv, 3));
qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
} else {
// For fused TRT attention, transpose qkv to BxSxNx3xH (format 2)
// For flash or memory efficient attention, transpose to 3xBxSxNxH (format 3)
// For unfused kernel, transpose to 3xBxNxSxH (format 1)
// For fused causal kernel, use format 1 since we need have K and V to update present state,
// at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel.
const int format = (use_fused_kernel ? 2 : (use_flash_or_efficient_attention ? 3 : 1));
qkv_format = use_fused_kernel
? AttentionQkvFormat::QKV_BSN3H
: (use_flash_or_efficient_attention
? AttentionQkvFormat::Q_K_V_BSNH
: (use_fused_causal
? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH
: AttentionQkvFormat::Q_K_V_BNSH));
// For fused causal, we will update gemm_buffer with bias directly.
T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr;
int matrix_to_transpose = ((format == AttentionQkvFormat::Q_K_V_BNSH && past_present_share_buffer) ? 1 : 3);
// format 1: BxSx(NH + NH + NH_v) => BxNxSxH + BxNxSxH + BxNxSxH_v
// format 2: BxSx(NH + NH + NH) => BxSxNx(H + H + H)
LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size,
data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias,
3, parameters.do_rotary, parameters.past_sequence_length);
}
return Status::OK();
}
// For MultiHeadAttention with past state
template <typename T>
Status PrepareQkv_MHA_WithPast(contrib::AttentionParameters& parameters,
AttentionData<T>& data,
cudaStream_t stream,
int max_threads_per_block,
T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int kv_sequence_length = parameters.kv_sequence_length;
const int num_heads = parameters.num_heads;
const int qk_head_size = parameters.head_size;
const int v_head_size = parameters.v_head_size;
DUMP_TENSOR_INIT();
if (data.bias == nullptr) {
// Below logic does not support fused attention with past without bias
// When there is past state, the format shall be BxNxSxH, so we disable fused attention when there is past.
// cross attention with past state
if (data.past_key != nullptr && data.present_key == nullptr) {
assert(data.past_value != nullptr);
assert(data.query != nullptr);
assert(data.key == nullptr);
assert(data.value == nullptr);
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.query, q));
}
// cross attention with present state or self attention with present state
else if (data.past_key == nullptr && data.present_key != nullptr) {
assert(data.past_value == nullptr);
assert(data.present_value != nullptr);
assert(data.query != nullptr);
assert(data.key != nullptr);
assert(data.value != nullptr);
// TODO: supporting packed qkv for self attention may benefit performance
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.query, q));
// TODO: supporting packed kv for cross attention may benefit performance
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.key, data.present_key));
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
max_threads_per_block, false, data.value, data.present_value));
}
// self attention with past and present state
else {
assert(data.past_key != nullptr);
assert(data.past_value != nullptr);
assert(data.present_key != nullptr);
assert(data.present_value != nullptr);
assert(data.query != nullptr);
assert(data.key != nullptr);
assert(data.value != nullptr);
// TODO: supporting packed qkv for self attention may benefit performance
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.query, q));
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.key, k));
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
max_threads_per_block, false, data.value, v));
}
qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
}
#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION
// When past_key/past_value are inputted directly as key/value and there is no present_key/present_value
else if ((data.use_memory_efficient_attention || data.use_flash_attention) &&
data.past_key != nullptr &&
data.past_value != nullptr &&
parameters.pass_past_in_kv) {
// Transpose past_key and past_value to use memory efficient attention
// past_key (BxNxSxH) => temp_k_workspace (BxSxNxH)
ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.past_key, data.temp_k_workspace));
// past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.past_value, data.temp_v_workspace));
// query => q, temp_k_workspace => k, temp_v_workspace => v
LaunchAddBias(stream, max_threads_per_block,
batch_size, sequence_length, kv_sequence_length,
num_heads, qk_head_size, v_head_size,
data.bias, data.query, data.temp_k_workspace, data.temp_v_workspace, q, k, v);
DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size);
qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
data.past_key = nullptr;
data.past_value = nullptr;
}
// When there is no past_key/past_value and there is present_key/present_value
// (e.g. get initial kv to use as past_kv in the next iteration)
else if ((data.use_memory_efficient_attention || data.use_flash_attention) &&
data.present_key != nullptr &&
data.present_value != nullptr) {
// Use memory efficient attention kernel
LaunchAddBias(stream, max_threads_per_block,
batch_size, sequence_length, kv_sequence_length,
num_heads, qk_head_size, v_head_size,
data.bias, data.query, data.key, data.value, q, data.temp_k_workspace, data.temp_v_workspace);
// temp_k_workspace (BxSxNxH) => present_k (BxNxSxH)
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.temp_k_workspace, data.present_key));
// temp_v_workspace (BxSxNxH_v) => present_v (BxNxSxH_v)
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
max_threads_per_block, false, data.temp_v_workspace, data.present_value));
DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("k(BSNH)", data.temp_k_workspace, batch_size, kv_sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("v(BSNH)", data.temp_v_workspace, batch_size, kv_sequence_length, num_heads, v_head_size);
qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
}
#endif
else {
// Use unfused kernel for Q, use unfused kernel for K and V if needed
constexpr int format = 0;
// Query (BxSxNxH) => Q (BxNxSxH)
LaunchAddBiasTranspose<T>(stream, 1, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size,
data.query, data.bias, q,
true, -1);
if (!parameters.pass_past_in_kv) {
T* k_dest = (data.past_key == nullptr && data.present_key != nullptr) ? data.present_key : k;
T* v_dest = (data.past_value == nullptr && data.present_value != nullptr) ? data.present_value : v;
// Key (BxLxNxH) => K (BxNxLxH)
LaunchAddBiasTranspose<T>(stream, 1, format, max_threads_per_block,
batch_size, kv_sequence_length, num_heads, qk_head_size,
data.key, data.bias + num_heads * qk_head_size, k_dest,
true, -1);
// Value (BxLxNxH_v) => V (BxNxLxH_v)
LaunchAddBiasTranspose<T>(stream, 1, format, max_threads_per_block,
batch_size, kv_sequence_length, num_heads, v_head_size,
data.value, data.bias + 2 * num_heads * qk_head_size, v_dest,
true, -1);
DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size);
DUMP_TENSOR_D("k(BNSH)", k_dest, batch_size, num_heads, kv_sequence_length, qk_head_size);
DUMP_TENSOR_D("v(BNSH)", v_dest, batch_size, num_heads, kv_sequence_length, v_head_size);
}
qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
}
return Status::OK();
}
// For MultiHeadAttention without past state, with packed QKV inputs
template <typename T>
Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters,
AttentionData<T>& data,
cudaStream_t stream,
int max_threads_per_block,
T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int num_heads = parameters.num_heads;
const int qk_head_size = parameters.head_size;
const int v_head_size = parameters.v_head_size;
void* fused_runner = data.fused_runner;
T* qkv = data.workspace;
bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional);
assert(data.bias == nullptr);
assert(qk_head_size == v_head_size);
DUMP_TENSOR_INIT();
DUMP_TENSOR_D("packed_qkv", data.query, batch_size * sequence_length, num_heads, 3, qk_head_size);
if (data.use_memory_efficient_attention || data.use_flash_attention) {
// unpack qkv to BSNH. Note that there is no bias so we need not output query to q.
constexpr int format = 4;
T* qkv_add_bias = nullptr;
LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size,
data.query, data.bias, qkv,
true, v_head_size, qkv_add_bias, 3);
DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("k(BSNH)", k, batch_size, sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("v(BSNH)", v, batch_size, sequence_length, num_heads, v_head_size);
qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
} else {
if (!use_fused_kernel) {
return ORT_MAKE_STATUS(
ONNXRUNTIME, NOT_IMPLEMENTED,
"packed QKV format is not implemented for current GPU. Please disable it in fusion options.");
}
qkv_format = AttentionQkvFormat::QKV_BSN3H;
}
return Status::OK();
}
// For MultiHeadAttention without past state, with packed KV inputs
template <typename T>
Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters,
AttentionData<T>& data,
cudaStream_t stream,
int max_threads_per_block,
T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
const int batch_size = parameters.batch_size;
const int kv_sequence_length = parameters.kv_sequence_length;
const int num_heads = parameters.num_heads;
const int qk_head_size = parameters.head_size;
const int v_head_size = parameters.v_head_size;
// TODO: unpack kv to BNSH for unfused kernel so that we can remove the following constraint.
// CheckInputs verified this constraint.
assert(data.bias == nullptr);
assert(qk_head_size == v_head_size);
DUMP_TENSOR_INIT();
DUMP_TENSOR_D("packed_kv", data.key, batch_size * kv_sequence_length, num_heads, 2, qk_head_size);
if (data.use_memory_efficient_attention || data.use_flash_attention) {
// unpack kv to BSNH. Note that there is no bias so we need not output query to q.
constexpr int format = 4;
T* qkv_add_bias = nullptr;
const T* kv_bias = (data.bias == nullptr ? data.bias : data.bias + parameters.hidden_size);
LaunchAddBiasTranspose(stream, 2, format, max_threads_per_block,
batch_size, kv_sequence_length, num_heads, qk_head_size,
data.key, kv_bias, k,
true, v_head_size, qkv_add_bias, 2);
DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size);
qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
} else {
if (data.fused_cross_attention_kernel == nullptr) {
return ORT_MAKE_STATUS(
ONNXRUNTIME, NOT_IMPLEMENTED,
"packed KV format is not implemented for current GPU. Please disable packed kv in fusion options.");
}
qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
}
return Status::OK();
}
// For MultiHeadAttention without past state, with Q, K and V inputs
template <typename T>
Status PrepareQkv_MHA_NotPacked(contrib::AttentionParameters& parameters,
AttentionData<T>& data,
cudaStream_t stream,
int max_threads_per_block,
T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int kv_sequence_length = parameters.kv_sequence_length;
const int num_heads = parameters.num_heads;
const int qk_head_size = parameters.head_size;
const int v_head_size = parameters.v_head_size;
void* fused_runner = data.fused_runner;
T* qkv = data.workspace;
bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional);
bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional);
// gemm_buffer == nullptr and not packed
assert(data.query != nullptr && data.key != nullptr && data.value != nullptr);
DUMP_TENSOR_INIT();
DUMP_TENSOR_D("query", data.query, batch_size, sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("key", data.key, batch_size, kv_sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("value", data.value, batch_size, kv_sequence_length, num_heads, v_head_size);
#if DUMP_TENSOR_LEVEL > 1
if (data.bias != nullptr) {
DUMP_TENSOR_D("query_bias", data.bias, num_heads, qk_head_size);
DUMP_TENSOR_D("key_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size);
DUMP_TENSOR_D("value_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size);
}
#endif
if (data.relative_position_bias != nullptr && parameters.broadcast_res_pos_bias) {
DUMP_TENSOR_D("relative_position_bias", data.relative_position_bias,
num_heads, sequence_length, kv_sequence_length);
}
if (data.mask_index != nullptr && parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) {
DUMP_TENSOR_D("mask_index", data.mask_index, 3 * batch_size + 2, 1);
}
if (data.fused_cross_attention_kernel != nullptr) {
assert(qk_head_size == v_head_size);
// For fused cross attention, besides adding bias, K and V needed to be packed:
// K (BxSxNxH), V (BxSxNxH) => BxSxNx2xH
LaunchAddBiasTransposeTrt(
stream, max_threads_per_block,
batch_size, sequence_length,
num_heads, qk_head_size,
data.bias, data.query, data.key, data.value, qkv, true, kv_sequence_length);
qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
}
#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION
else if (data.use_memory_efficient_attention || data.use_flash_attention) {
LaunchAddBias(stream, max_threads_per_block,
batch_size, sequence_length, kv_sequence_length,
num_heads, qk_head_size, v_head_size,
data.bias, data.query, data.key, data.value, q, k, v);
DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size);
qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
}
#endif
else if (use_fused_kernel) {
assert(qk_head_size == v_head_size);
// Q (BxSxNxH), K (BxSxNxH), V (BxSxNxH) => BxSxNx(H + H + H)
LaunchAddBiasTransposeTrt(
stream, max_threads_per_block,
batch_size, sequence_length,
num_heads, qk_head_size,
data.bias, data.query, data.key, data.value, qkv, false, kv_sequence_length);
DUMP_TENSOR_D("qkv(BSN3H)", qkv, batch_size, sequence_length, num_heads, 2 * qk_head_size + v_head_size);
qkv_format = AttentionQkvFormat::QKV_BSN3H;
} else { // unfused kernel
ORT_ENFORCE(!use_fused_causal, "MultiHeadAttention has not enabled fused causal");
// Query (BxSxNxH) => Q (BxNxSxH)
constexpr int format = 0;
LaunchAddBiasTranspose<T>(stream, 1, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size,
data.query, data.bias, q,
true, -1);
// Key (BxLxNxH) => K (BxNxLxH)
LaunchAddBiasTranspose<T>(stream, 1, format, max_threads_per_block,
batch_size, kv_sequence_length, num_heads, qk_head_size,
data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, k,
true, -1);
// Value (BxLxNxH_v) => K (BxNxLxH_v)
LaunchAddBiasTranspose<T>(stream, 1, format, max_threads_per_block,
batch_size, kv_sequence_length, num_heads, v_head_size,
data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, v,
true, -1);
DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size);
DUMP_TENSOR_D("k(BNSH)", k, batch_size, num_heads, kv_sequence_length, qk_head_size);
DUMP_TENSOR_D("v(BNSH)", v, batch_size, num_heads, kv_sequence_length, v_head_size);
qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
}
return Status::OK();
}
template <typename T>
Status PrepareQkv(contrib::AttentionParameters& parameters,
AttentionData<T>& data,
cudaStream_t stream,
int max_threads_per_block,
T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
if (nullptr != data.gemm_buffer) { // Attention operator
ORT_RETURN_IF_ERROR(PrepareQkv_Attention<T>(parameters, data, stream, max_threads_per_block, qkv_format));
} else if (data.past_key != nullptr || data.present_key != nullptr) { // mha operator with past/present state
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format));
} else if (data.key == nullptr) { // multihead attention operator, no past, packed qkv
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedQKV(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format));
} else if (data.value == nullptr) { // multihead attention operator, no past, packed kv
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedKV(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format));
} else { // multihead attention operator, no past, separated Q/K/V inputs
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NotPacked(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format));
}
CUDA_RETURN_IF_ERROR(cudaGetLastError());
return Status::OK();
}
template <typename T>
Status QkvToContext(
const cudaDeviceProp& device_prop,
@ -755,92 +311,22 @@ Status QkvToContext(
const int batches = batch_size * num_heads;
T* qkv = nullptr;
T* q = nullptr;
T* k = nullptr;
T* v = nullptr;
T* scratch1 = data.workspace;
if (data.has_qkv_workspace) {
const int size_per_batch_q = sequence_length * qk_head_size;
const int size_per_batch_k = kv_sequence_length * qk_head_size;
const int size_per_batch_v = kv_sequence_length * v_head_size;
const size_t elements_q = static_cast<size_t>(batches) * static_cast<size_t>(size_per_batch_q);
const size_t elements_k = static_cast<size_t>(batches) * static_cast<size_t>(size_per_batch_k);
const size_t elements_v = static_cast<size_t>(batches) * static_cast<size_t>(size_per_batch_v);
qkv = data.workspace;
q = qkv;
k = q + elements_q;
v = k + elements_k;
scratch1 = v + elements_v;
}
bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional);
bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional);
AttentionQkvFormat qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
ORT_RETURN_IF_ERROR(PrepareQkv<T>(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format));
QkvData<T> qkv;
ORT_RETURN_IF_ERROR(PrepareQkv<T>(parameters, data, stream, max_threads_per_block, qkv));
T* scratch1 = data.has_qkv_workspace ? qkv.after_v : data.workspace;
int present_size_per_batch_k = 0;
int present_size_per_batch_v = 0;
if (!past_present_share_buffer) {
// Concat past key value to present (2xBxNxLxH), where L is kv_sequence_length and T is total_sequence_length.
// past_k (BxNxPxH) + k (BxNxLxH) => present_k (BxNxTxH)
// past_v (BxNxPxH) + v (BxNxLxH) => present_v (BxNxTxH)
// When there is past state, the head size for Q/K/V shall be same: H == H_v.
present_size_per_batch_k = total_sequence_length * qk_head_size;
present_size_per_batch_v = total_sequence_length * v_head_size;
ORT_RETURN_IF_ERROR(ConcatPastToPresent(batch_size, num_heads, qk_head_size, v_head_size,
sequence_length, total_sequence_length, parameters.pass_past_in_kv,
stream, max_threads_per_block, data, qkv));
if (nullptr != data.present) {
assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH || qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH);
ORT_RETURN_IF_ERROR(
LaunchConcatPastToPresent(
stream, total_sequence_length, sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, data.past, k, data.present));
// Update pointers to present_k and present_v.
k = data.present;
v = data.present + batches * present_size_per_batch_k;
}
if (nullptr != data.past_key || nullptr != data.present_key) {
if (nullptr != data.past_key && nullptr == data.present_key) {
k = const_cast<T*>(data.past_key);
v = const_cast<T*>(data.past_value);
} else if (nullptr == data.past_key && nullptr != data.present_key) {
if (qkv_format == AttentionQkvFormat::Q_K_V_BNSH) {
k = data.present_key;
v = data.present_value;
} else {
assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
k = data.temp_k_workspace;
v = data.temp_v_workspace;
}
} else if (parameters.pass_past_in_kv) {
// past_key and past_value are used directly as key and value in attention computations
k = const_cast<T*>(data.past_key);
v = const_cast<T*>(data.past_value);
// This path has a memory copy from past_key and past_value to present_key and present_value
// Avoid this path since the memory copy is unnecessary because past_key == present_key and
// past_value == present_value
int64_t k_size = (int64_t)batch_size * num_heads * parameters.total_sequence_length * qk_head_size;
int64_t v_size = (int64_t)batch_size * num_heads * parameters.total_sequence_length * v_head_size;
cudaMemcpyAsync(data.present_key, data.past_key, k_size * sizeof(T), cudaMemcpyDeviceToDevice, stream);
cudaMemcpyAsync(data.present_value, data.past_value, v_size * sizeof(T), cudaMemcpyDeviceToDevice, stream);
} else {
ORT_RETURN_IF_ERROR(
LaunchConcatTensorToTensor(stream, parameters.total_sequence_length, sequence_length,
batch_size, qk_head_size, num_heads,
max_threads_per_block, 1, data.past_key, k, data.present_key));
ORT_RETURN_IF_ERROR(
LaunchConcatTensorToTensor(stream, parameters.total_sequence_length, sequence_length,
batch_size, v_head_size, num_heads,
max_threads_per_block, 1, data.past_value, v, data.present_value));
// Update pointers to present_k and present_v.
k = data.present_key;
v = data.present_value;
}
}
} else { // past_present_share_buffer
assert(qk_head_size == v_head_size);
assert(data.fused_cross_attention_kernel == nullptr);
@ -870,15 +356,15 @@ Status QkvToContext(
present_size_per_batch_k = parameters.max_sequence_length * qk_head_size;
present_size_per_batch_v = present_size_per_batch_k;
k = data.present;
v = data.present + batches * present_size_per_batch_k;
qkv.k = data.present;
qkv.v = data.present + batches * present_size_per_batch_k;
}
// Q, K and V are ready now
DUMP_TENSOR_INIT();
if (data.fused_cross_attention_kernel != nullptr) {
assert(qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H);
assert(qkv.format == AttentionQkvFormat::Q_KV_BSNH_BSN2H);
// We only enable fused cross attention when there is no key padding mask.
// Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query.
@ -902,8 +388,8 @@ Status QkvToContext(
reinterpret_cast<FusedMultiHeadCrossAttentionKernel const*>(data.fused_cross_attention_kernel);
// When there is no bias, we can directly use q and packed kv from inputs.
void const* query = q;
void const* packed_kv = k;
void const* query = qkv.q;
void const* packed_kv = qkv.k;
if (data.value == nullptr && data.bias == nullptr) {
query = data.query;
packed_kv = data.key;
@ -951,10 +437,10 @@ Status QkvToContext(
fused_fp16_runner->setup(S, B);
if (use_fused_kernel) {
assert(qkv_format == AttentionQkvFormat::QKV_BSN3H);
assert(qkv.format == AttentionQkvFormat::QKV_BSN3H);
// When there is no bias, we can directly use packed qkv from inputs.
void const* packed_qkv = qkv;
void const* packed_qkv = qkv.q;
if (data.query != nullptr && data.key == nullptr && data.bias == nullptr) {
packed_qkv = data.query;
}
@ -962,7 +448,7 @@ Status QkvToContext(
fused_fp16_runner->run(packed_qkv, sequence_offset, data.output, stream);
DUMP_TENSOR("fused output", data.output, batch_size, sequence_length, num_heads, v_head_size);
} else {
assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH);
assert(qkv.format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH);
fused_fp16_runner->run(data.gemm_buffer, sequence_offset, data.output, stream);
DUMP_TENSOR("fused causal output", data.output, batch_size, sequence_length, num_heads, v_head_size);
}
@ -975,22 +461,22 @@ Status QkvToContext(
#if USE_FLASH_ATTENTION
if (data.use_flash_attention) {
assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
assert(qkv.format == AttentionQkvFormat::Q_K_V_BSNH);
assert(nullptr == data.mask_index);
assert(nullptr == data.relative_position_bias);
assert(parameters.head_size == parameters.v_head_size);
void* query = reinterpret_cast<void*>(q);
void* key = reinterpret_cast<void*>(k);
void* value = reinterpret_cast<void*>(v);
void* query = reinterpret_cast<void*>(qkv.q);
void* key = reinterpret_cast<void*>(qkv.k);
void* value = reinterpret_cast<void*>(qkv.v);
// For packed KV, we can use query input directly.
if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr && data.bias == nullptr) {
query = reinterpret_cast<void*>(const_cast<T*>(data.query));
}
DUMP_TENSOR_D("q(BSNH)", reinterpret_cast<const T*>(query), batch_size, sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("k(BSNH)", k, batch_size, parameters.total_sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("v(BSNH)", v, batch_size, parameters.total_sequence_length, num_heads, v_head_size);
DUMP_TENSOR_D("k(BSNH)", qkv.k, batch_size, parameters.total_sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("v(BSNH)", qkv.v, batch_size, parameters.total_sequence_length, num_heads, v_head_size);
constexpr bool is_causal = false;
ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd(
@ -1008,11 +494,11 @@ Status QkvToContext(
if (data.use_memory_efficient_attention) {
// We only enable fused cross attention when there is no key padding mask.
// Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query.
assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
assert(qkv.format == AttentionQkvFormat::Q_K_V_BSNH);
const void* query = q;
const void* key = k;
const void* value = v;
const void* query = qkv.q;
const void* key = qkv.k;
const void* value = qkv.v;
// For packed KV, we can use query input directly.
if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr) {
assert(data.bias == nullptr);
@ -1020,8 +506,8 @@ Status QkvToContext(
}
DUMP_TENSOR_D("q(BSNH)", reinterpret_cast<const T*>(query), batch_size, sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("k(BSNH)", k, batch_size, parameters.total_sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("v(BSNH)", v, batch_size, parameters.total_sequence_length, num_heads, v_head_size);
DUMP_TENSOR_D("k(BSNH)", qkv.k, batch_size, parameters.total_sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("v(BSNH)", qkv.v, batch_size, parameters.total_sequence_length, num_heads, v_head_size);
MemoryEfficientAttentionParams p;
p.sm = device_prop.major * 10 + device_prop.minor;
@ -1061,7 +547,7 @@ Status QkvToContext(
#endif
// The following are unfused attention.
assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH);
assert(qkv.format == AttentionQkvFormat::Q_K_V_BNSH);
const int* mask_index = data.mask_index;
gsl::span<const int64_t>& mask_index_dims = data.mask_index_dims;
@ -1082,12 +568,12 @@ Status QkvToContext(
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
cublas, CUBLAS_OP_T, CUBLAS_OP_N,
total_sequence_length, sequence_length, qk_head_size,
&alpha, k, qk_head_size, present_size_per_batch_k,
q, qk_head_size, sequence_length * qk_head_size,
&alpha, qkv.k, qk_head_size, present_size_per_batch_k,
qkv.q, qk_head_size, sequence_length * qk_head_size,
&zero, scratch1, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop));
DUMP_TENSOR_D("Q", q, batch_size, num_heads, sequence_length, qk_head_size);
DUMP_TENSOR_D("K", k, batch_size, num_heads, qk_head_size, sequence_length);
DUMP_TENSOR_D("Q", qkv.q, batch_size, num_heads, sequence_length, qk_head_size);
DUMP_TENSOR_D("K", qkv.k, batch_size, num_heads, qk_head_size, sequence_length);
DUMP_TENSOR_D("QK", scratch1, batch_size, num_heads, sequence_length, total_sequence_length);
const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads,
@ -1126,14 +612,14 @@ Status QkvToContext(
}
DUMP_TENSOR_D("Softmax", scratch2, batch_size, num_heads, sequence_length, total_sequence_length);
DUMP_TENSOR_D("V", v, batch_size, num_heads, sequence_length, v_head_size);
DUMP_TENSOR_D("V", qkv.v, batch_size, num_heads, sequence_length, v_head_size);
// compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v
T* temp_output = qkv;
T* temp_output = qkv.q;
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
cublas, CUBLAS_OP_N, CUBLAS_OP_N,
v_head_size, sequence_length, total_sequence_length,
&one, v, v_head_size, present_size_per_batch_v,
&one, qkv.v, v_head_size, present_size_per_batch_v,
scratch2, total_sequence_length, sequence_length * total_sequence_length,
&zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop));

View file

@ -2,11 +2,12 @@
// Licensed under the MIT License.
#pragma once
#include "core/providers/cuda/shared_inc/cuda_utils.h"
#include <cuda_fp16.h>
#include <cublas_v2.h>
#include "contrib_ops/cpu/bert/attention_common.h"
#include "core/common/gsl.h"
#include "core/framework/allocator.h"
#include "contrib_ops/cpu/bert/attention_common.h"
namespace onnxruntime {
namespace contrib {
@ -49,39 +50,56 @@ size_t GetAttentionWorkspaceSize(
template <typename T>
struct AttentionData {
T* gemm_buffer;
const T* bias;
T* gemm_buffer = nullptr;
const T* bias = nullptr;
const T* query;
const T* key;
const T* value;
const int* mask_index;
const T* query = nullptr;
const T* key = nullptr;
const T* value = nullptr;
const int* mask_index = nullptr;
gsl::span<const int64_t> mask_index_dims;
const T* past;
const T* past_key;
const T* past_value;
const T* relative_position_bias;
const T* past = nullptr;
const T* past_key = nullptr;
const T* past_value = nullptr;
const T* relative_position_bias = nullptr;
bool has_qkv_workspace;
T* workspace;
T* temp_k_workspace;
T* temp_v_workspace;
bool has_qkv_workspace = false;
T* workspace = nullptr;
T* temp_k_workspace = nullptr;
T* temp_v_workspace = nullptr;
T* output;
T* present;
T* present_key;
T* present_value;
T* output = nullptr;
T* present = nullptr;
T* present_key = nullptr;
T* present_value = nullptr;
void* fused_runner;
const void* fused_cross_attention_kernel;
void* fused_runner = nullptr;
const void* fused_cross_attention_kernel = nullptr;
bool use_flash_attention;
bool use_memory_efficient_attention;
bool use_flash_attention = false;
bool use_memory_efficient_attention = false;
mutable CumulatedSequenceLengthCache* cumulated_sequence_length_q_cache;
mutable CumulatedSequenceLengthCache* cumulated_sequence_length_kv_cache;
mutable CumulatedSequenceLengthCache* cumulated_sequence_length_q_cache = nullptr;
mutable CumulatedSequenceLengthCache* cumulated_sequence_length_kv_cache = nullptr;
};
// Intermediate data pointers available after PrepareQKV
template <typename T>
struct QkvData {
T* q = nullptr;
T* k = nullptr;
T* v = nullptr;
T* after_v = nullptr; // pointer right after v
AttentionQkvFormat format = AttentionQkvFormat::Q_K_V_BSNH;
};
template <typename T>
Status PrepareQkv(contrib::AttentionParameters& parameters,
AttentionData<T>& data,
cudaStream_t stream,
int max_threads_per_block,
QkvData<T>& qkv_data);
template <typename T>
Status QkvToContext(
const cudaDeviceProp& device_prop,
@ -161,27 +179,13 @@ Status LaunchConcatTensorToTensor(cudaStream_t stream,
const half* tensor_add,
half* tensor_out);
Status LaunchConcatPastToPresent(cudaStream_t stream,
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const float* past,
const float* k_v,
float* present);
Status LaunchConcatPastToPresent(cudaStream_t stream,
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const half* past,
const half* k_v,
half* present);
template <typename T>
Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size,
int sequence_length, int total_sequence_length, bool pass_past_in_kv,
cudaStream_t stream,
int max_threads_per_block,
AttentionData<T>& data,
QkvData<T>& qkv);
template <typename T>
Status LaunchStridedCopy(cudaStream_t stream,

View file

@ -0,0 +1,492 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <cuda_fp16.h>
#include "core/providers/cuda/cu_inc/common.cuh"
#include "contrib_ops/cuda/bert/attention_impl.h"
#include "contrib_ops/cuda/bert/add_bias_transpose.h"
#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
using namespace onnxruntime::cuda;
namespace onnxruntime {
namespace contrib {
namespace cuda {
template <typename T>
Status PrepareQkv_Attention(contrib::AttentionParameters& parameters,
AttentionData<T>& data,
cudaStream_t stream,
int max_threads_per_block,
AttentionQkvFormat& qkv_format) {
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int num_heads = parameters.num_heads;
const int qk_head_size = parameters.head_size;
const int v_head_size = parameters.v_head_size;
const bool past_present_share_buffer = parameters.past_present_share_buffer;
void* fused_runner = data.fused_runner;
bool use_flash_or_efficient_attention = data.use_flash_attention || data.use_memory_efficient_attention;
T* qkv = data.workspace;
bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional);
bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional);
if (data.bias == nullptr) {
assert(nullptr == fused_runner);
// For quantized attention, bias has been added so only need transpose here.
// gemm_buffer should be BxSx3xNxH => qkv: 3xBxNxSxH
assert(qk_head_size == v_head_size);
int matrix_to_trans = (past_present_share_buffer ? 1 : 3);
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, matrix_to_trans, sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.gemm_buffer, qkv, 3));
qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
} else {
// For fused TRT attention, transpose qkv to BxSxNx3xH (format 2)
// For flash or memory efficient attention, transpose to 3xBxSxNxH (format 3)
// For unfused kernel, transpose to 3xBxNxSxH (format 1)
// For fused causal kernel, use format 1 since we need have K and V to update present state,
// at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel.
const int format = (use_fused_kernel ? 2 : (use_flash_or_efficient_attention ? 3 : 1));
qkv_format = use_fused_kernel
? AttentionQkvFormat::QKV_BSN3H
: (use_flash_or_efficient_attention
? AttentionQkvFormat::Q_K_V_BSNH
: (use_fused_causal
? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH
: AttentionQkvFormat::Q_K_V_BNSH));
// For fused causal, we will update gemm_buffer with bias directly.
T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr;
int matrix_to_transpose = ((format == AttentionQkvFormat::Q_K_V_BNSH && past_present_share_buffer) ? 1 : 3);
// format 1: BxSx(NH + NH + NH_v) => BxNxSxH + BxNxSxH + BxNxSxH_v
// format 2: BxSx(NH + NH + NH) => BxSxNx(H + H + H)
LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size,
data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias,
3, parameters.do_rotary, parameters.past_sequence_length);
}
return Status::OK();
}
// For MultiHeadAttention with past state
template <typename T>
Status PrepareQkv_MHA_WithPast(contrib::AttentionParameters& parameters,
AttentionData<T>& data,
cudaStream_t stream,
int max_threads_per_block,
T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int kv_sequence_length = parameters.kv_sequence_length;
const int num_heads = parameters.num_heads;
const int qk_head_size = parameters.head_size;
const int v_head_size = parameters.v_head_size;
DUMP_TENSOR_INIT();
if (data.bias == nullptr) {
// Below logic does not support fused attention with past without bias
// When there is past state, the format shall be BxNxSxH, so we disable fused attention when there is past.
// cross attention with past state
if (data.past_key != nullptr && data.present_key == nullptr) {
assert(data.past_value != nullptr);
assert(data.query != nullptr);
assert(data.key == nullptr);
assert(data.value == nullptr);
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.query, q));
}
// cross attention with present state or self attention with present state
else if (data.past_key == nullptr && data.present_key != nullptr) {
assert(data.past_value == nullptr);
assert(data.present_value != nullptr);
assert(data.query != nullptr);
assert(data.key != nullptr);
assert(data.value != nullptr);
// TODO: supporting packed qkv for self attention may benefit performance
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.query, q));
// TODO: supporting packed kv for cross attention may benefit performance
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.key, data.present_key));
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
max_threads_per_block, false, data.value, data.present_value));
}
// self attention with past and present state
else {
assert(data.past_key != nullptr);
assert(data.past_value != nullptr);
assert(data.present_key != nullptr);
assert(data.present_value != nullptr);
assert(data.query != nullptr);
assert(data.key != nullptr);
assert(data.value != nullptr);
// TODO: supporting packed qkv for self attention may benefit performance
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.query, q));
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.key, k));
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
max_threads_per_block, false, data.value, v));
}
qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
}
#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION
// When past_key/past_value are inputted directly as key/value and there is no present_key/present_value
else if ((data.use_memory_efficient_attention || data.use_flash_attention) &&
data.past_key != nullptr &&
data.past_value != nullptr &&
parameters.pass_past_in_kv) {
// Transpose past_key and past_value to use memory efficient attention
// past_key (BxNxSxH) => temp_k_workspace (BxSxNxH)
ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.past_key, data.temp_k_workspace));
// past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.past_value, data.temp_v_workspace));
// query => q, temp_k_workspace => k, temp_v_workspace => v
LaunchAddBias(stream, max_threads_per_block,
batch_size, sequence_length, kv_sequence_length,
num_heads, qk_head_size, v_head_size,
data.bias, data.query, data.temp_k_workspace, data.temp_v_workspace, q, k, v);
DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size);
qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
data.past_key = nullptr;
data.past_value = nullptr;
}
// When there is no past_key/past_value and there is present_key/present_value
// (e.g. get initial kv to use as past_kv in the next iteration)
else if ((data.use_memory_efficient_attention || data.use_flash_attention) &&
data.present_key != nullptr &&
data.present_value != nullptr) {
// Use memory efficient attention kernel
LaunchAddBias(stream, max_threads_per_block,
batch_size, sequence_length, kv_sequence_length,
num_heads, qk_head_size, v_head_size,
data.bias, data.query, data.key, data.value, q, data.temp_k_workspace, data.temp_v_workspace);
// temp_k_workspace (BxSxNxH) => present_k (BxNxSxH)
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.temp_k_workspace, data.present_key));
// temp_v_workspace (BxSxNxH_v) => present_v (BxNxSxH_v)
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
max_threads_per_block, false, data.temp_v_workspace, data.present_value));
DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("k(BSNH)", data.temp_k_workspace, batch_size, kv_sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("v(BSNH)", data.temp_v_workspace, batch_size, kv_sequence_length, num_heads, v_head_size);
qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
}
#endif
else {
// Use unfused kernel for Q, use unfused kernel for K and V if needed
constexpr int format = 0;
// Query (BxSxNxH) => Q (BxNxSxH)
LaunchAddBiasTranspose<T>(stream, 1, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size,
data.query, data.bias, q,
true, -1);
if (!parameters.pass_past_in_kv) {
T* k_dest = (data.past_key == nullptr && data.present_key != nullptr) ? data.present_key : k;
T* v_dest = (data.past_value == nullptr && data.present_value != nullptr) ? data.present_value : v;
// Key (BxLxNxH) => K (BxNxLxH)
LaunchAddBiasTranspose<T>(stream, 1, format, max_threads_per_block,
batch_size, kv_sequence_length, num_heads, qk_head_size,
data.key, data.bias + num_heads * qk_head_size, k_dest,
true, -1);
// Value (BxLxNxH_v) => V (BxNxLxH_v)
LaunchAddBiasTranspose<T>(stream, 1, format, max_threads_per_block,
batch_size, kv_sequence_length, num_heads, v_head_size,
data.value, data.bias + 2 * num_heads * qk_head_size, v_dest,
true, -1);
DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size);
DUMP_TENSOR_D("k(BNSH)", k_dest, batch_size, num_heads, kv_sequence_length, qk_head_size);
DUMP_TENSOR_D("v(BNSH)", v_dest, batch_size, num_heads, kv_sequence_length, v_head_size);
}
qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
}
return Status::OK();
}
// For MultiHeadAttention without past state, with packed QKV inputs
template <typename T>
Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters,
AttentionData<T>& data,
cudaStream_t stream,
int max_threads_per_block,
T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int num_heads = parameters.num_heads;
const int qk_head_size = parameters.head_size;
const int v_head_size = parameters.v_head_size;
void* fused_runner = data.fused_runner;
T* qkv = data.workspace;
bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional);
assert(data.bias == nullptr);
assert(qk_head_size == v_head_size);
DUMP_TENSOR_INIT();
DUMP_TENSOR_D("packed_qkv", data.query, batch_size * sequence_length, num_heads, 3, qk_head_size);
if (data.use_memory_efficient_attention || data.use_flash_attention) {
// unpack qkv to BSNH. Note that there is no bias so we need not output query to q.
constexpr int format = 4;
T* qkv_add_bias = nullptr;
LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size,
data.query, data.bias, qkv,
true, v_head_size, qkv_add_bias, 3);
DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("k(BSNH)", k, batch_size, sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("v(BSNH)", v, batch_size, sequence_length, num_heads, v_head_size);
qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
} else {
if (!use_fused_kernel) {
return ORT_MAKE_STATUS(
ONNXRUNTIME, NOT_IMPLEMENTED,
"packed QKV format is not implemented for current GPU. Please disable it in fusion options.");
}
qkv_format = AttentionQkvFormat::QKV_BSN3H;
}
return Status::OK();
}
// For MultiHeadAttention without past state, with packed KV inputs
template <typename T>
Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters,
AttentionData<T>& data,
cudaStream_t stream,
int max_threads_per_block,
T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
const int batch_size = parameters.batch_size;
const int kv_sequence_length = parameters.kv_sequence_length;
const int num_heads = parameters.num_heads;
const int qk_head_size = parameters.head_size;
const int v_head_size = parameters.v_head_size;
// TODO: unpack kv to BNSH for unfused kernel so that we can remove the following constraint.
// CheckInputs verified this constraint.
assert(data.bias == nullptr);
assert(qk_head_size == v_head_size);
DUMP_TENSOR_INIT();
DUMP_TENSOR_D("packed_kv", data.key, batch_size * kv_sequence_length, num_heads, 2, qk_head_size);
if (data.use_memory_efficient_attention || data.use_flash_attention) {
// unpack kv to BSNH. Note that there is no bias so we need not output query to q.
constexpr int format = 4;
T* qkv_add_bias = nullptr;
const T* kv_bias = (data.bias == nullptr ? data.bias : data.bias + parameters.hidden_size);
LaunchAddBiasTranspose(stream, 2, format, max_threads_per_block,
batch_size, kv_sequence_length, num_heads, qk_head_size,
data.key, kv_bias, k,
true, v_head_size, qkv_add_bias, 2);
DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size);
qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
} else {
if (data.fused_cross_attention_kernel == nullptr) {
return ORT_MAKE_STATUS(
ONNXRUNTIME, NOT_IMPLEMENTED,
"packed KV format is not implemented for current GPU. Please disable packed kv in fusion options.");
}
qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
}
return Status::OK();
}
// For MultiHeadAttention without past state, with Q, K and V inputs
template <typename T>
Status PrepareQkv_MHA_NotPacked(contrib::AttentionParameters& parameters,
AttentionData<T>& data,
cudaStream_t stream,
int max_threads_per_block,
T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int kv_sequence_length = parameters.kv_sequence_length;
const int num_heads = parameters.num_heads;
const int qk_head_size = parameters.head_size;
const int v_head_size = parameters.v_head_size;
void* fused_runner = data.fused_runner;
T* qkv = data.workspace;
bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional);
bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional);
// gemm_buffer == nullptr and not packed
assert(data.query != nullptr && data.key != nullptr && data.value != nullptr);
DUMP_TENSOR_INIT();
DUMP_TENSOR_D("query", data.query, batch_size, sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("key", data.key, batch_size, kv_sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("value", data.value, batch_size, kv_sequence_length, num_heads, v_head_size);
#if DUMP_TENSOR_LEVEL > 1
if (data.bias != nullptr) {
DUMP_TENSOR_D("query_bias", data.bias, num_heads, qk_head_size);
DUMP_TENSOR_D("key_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size);
DUMP_TENSOR_D("value_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size);
}
#endif
if (data.relative_position_bias != nullptr && parameters.broadcast_res_pos_bias) {
DUMP_TENSOR_D("relative_position_bias", data.relative_position_bias,
num_heads, sequence_length, kv_sequence_length);
}
if (data.mask_index != nullptr && parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) {
DUMP_TENSOR_D("mask_index", data.mask_index, 3 * batch_size + 2, 1);
}
if (data.fused_cross_attention_kernel != nullptr) {
assert(qk_head_size == v_head_size);
// For fused cross attention, besides adding bias, K and V needed to be packed:
// K (BxSxNxH), V (BxSxNxH) => BxSxNx2xH
LaunchAddBiasTransposeTrt(
stream, max_threads_per_block,
batch_size, sequence_length,
num_heads, qk_head_size,
data.bias, data.query, data.key, data.value, qkv, true, kv_sequence_length);
qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
}
#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION
else if (data.use_memory_efficient_attention || data.use_flash_attention) {
LaunchAddBias(stream, max_threads_per_block,
batch_size, sequence_length, kv_sequence_length,
num_heads, qk_head_size, v_head_size,
data.bias, data.query, data.key, data.value, q, k, v);
DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size);
qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
}
#endif
else if (use_fused_kernel) {
assert(qk_head_size == v_head_size);
// Q (BxSxNxH), K (BxSxNxH), V (BxSxNxH) => BxSxNx(H + H + H)
LaunchAddBiasTransposeTrt(
stream, max_threads_per_block,
batch_size, sequence_length,
num_heads, qk_head_size,
data.bias, data.query, data.key, data.value, qkv, false, kv_sequence_length);
DUMP_TENSOR_D("qkv(BSN3H)", qkv, batch_size, sequence_length, num_heads, 2 * qk_head_size + v_head_size);
qkv_format = AttentionQkvFormat::QKV_BSN3H;
} else { // unfused kernel
ORT_ENFORCE(!use_fused_causal, "MultiHeadAttention has not enabled fused causal");
// Query (BxSxNxH) => Q (BxNxSxH)
constexpr int format = 0;
LaunchAddBiasTranspose<T>(stream, 1, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size,
data.query, data.bias, q,
true, -1);
// Key (BxLxNxH) => K (BxNxLxH)
LaunchAddBiasTranspose<T>(stream, 1, format, max_threads_per_block,
batch_size, kv_sequence_length, num_heads, qk_head_size,
data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, k,
true, -1);
// Value (BxLxNxH_v) => K (BxNxLxH_v)
LaunchAddBiasTranspose<T>(stream, 1, format, max_threads_per_block,
batch_size, kv_sequence_length, num_heads, v_head_size,
data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, v,
true, -1);
DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size);
DUMP_TENSOR_D("k(BNSH)", k, batch_size, num_heads, kv_sequence_length, qk_head_size);
DUMP_TENSOR_D("v(BNSH)", v, batch_size, num_heads, kv_sequence_length, v_head_size);
qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
}
return Status::OK();
}
template <typename T>
Status PrepareQkv(contrib::AttentionParameters& parameters,
AttentionData<T>& data,
cudaStream_t stream,
int max_threads_per_block,
QkvData<T>& qkv) {
if (data.has_qkv_workspace) {
const int size_per_batch_q = parameters.sequence_length * parameters.head_size;
const int size_per_batch_k = parameters.kv_sequence_length * parameters.head_size;
const int size_per_batch_v = parameters.kv_sequence_length * parameters.v_head_size;
const int batches = parameters.batch_size * parameters.num_heads;
const size_t elements_q = static_cast<size_t>(batches) * static_cast<size_t>(size_per_batch_q);
const size_t elements_k = static_cast<size_t>(batches) * static_cast<size_t>(size_per_batch_k);
const size_t elements_v = static_cast<size_t>(batches) * static_cast<size_t>(size_per_batch_v);
qkv.q = data.workspace;
qkv.k = data.workspace + elements_q;
qkv.v = qkv.k + elements_k;
qkv.after_v = qkv.v + elements_v;
}
if (nullptr != data.gemm_buffer) { // Attention operator
ORT_RETURN_IF_ERROR(PrepareQkv_Attention<T>(parameters, data, stream, max_threads_per_block,
qkv.format));
} else if (data.past_key != nullptr || data.present_key != nullptr) { // mha operator with past/present state
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast(parameters, data, stream, max_threads_per_block,
qkv.q, qkv.k, qkv.v, qkv.format));
} else if (data.key == nullptr) { // multihead attention operator, no past, packed qkv
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedQKV(parameters, data, stream, max_threads_per_block,
qkv.q, qkv.k, qkv.v, qkv.format));
} else if (data.value == nullptr) { // multihead attention operator, no past, packed kv
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedKV(parameters, data, stream, max_threads_per_block,
qkv.q, qkv.k, qkv.v, qkv.format));
} else { // multihead attention operator, no past, separated Q/K/V inputs
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NotPacked(parameters, data, stream, max_threads_per_block,
qkv.q, qkv.k, qkv.v, qkv.format));
}
CUDA_RETURN_IF_ERROR(cudaGetLastError());
return Status::OK();
}
// Template Instantiation
template Status PrepareQkv<float>(
contrib::AttentionParameters& parameters,
AttentionData<float>& data,
cudaStream_t stream,
int max_threads_per_block,
QkvData<float>& qkv);
template Status PrepareQkv<half>(
contrib::AttentionParameters& parameters,
AttentionData<half>& data,
cudaStream_t stream,
int max_threads_per_block,
QkvData<half>& qkv);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -263,14 +263,12 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
typedef typename ToCudaType<T>::MappedType CudaT;
AttentionData<CudaT> data;
data.gemm_buffer = nullptr;
data.bias = (nullptr == bias) ? nullptr : reinterpret_cast<const CudaT*>(bias->Data<T>());
data.query = reinterpret_cast<const CudaT*>(query->Data<T>());
data.key = (nullptr == key || parameters.pass_past_in_kv) ? nullptr : reinterpret_cast<const CudaT*>(key->Data<T>());
data.value = (nullptr == value || parameters.pass_past_in_kv) ? nullptr : reinterpret_cast<const CudaT*>(value->Data<T>());
data.mask_index = (nullptr == key_padding_mask) ? nullptr : key_padding_mask->Data<int>();
data.mask_index_dims = (nullptr == key_padding_mask) ? gsl::span<const int64_t>() : key_padding_mask->Shape().GetDims();
data.past = nullptr;
data.past_key = pass_key_value_as_past ? reinterpret_cast<const CudaT*>(key->Data<T>())
: (nullptr == past_key) ? nullptr
: reinterpret_cast<const CudaT*>(past_key->Data<T>());
@ -283,7 +281,6 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
data.temp_k_workspace = use_temp_k_v_workspace ? reinterpret_cast<CudaT*>(temp_k_work_space.get()) : nullptr;
data.temp_v_workspace = use_temp_k_v_workspace ? reinterpret_cast<CudaT*>(temp_v_work_space.get()) : nullptr;
data.output = reinterpret_cast<CudaT*>(output->MutableData<T>());
data.present = nullptr;
data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast<CudaT*>(present_key->MutableData<T>());
data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast<CudaT*>(present_value->MutableData<T>());
data.fused_runner = reinterpret_cast<void*>(fused_runner);

View file

@ -195,28 +195,21 @@ Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
typedef typename ToCudaType<T>::MappedType CudaT;
AttentionData<CudaT> data;
data.gemm_buffer = reinterpret_cast<CudaT*>(gemm_buffer.get());
data.bias = nullptr; // bias has been added
data.query = nullptr;
data.key = nullptr;
data.value = nullptr;
data.mask_index = (nullptr == mask_index) ? nullptr : mask_index->Data<int>();
data.mask_index_dims = (nullptr == mask_index) ? gsl::span<const int64_t>() : mask_index->Shape().GetDims();
data.past = (nullptr == past_tensor) ? nullptr : reinterpret_cast<const CudaT*>(past_tensor->Data<T>());
data.past_key = nullptr;
data.past_value = nullptr;
data.relative_position_bias = nullptr; // add_qk is not supported in quantized attention
if (nullptr != mask_index) {
data.mask_index = mask_index->Data<int>();
data.mask_index_dims = mask_index->Shape().GetDims();
}
if (nullptr != past_tensor) {
data.past = reinterpret_cast<const CudaT*>(past_tensor->Data<T>());
}
data.has_qkv_workspace = true;
data.workspace = reinterpret_cast<CudaT*>(work_space.get());
data.output = reinterpret_cast<CudaT*>(output->MutableData<T>());
data.present = (nullptr == present) ? nullptr : reinterpret_cast<CudaT*>(present->MutableData<T>());
data.present_key = nullptr;
data.present_value = nullptr;
data.fused_runner = fused_runner;
data.fused_cross_attention_kernel = nullptr;
data.use_flash_attention = use_flash_attention;
data.use_memory_efficient_attention = use_memory_efficient_attention;
data.cumulated_sequence_length_q_cache = nullptr;
data.cumulated_sequence_length_kv_cache = nullptr;
if (nullptr != present) {
data.present = reinterpret_cast<CudaT*>(present->MutableData<T>());
}
return QkvToContext<CudaT>(GetDeviceProp(), cublas, context->GetComputeStream(), parameters, data);
}