diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 25202f82f4..b483e6de81 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -10,6 +10,7 @@ set(contrib_ops_excluded_files "bert/attention_impl.cu" "bert/attention_softmax.h" "bert/attention_softmax.cu" + "bert/attention_prepare_qkv.cu" "bert/decoder_masked_multihead_attention.h" "bert/decoder_masked_multihead_attention.cc" "bert/decoder_masked_self_attention.h" diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index a79ad96b94..f0385ea5ab 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -249,30 +249,28 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; AttentionData data; data.gemm_buffer = reinterpret_cast(gemm_buffer.get()); - data.bias = nullptr == bias ? nullptr : reinterpret_cast(bias->Data()); - data.query = nullptr; - data.key = nullptr; - data.value = nullptr; - data.mask_index = (nullptr == mask_index) ? nullptr : mask_index->Data(); - data.mask_index_dims = (nullptr == mask_index) ? gsl::span() : mask_index->Shape().GetDims(); - data.past = (nullptr == past) ? nullptr : reinterpret_cast(past->Data()); - data.past_key = nullptr; - data.past_value = nullptr; - data.relative_position_bias = (nullptr == relative_position_bias) - ? nullptr - : reinterpret_cast(relative_position_bias->Data()); + if (nullptr != bias) { + data.bias = reinterpret_cast(bias->Data()); + } + if (nullptr != mask_index) { + data.mask_index = mask_index->Data(); + data.mask_index_dims = mask_index->Shape().GetDims(); + } + if (nullptr != past) { + data.past = reinterpret_cast(past->Data()); + } + if (nullptr != relative_position_bias) { + data.relative_position_bias = reinterpret_cast(relative_position_bias->Data()); + } data.has_qkv_workspace = true; data.workspace = reinterpret_cast(work_space.get()); data.output = reinterpret_cast(output->MutableData()); - data.present = (nullptr == present) ? nullptr : reinterpret_cast(present->MutableData()); - data.present_key = nullptr; - data.present_value = nullptr; + if (nullptr != present) { + data.present = reinterpret_cast(present->MutableData()); + } data.fused_runner = reinterpret_cast(fused_runner); - data.fused_cross_attention_kernel = nullptr; data.use_flash_attention = use_flash_attention; data.use_memory_efficient_attention = use_memory_efficient_attention; - data.cumulated_sequence_length_q_cache = nullptr; - data.cumulated_sequence_length_kv_cache = nullptr; return QkvToContext(device_prop, cublas, context->GetComputeStream(), parameters, data); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_concat.cu b/onnxruntime/contrib_ops/cuda/bert/attention_concat.cu index 5d9cfcc697..8378ee2691 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_concat.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_concat.cu @@ -93,16 +93,16 @@ __global__ void ConcatTensorToTensorLarge(const int tensor_add_sequence_length, } Status LaunchConcatTensorToTensor(cudaStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const int matrix_num, - const float* tensor_in, - const float* tensor_add, - float* tensor_out) { + const int all_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const int matrix_num, + const float* tensor_in, + const float* tensor_add, + float* tensor_out) { const dim3 grid(all_sequence_length, batch_size, matrix_num); if (0 == (head_size & 1)) { const int H = head_size / 2; @@ -137,16 +137,16 @@ Status LaunchConcatTensorToTensor(cudaStream_t stream, } Status LaunchConcatTensorToTensor(cudaStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const int matrix_num, - const half* tensor_in, - const half* tensor_add, - half* tensor_out) { + const int all_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const int matrix_num, + const half* tensor_in, + const half* tensor_add, + half* tensor_out) { const dim3 grid(all_sequence_length, batch_size, matrix_num); if (0 == (head_size % 4)) { const int H = head_size / 4; @@ -197,15 +197,15 @@ Status LaunchConcatTensorToTensor(cudaStream_t stream, } Status LaunchConcatPastToPresent(cudaStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const float* past, - const float* k_v, - float* present) { + const int all_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const float* past, + const float* k_v, + float* present) { return LaunchConcatTensorToTensor( stream, all_sequence_length, @@ -221,15 +221,15 @@ Status LaunchConcatPastToPresent(cudaStream_t stream, } Status LaunchConcatPastToPresent(cudaStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const half* past, - const half* k_v, - half* present) { + const int all_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const half* past, + const half* k_v, + half* present) { return LaunchConcatTensorToTensor( stream, all_sequence_length, @@ -244,6 +244,90 @@ Status LaunchConcatPastToPresent(cudaStream_t stream, present); } +#ifndef USE_ROCM // exclude from hipify +template +Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, + int sequence_length, int total_sequence_length, bool pass_past_in_kv, + cudaStream_t stream, + int max_threads_per_block, + AttentionData& data, + QkvData& qkv) { + // Concat past key value to present (2xBxNxLxH), where L is kv_sequence_length and T is total_sequence_length. + // past_k (BxNxPxH) + k (BxNxLxH) => present_k (BxNxTxH) + // past_v (BxNxPxH) + v (BxNxLxH) => present_v (BxNxTxH) + // When there is past state, the head size for Q/K/V shall be same: H == H_v. + + if (nullptr != data.present) { + assert(qkv.format == AttentionQkvFormat::Q_K_V_BNSH || qkv.format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); + ORT_RETURN_IF_ERROR( + LaunchConcatPastToPresent( + stream, total_sequence_length, sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, data.past, qkv.k, data.present)); + + // Update pointers to present_k and present_v. + qkv.k = data.present; + qkv.v = data.present + batch_size * num_heads * total_sequence_length * qk_head_size; + } + + if (nullptr != data.past_key || nullptr != data.present_key) { + if (nullptr != data.past_key && nullptr == data.present_key) { + qkv.k = const_cast(data.past_key); + qkv.v = const_cast(data.past_value); + } else if (nullptr == data.past_key && nullptr != data.present_key) { + if (qkv.format == AttentionQkvFormat::Q_K_V_BNSH) { + qkv.k = data.present_key; + qkv.v = data.present_value; + } else { + assert(qkv.format == AttentionQkvFormat::Q_K_V_BSNH); + qkv.k = data.temp_k_workspace; + qkv.v = data.temp_v_workspace; + } + } else if (pass_past_in_kv) { + // past_key and past_value are used directly as key and value in attention computations + qkv.k = const_cast(data.past_key); + qkv.v = const_cast(data.past_value); + + // This path has a memory copy from past_key and past_value to present_key and present_value + // Avoid this path since the memory copy is unnecessary because past_key == present_key and + // past_value == present_value + int64_t k_size = (int64_t)batch_size * num_heads * total_sequence_length * qk_head_size; + int64_t v_size = (int64_t)batch_size * num_heads * total_sequence_length * v_head_size; + cudaMemcpyAsync(data.present_key, data.past_key, k_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); + cudaMemcpyAsync(data.present_value, data.past_value, v_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); + } else { + ORT_RETURN_IF_ERROR( + LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length, + batch_size, qk_head_size, num_heads, + max_threads_per_block, 1, data.past_key, qkv.k, data.present_key)); + ORT_RETURN_IF_ERROR( + LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length, + batch_size, v_head_size, num_heads, + max_threads_per_block, 1, data.past_value, qkv.v, data.present_value)); + // Update pointers to present_k and present_v. + qkv.k = data.present_key; + qkv.v = data.present_value; + } + } + + return CUDA_CALL(cudaGetLastError()); +} + +// Template Instantiation +template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, + int sequence_length, int total_sequence_length, bool pass_past_in_kv, + cudaStream_t stream, + int max_threads_per_block, + AttentionData& data, + QkvData& qkv); + +template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, + int sequence_length, int total_sequence_length, bool pass_past_in_kv, + cudaStream_t stream, + int max_threads_per_block, + AttentionData& data, + QkvData& qkv); +#endif + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index ae7696eb9f..366d8fee14 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -26,16 +26,11 @@ limitations under the License. // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include -#include -#include #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" -#include "contrib_ops/cuda/bert/attention_impl.h" #include "contrib_ops/cuda/bert/attention_softmax.h" #include "contrib_ops/cuda/bert/transformer_common.h" -#include "contrib_ops/cuda/bert/add_bias_transpose.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h" #include "contrib_ops/cpu/bert/attention_base.h" @@ -43,6 +38,7 @@ limitations under the License. #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" +#include "contrib_ops/cuda/bert/attention_impl.h" using namespace onnxruntime::cuda; using namespace onnxruntime::contrib::attention_softmax_cuda; @@ -286,446 +282,6 @@ template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, const half* qkv_buffer, half* present); -template -Status PrepareQkv_Attention(contrib::AttentionParameters& parameters, - AttentionData& data, - cudaStream_t stream, - int max_threads_per_block, - AttentionQkvFormat& qkv_format) { - const int batch_size = parameters.batch_size; - const int sequence_length = parameters.sequence_length; - const int num_heads = parameters.num_heads; - const int qk_head_size = parameters.head_size; - const int v_head_size = parameters.v_head_size; - const bool past_present_share_buffer = parameters.past_present_share_buffer; - void* fused_runner = data.fused_runner; - bool use_flash_or_efficient_attention = data.use_flash_attention || data.use_memory_efficient_attention; - - T* qkv = data.workspace; - - bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); - bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); - - if (data.bias == nullptr) { - assert(nullptr == fused_runner); - // For quantized attention, bias has been added so only need transpose here. - // gemm_buffer should be BxSx3xNxH => qkv: 3xBxNxSxH - assert(qk_head_size == v_head_size); - int matrix_to_trans = (past_present_share_buffer ? 1 : 3); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, matrix_to_trans, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.gemm_buffer, qkv, 3)); - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; - } else { - // For fused TRT attention, transpose qkv to BxSxNx3xH (format 2) - // For flash or memory efficient attention, transpose to 3xBxSxNxH (format 3) - // For unfused kernel, transpose to 3xBxNxSxH (format 1) - // For fused causal kernel, use format 1 since we need have K and V to update present state, - // at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel. - const int format = (use_fused_kernel ? 2 : (use_flash_or_efficient_attention ? 3 : 1)); - qkv_format = use_fused_kernel - ? AttentionQkvFormat::QKV_BSN3H - : (use_flash_or_efficient_attention - ? AttentionQkvFormat::Q_K_V_BSNH - : (use_fused_causal - ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH - : AttentionQkvFormat::Q_K_V_BNSH)); - - // For fused causal, we will update gemm_buffer with bias directly. - T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr; - - int matrix_to_transpose = ((format == AttentionQkvFormat::Q_K_V_BNSH && past_present_share_buffer) ? 1 : 3); - // format 1: BxSx(NH + NH + NH_v) => BxNxSxH + BxNxSxH + BxNxSxH_v - // format 2: BxSx(NH + NH + NH) => BxSxNx(H + H + H) - LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block, - batch_size, sequence_length, num_heads, qk_head_size, - data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias, - 3, parameters.do_rotary, parameters.past_sequence_length); - } - return Status::OK(); -} - -// For MultiHeadAttention with past state -template -Status PrepareQkv_MHA_WithPast(contrib::AttentionParameters& parameters, - AttentionData& data, - cudaStream_t stream, - int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { - const int batch_size = parameters.batch_size; - const int sequence_length = parameters.sequence_length; - const int kv_sequence_length = parameters.kv_sequence_length; - const int num_heads = parameters.num_heads; - const int qk_head_size = parameters.head_size; - const int v_head_size = parameters.v_head_size; - - DUMP_TENSOR_INIT(); - - if (data.bias == nullptr) { - // Below logic does not support fused attention with past without bias - // When there is past state, the format shall be BxNxSxH, so we disable fused attention when there is past. - - // cross attention with past state - if (data.past_key != nullptr && data.present_key == nullptr) { - assert(data.past_value != nullptr); - assert(data.query != nullptr); - assert(data.key == nullptr); - assert(data.value == nullptr); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.query, q)); - } - // cross attention with present state or self attention with present state - else if (data.past_key == nullptr && data.present_key != nullptr) { - assert(data.past_value == nullptr); - assert(data.present_value != nullptr); - assert(data.query != nullptr); - assert(data.key != nullptr); - assert(data.value != nullptr); - - // TODO: supporting packed qkv for self attention may benefit performance - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.query, q)); - - // TODO: supporting packed kv for cross attention may benefit performance - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.key, data.present_key)); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, data.value, data.present_value)); - } - // self attention with past and present state - else { - assert(data.past_key != nullptr); - assert(data.past_value != nullptr); - assert(data.present_key != nullptr); - assert(data.present_value != nullptr); - assert(data.query != nullptr); - assert(data.key != nullptr); - assert(data.value != nullptr); - // TODO: supporting packed qkv for self attention may benefit performance - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.query, q)); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.key, k)); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, data.value, v)); - } - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; - } -#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION - // When past_key/past_value are inputted directly as key/value and there is no present_key/present_value - else if ((data.use_memory_efficient_attention || data.use_flash_attention) && - data.past_key != nullptr && - data.past_value != nullptr && - parameters.pass_past_in_kv) { - // Transpose past_key and past_value to use memory efficient attention - - // past_key (BxNxSxH) => temp_k_workspace (BxSxNxH) - ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.past_key, data.temp_k_workspace)); - // past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) - ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.past_value, data.temp_v_workspace)); - - // query => q, temp_k_workspace => k, temp_v_workspace => v - LaunchAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, - num_heads, qk_head_size, v_head_size, - data.bias, data.query, data.temp_k_workspace, data.temp_v_workspace, q, k, v); - - DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - - data.past_key = nullptr; - data.past_value = nullptr; - } - // When there is no past_key/past_value and there is present_key/present_value - // (e.g. get initial kv to use as past_kv in the next iteration) - else if ((data.use_memory_efficient_attention || data.use_flash_attention) && - data.present_key != nullptr && - data.present_value != nullptr) { - // Use memory efficient attention kernel - LaunchAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, - num_heads, qk_head_size, v_head_size, - data.bias, data.query, data.key, data.value, q, data.temp_k_workspace, data.temp_v_workspace); - - // temp_k_workspace (BxSxNxH) => present_k (BxNxSxH) - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.temp_k_workspace, data.present_key)); - - // temp_v_workspace (BxSxNxH_v) => present_v (BxNxSxH_v) - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, data.temp_v_workspace, data.present_value)); - - DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", data.temp_k_workspace, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", data.temp_v_workspace, batch_size, kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - } -#endif - else { - // Use unfused kernel for Q, use unfused kernel for K and V if needed - constexpr int format = 0; - // Query (BxSxNxH) => Q (BxNxSxH) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, sequence_length, num_heads, qk_head_size, - data.query, data.bias, q, - true, -1); - - if (!parameters.pass_past_in_kv) { - T* k_dest = (data.past_key == nullptr && data.present_key != nullptr) ? data.present_key : k; - T* v_dest = (data.past_value == nullptr && data.present_value != nullptr) ? data.present_value : v; - - // Key (BxLxNxH) => K (BxNxLxH) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, qk_head_size, - data.key, data.bias + num_heads * qk_head_size, k_dest, - true, -1); - - // Value (BxLxNxH_v) => V (BxNxLxH_v) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, v_head_size, - data.value, data.bias + 2 * num_heads * qk_head_size, v_dest, - true, -1); - - DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("k(BNSH)", k_dest, batch_size, num_heads, kv_sequence_length, qk_head_size); - DUMP_TENSOR_D("v(BNSH)", v_dest, batch_size, num_heads, kv_sequence_length, v_head_size); - } - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; - } - return Status::OK(); -} - -// For MultiHeadAttention without past state, with packed QKV inputs -template -Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters, - AttentionData& data, - cudaStream_t stream, - int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { - const int batch_size = parameters.batch_size; - const int sequence_length = parameters.sequence_length; - const int num_heads = parameters.num_heads; - const int qk_head_size = parameters.head_size; - const int v_head_size = parameters.v_head_size; - void* fused_runner = data.fused_runner; - - T* qkv = data.workspace; - - bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); - - assert(data.bias == nullptr); - assert(qk_head_size == v_head_size); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("packed_qkv", data.query, batch_size * sequence_length, num_heads, 3, qk_head_size); - - if (data.use_memory_efficient_attention || data.use_flash_attention) { - // unpack qkv to BSNH. Note that there is no bias so we need not output query to q. - constexpr int format = 4; - T* qkv_add_bias = nullptr; - LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block, - batch_size, sequence_length, num_heads, qk_head_size, - data.query, data.bias, qkv, - true, v_head_size, qkv_add_bias, 3); - DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", k, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size, sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - } else { - if (!use_fused_kernel) { - return ORT_MAKE_STATUS( - ONNXRUNTIME, NOT_IMPLEMENTED, - "packed QKV format is not implemented for current GPU. Please disable it in fusion options."); - } - - qkv_format = AttentionQkvFormat::QKV_BSN3H; - } - return Status::OK(); -} - -// For MultiHeadAttention without past state, with packed KV inputs -template -Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters, - AttentionData& data, - cudaStream_t stream, - int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { - const int batch_size = parameters.batch_size; - const int kv_sequence_length = parameters.kv_sequence_length; - const int num_heads = parameters.num_heads; - const int qk_head_size = parameters.head_size; - const int v_head_size = parameters.v_head_size; - - // TODO: unpack kv to BNSH for unfused kernel so that we can remove the following constraint. - // CheckInputs verified this constraint. - assert(data.bias == nullptr); - assert(qk_head_size == v_head_size); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("packed_kv", data.key, batch_size * kv_sequence_length, num_heads, 2, qk_head_size); - - if (data.use_memory_efficient_attention || data.use_flash_attention) { - // unpack kv to BSNH. Note that there is no bias so we need not output query to q. - constexpr int format = 4; - T* qkv_add_bias = nullptr; - const T* kv_bias = (data.bias == nullptr ? data.bias : data.bias + parameters.hidden_size); - LaunchAddBiasTranspose(stream, 2, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, qk_head_size, - data.key, kv_bias, k, - true, v_head_size, qkv_add_bias, 2); - DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - } else { - if (data.fused_cross_attention_kernel == nullptr) { - return ORT_MAKE_STATUS( - ONNXRUNTIME, NOT_IMPLEMENTED, - "packed KV format is not implemented for current GPU. Please disable packed kv in fusion options."); - } - - qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; - } - return Status::OK(); -} - -// For MultiHeadAttention without past state, with Q, K and V inputs -template -Status PrepareQkv_MHA_NotPacked(contrib::AttentionParameters& parameters, - AttentionData& data, - cudaStream_t stream, - int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { - const int batch_size = parameters.batch_size; - const int sequence_length = parameters.sequence_length; - const int kv_sequence_length = parameters.kv_sequence_length; - const int num_heads = parameters.num_heads; - const int qk_head_size = parameters.head_size; - const int v_head_size = parameters.v_head_size; - void* fused_runner = data.fused_runner; - - T* qkv = data.workspace; - - bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); - bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); - - // gemm_buffer == nullptr and not packed - assert(data.query != nullptr && data.key != nullptr && data.value != nullptr); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("query", data.query, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("key", data.key, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("value", data.value, batch_size, kv_sequence_length, num_heads, v_head_size); - -#if DUMP_TENSOR_LEVEL > 1 - if (data.bias != nullptr) { - DUMP_TENSOR_D("query_bias", data.bias, num_heads, qk_head_size); - DUMP_TENSOR_D("key_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size); - DUMP_TENSOR_D("value_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size); - } -#endif - - if (data.relative_position_bias != nullptr && parameters.broadcast_res_pos_bias) { - DUMP_TENSOR_D("relative_position_bias", data.relative_position_bias, - num_heads, sequence_length, kv_sequence_length); - } - - if (data.mask_index != nullptr && parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) { - DUMP_TENSOR_D("mask_index", data.mask_index, 3 * batch_size + 2, 1); - } - - if (data.fused_cross_attention_kernel != nullptr) { - assert(qk_head_size == v_head_size); - - // For fused cross attention, besides adding bias, K and V needed to be packed: - // K (BxSxNxH), V (BxSxNxH) => BxSxNx2xH - LaunchAddBiasTransposeTrt( - stream, max_threads_per_block, - batch_size, sequence_length, - num_heads, qk_head_size, - data.bias, data.query, data.key, data.value, qkv, true, kv_sequence_length); - - qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; - } -#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION - else if (data.use_memory_efficient_attention || data.use_flash_attention) { - LaunchAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, - num_heads, qk_head_size, v_head_size, - data.bias, data.query, data.key, data.value, q, k, v); - - DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - } -#endif - else if (use_fused_kernel) { - assert(qk_head_size == v_head_size); - - // Q (BxSxNxH), K (BxSxNxH), V (BxSxNxH) => BxSxNx(H + H + H) - LaunchAddBiasTransposeTrt( - stream, max_threads_per_block, - batch_size, sequence_length, - num_heads, qk_head_size, - data.bias, data.query, data.key, data.value, qkv, false, kv_sequence_length); - DUMP_TENSOR_D("qkv(BSN3H)", qkv, batch_size, sequence_length, num_heads, 2 * qk_head_size + v_head_size); - - qkv_format = AttentionQkvFormat::QKV_BSN3H; - } else { // unfused kernel - ORT_ENFORCE(!use_fused_causal, "MultiHeadAttention has not enabled fused causal"); - - // Query (BxSxNxH) => Q (BxNxSxH) - constexpr int format = 0; - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, sequence_length, num_heads, qk_head_size, - data.query, data.bias, q, - true, -1); - - // Key (BxLxNxH) => K (BxNxLxH) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, qk_head_size, - data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, k, - true, -1); - - // Value (BxLxNxH_v) => K (BxNxLxH_v) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, v_head_size, - data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, v, - true, -1); - - DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("k(BNSH)", k, batch_size, num_heads, kv_sequence_length, qk_head_size); - DUMP_TENSOR_D("v(BNSH)", v, batch_size, num_heads, kv_sequence_length, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; - } - return Status::OK(); -} - -template -Status PrepareQkv(contrib::AttentionParameters& parameters, - AttentionData& data, - cudaStream_t stream, - int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { - if (nullptr != data.gemm_buffer) { // Attention operator - ORT_RETURN_IF_ERROR(PrepareQkv_Attention(parameters, data, stream, max_threads_per_block, qkv_format)); - } else if (data.past_key != nullptr || data.present_key != nullptr) { // mha operator with past/present state - ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format)); - } else if (data.key == nullptr) { // multihead attention operator, no past, packed qkv - ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedQKV(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format)); - } else if (data.value == nullptr) { // multihead attention operator, no past, packed kv - ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedKV(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format)); - } else { // multihead attention operator, no past, separated Q/K/V inputs - ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NotPacked(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format)); - } - - CUDA_RETURN_IF_ERROR(cudaGetLastError()); - return Status::OK(); -} - template Status QkvToContext( const cudaDeviceProp& device_prop, @@ -755,92 +311,22 @@ Status QkvToContext( const int batches = batch_size * num_heads; - T* qkv = nullptr; - T* q = nullptr; - T* k = nullptr; - T* v = nullptr; - T* scratch1 = data.workspace; - if (data.has_qkv_workspace) { - const int size_per_batch_q = sequence_length * qk_head_size; - const int size_per_batch_k = kv_sequence_length * qk_head_size; - const int size_per_batch_v = kv_sequence_length * v_head_size; - const size_t elements_q = static_cast(batches) * static_cast(size_per_batch_q); - const size_t elements_k = static_cast(batches) * static_cast(size_per_batch_k); - const size_t elements_v = static_cast(batches) * static_cast(size_per_batch_v); - qkv = data.workspace; - q = qkv; - k = q + elements_q; - v = k + elements_k; - scratch1 = v + elements_v; - } - bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); - AttentionQkvFormat qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - ORT_RETURN_IF_ERROR(PrepareQkv(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format)); + QkvData qkv; + ORT_RETURN_IF_ERROR(PrepareQkv(parameters, data, stream, max_threads_per_block, qkv)); + T* scratch1 = data.has_qkv_workspace ? qkv.after_v : data.workspace; int present_size_per_batch_k = 0; int present_size_per_batch_v = 0; if (!past_present_share_buffer) { - // Concat past key value to present (2xBxNxLxH), where L is kv_sequence_length and T is total_sequence_length. - // past_k (BxNxPxH) + k (BxNxLxH) => present_k (BxNxTxH) - // past_v (BxNxPxH) + v (BxNxLxH) => present_v (BxNxTxH) - // When there is past state, the head size for Q/K/V shall be same: H == H_v. present_size_per_batch_k = total_sequence_length * qk_head_size; present_size_per_batch_v = total_sequence_length * v_head_size; + ORT_RETURN_IF_ERROR(ConcatPastToPresent(batch_size, num_heads, qk_head_size, v_head_size, + sequence_length, total_sequence_length, parameters.pass_past_in_kv, + stream, max_threads_per_block, data, qkv)); - if (nullptr != data.present) { - assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH || qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); - ORT_RETURN_IF_ERROR( - LaunchConcatPastToPresent( - stream, total_sequence_length, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, data.past, k, data.present)); - - // Update pointers to present_k and present_v. - k = data.present; - v = data.present + batches * present_size_per_batch_k; - } - - if (nullptr != data.past_key || nullptr != data.present_key) { - if (nullptr != data.past_key && nullptr == data.present_key) { - k = const_cast(data.past_key); - v = const_cast(data.past_value); - } else if (nullptr == data.past_key && nullptr != data.present_key) { - if (qkv_format == AttentionQkvFormat::Q_K_V_BNSH) { - k = data.present_key; - v = data.present_value; - } else { - assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH); - k = data.temp_k_workspace; - v = data.temp_v_workspace; - } - } else if (parameters.pass_past_in_kv) { - // past_key and past_value are used directly as key and value in attention computations - k = const_cast(data.past_key); - v = const_cast(data.past_value); - - // This path has a memory copy from past_key and past_value to present_key and present_value - // Avoid this path since the memory copy is unnecessary because past_key == present_key and - // past_value == present_value - int64_t k_size = (int64_t)batch_size * num_heads * parameters.total_sequence_length * qk_head_size; - int64_t v_size = (int64_t)batch_size * num_heads * parameters.total_sequence_length * v_head_size; - cudaMemcpyAsync(data.present_key, data.past_key, k_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); - cudaMemcpyAsync(data.present_value, data.past_value, v_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); - } else { - ORT_RETURN_IF_ERROR( - LaunchConcatTensorToTensor(stream, parameters.total_sequence_length, sequence_length, - batch_size, qk_head_size, num_heads, - max_threads_per_block, 1, data.past_key, k, data.present_key)); - ORT_RETURN_IF_ERROR( - LaunchConcatTensorToTensor(stream, parameters.total_sequence_length, sequence_length, - batch_size, v_head_size, num_heads, - max_threads_per_block, 1, data.past_value, v, data.present_value)); - // Update pointers to present_k and present_v. - k = data.present_key; - v = data.present_value; - } - } } else { // past_present_share_buffer assert(qk_head_size == v_head_size); assert(data.fused_cross_attention_kernel == nullptr); @@ -870,15 +356,15 @@ Status QkvToContext( present_size_per_batch_k = parameters.max_sequence_length * qk_head_size; present_size_per_batch_v = present_size_per_batch_k; - k = data.present; - v = data.present + batches * present_size_per_batch_k; + qkv.k = data.present; + qkv.v = data.present + batches * present_size_per_batch_k; } // Q, K and V are ready now DUMP_TENSOR_INIT(); if (data.fused_cross_attention_kernel != nullptr) { - assert(qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H); + assert(qkv.format == AttentionQkvFormat::Q_KV_BSNH_BSN2H); // We only enable fused cross attention when there is no key padding mask. // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. @@ -902,8 +388,8 @@ Status QkvToContext( reinterpret_cast(data.fused_cross_attention_kernel); // When there is no bias, we can directly use q and packed kv from inputs. - void const* query = q; - void const* packed_kv = k; + void const* query = qkv.q; + void const* packed_kv = qkv.k; if (data.value == nullptr && data.bias == nullptr) { query = data.query; packed_kv = data.key; @@ -951,10 +437,10 @@ Status QkvToContext( fused_fp16_runner->setup(S, B); if (use_fused_kernel) { - assert(qkv_format == AttentionQkvFormat::QKV_BSN3H); + assert(qkv.format == AttentionQkvFormat::QKV_BSN3H); // When there is no bias, we can directly use packed qkv from inputs. - void const* packed_qkv = qkv; + void const* packed_qkv = qkv.q; if (data.query != nullptr && data.key == nullptr && data.bias == nullptr) { packed_qkv = data.query; } @@ -962,7 +448,7 @@ Status QkvToContext( fused_fp16_runner->run(packed_qkv, sequence_offset, data.output, stream); DUMP_TENSOR("fused output", data.output, batch_size, sequence_length, num_heads, v_head_size); } else { - assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); + assert(qkv.format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); fused_fp16_runner->run(data.gemm_buffer, sequence_offset, data.output, stream); DUMP_TENSOR("fused causal output", data.output, batch_size, sequence_length, num_heads, v_head_size); } @@ -975,22 +461,22 @@ Status QkvToContext( #if USE_FLASH_ATTENTION if (data.use_flash_attention) { - assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + assert(qkv.format == AttentionQkvFormat::Q_K_V_BSNH); assert(nullptr == data.mask_index); assert(nullptr == data.relative_position_bias); assert(parameters.head_size == parameters.v_head_size); - void* query = reinterpret_cast(q); - void* key = reinterpret_cast(k); - void* value = reinterpret_cast(v); + void* query = reinterpret_cast(qkv.q); + void* key = reinterpret_cast(qkv.k); + void* value = reinterpret_cast(qkv.v); // For packed KV, we can use query input directly. if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr && data.bias == nullptr) { query = reinterpret_cast(const_cast(data.query)); } DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", k, batch_size, parameters.total_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size, parameters.total_sequence_length, num_heads, v_head_size); + DUMP_TENSOR_D("k(BSNH)", qkv.k, batch_size, parameters.total_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", qkv.v, batch_size, parameters.total_sequence_length, num_heads, v_head_size); constexpr bool is_causal = false; ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( @@ -1008,11 +494,11 @@ Status QkvToContext( if (data.use_memory_efficient_attention) { // We only enable fused cross attention when there is no key padding mask. // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. - assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + assert(qkv.format == AttentionQkvFormat::Q_K_V_BSNH); - const void* query = q; - const void* key = k; - const void* value = v; + const void* query = qkv.q; + const void* key = qkv.k; + const void* value = qkv.v; // For packed KV, we can use query input directly. if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr) { assert(data.bias == nullptr); @@ -1020,8 +506,8 @@ Status QkvToContext( } DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", k, batch_size, parameters.total_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size, parameters.total_sequence_length, num_heads, v_head_size); + DUMP_TENSOR_D("k(BSNH)", qkv.k, batch_size, parameters.total_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", qkv.v, batch_size, parameters.total_sequence_length, num_heads, v_head_size); MemoryEfficientAttentionParams p; p.sm = device_prop.major * 10 + device_prop.minor; @@ -1061,7 +547,7 @@ Status QkvToContext( #endif // The following are unfused attention. - assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH); + assert(qkv.format == AttentionQkvFormat::Q_K_V_BNSH); const int* mask_index = data.mask_index; gsl::span& mask_index_dims = data.mask_index_dims; @@ -1082,12 +568,12 @@ Status QkvToContext( CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, CUBLAS_OP_T, CUBLAS_OP_N, total_sequence_length, sequence_length, qk_head_size, - &alpha, k, qk_head_size, present_size_per_batch_k, - q, qk_head_size, sequence_length * qk_head_size, + &alpha, qkv.k, qk_head_size, present_size_per_batch_k, + qkv.q, qk_head_size, sequence_length * qk_head_size, &zero, scratch1, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop)); - DUMP_TENSOR_D("Q", q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("K", k, batch_size, num_heads, qk_head_size, sequence_length); + DUMP_TENSOR_D("Q", qkv.q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("K", qkv.k, batch_size, num_heads, qk_head_size, sequence_length); DUMP_TENSOR_D("QK", scratch1, batch_size, num_heads, sequence_length, total_sequence_length); const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads, @@ -1126,14 +612,14 @@ Status QkvToContext( } DUMP_TENSOR_D("Softmax", scratch2, batch_size, num_heads, sequence_length, total_sequence_length); - DUMP_TENSOR_D("V", v, batch_size, num_heads, sequence_length, v_head_size); + DUMP_TENSOR_D("V", qkv.v, batch_size, num_heads, sequence_length, v_head_size); // compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v - T* temp_output = qkv; + T* temp_output = qkv.q; CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, v_head_size, sequence_length, total_sequence_length, - &one, v, v_head_size, present_size_per_batch_v, + &one, qkv.v, v_head_size, present_size_per_batch_v, scratch2, total_sequence_length, sequence_length * total_sequence_length, &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop)); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index af7373dd9f..c361a47c36 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -2,11 +2,12 @@ // Licensed under the MIT License. #pragma once -#include "core/providers/cuda/shared_inc/cuda_utils.h" + #include #include -#include "contrib_ops/cpu/bert/attention_common.h" +#include "core/common/gsl.h" #include "core/framework/allocator.h" +#include "contrib_ops/cpu/bert/attention_common.h" namespace onnxruntime { namespace contrib { @@ -49,39 +50,56 @@ size_t GetAttentionWorkspaceSize( template struct AttentionData { - T* gemm_buffer; - const T* bias; + T* gemm_buffer = nullptr; + const T* bias = nullptr; - const T* query; - const T* key; - const T* value; - const int* mask_index; + const T* query = nullptr; + const T* key = nullptr; + const T* value = nullptr; + const int* mask_index = nullptr; gsl::span mask_index_dims; - const T* past; - const T* past_key; - const T* past_value; - const T* relative_position_bias; + const T* past = nullptr; + const T* past_key = nullptr; + const T* past_value = nullptr; + const T* relative_position_bias = nullptr; - bool has_qkv_workspace; - T* workspace; - T* temp_k_workspace; - T* temp_v_workspace; + bool has_qkv_workspace = false; + T* workspace = nullptr; + T* temp_k_workspace = nullptr; + T* temp_v_workspace = nullptr; - T* output; - T* present; - T* present_key; - T* present_value; + T* output = nullptr; + T* present = nullptr; + T* present_key = nullptr; + T* present_value = nullptr; - void* fused_runner; - const void* fused_cross_attention_kernel; + void* fused_runner = nullptr; + const void* fused_cross_attention_kernel = nullptr; - bool use_flash_attention; - bool use_memory_efficient_attention; + bool use_flash_attention = false; + bool use_memory_efficient_attention = false; - mutable CumulatedSequenceLengthCache* cumulated_sequence_length_q_cache; - mutable CumulatedSequenceLengthCache* cumulated_sequence_length_kv_cache; + mutable CumulatedSequenceLengthCache* cumulated_sequence_length_q_cache = nullptr; + mutable CumulatedSequenceLengthCache* cumulated_sequence_length_kv_cache = nullptr; }; +// Intermediate data pointers available after PrepareQKV +template +struct QkvData { + T* q = nullptr; + T* k = nullptr; + T* v = nullptr; + T* after_v = nullptr; // pointer right after v + AttentionQkvFormat format = AttentionQkvFormat::Q_K_V_BSNH; +}; + +template +Status PrepareQkv(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + QkvData& qkv_data); + template Status QkvToContext( const cudaDeviceProp& device_prop, @@ -161,27 +179,13 @@ Status LaunchConcatTensorToTensor(cudaStream_t stream, const half* tensor_add, half* tensor_out); -Status LaunchConcatPastToPresent(cudaStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const float* past, - const float* k_v, - float* present); - -Status LaunchConcatPastToPresent(cudaStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const half* past, - const half* k_v, - half* present); +template +Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, + int sequence_length, int total_sequence_length, bool pass_past_in_kv, + cudaStream_t stream, + int max_threads_per_block, + AttentionData& data, + QkvData& qkv); template Status LaunchStridedCopy(cudaStream_t stream, diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu new file mode 100644 index 0000000000..cd4137ab11 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu @@ -0,0 +1,492 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "core/providers/cuda/cu_inc/common.cuh" +#include "contrib_ops/cuda/bert/attention_impl.h" +#include "contrib_ops/cuda/bert/add_bias_transpose.h" +#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +Status PrepareQkv_Attention(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + AttentionQkvFormat& qkv_format) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + const bool past_present_share_buffer = parameters.past_present_share_buffer; + void* fused_runner = data.fused_runner; + bool use_flash_or_efficient_attention = data.use_flash_attention || data.use_memory_efficient_attention; + + T* qkv = data.workspace; + + bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); + bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); + + if (data.bias == nullptr) { + assert(nullptr == fused_runner); + // For quantized attention, bias has been added so only need transpose here. + // gemm_buffer should be BxSx3xNxH => qkv: 3xBxNxSxH + assert(qk_head_size == v_head_size); + int matrix_to_trans = (past_present_share_buffer ? 1 : 3); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, matrix_to_trans, sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.gemm_buffer, qkv, 3)); + qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } else { + // For fused TRT attention, transpose qkv to BxSxNx3xH (format 2) + // For flash or memory efficient attention, transpose to 3xBxSxNxH (format 3) + // For unfused kernel, transpose to 3xBxNxSxH (format 1) + // For fused causal kernel, use format 1 since we need have K and V to update present state, + // at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel. + const int format = (use_fused_kernel ? 2 : (use_flash_or_efficient_attention ? 3 : 1)); + qkv_format = use_fused_kernel + ? AttentionQkvFormat::QKV_BSN3H + : (use_flash_or_efficient_attention + ? AttentionQkvFormat::Q_K_V_BSNH + : (use_fused_causal + ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH + : AttentionQkvFormat::Q_K_V_BNSH)); + + // For fused causal, we will update gemm_buffer with bias directly. + T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr; + + int matrix_to_transpose = ((format == AttentionQkvFormat::Q_K_V_BNSH && past_present_share_buffer) ? 1 : 3); + // format 1: BxSx(NH + NH + NH_v) => BxNxSxH + BxNxSxH + BxNxSxH_v + // format 2: BxSx(NH + NH + NH) => BxSxNx(H + H + H) + LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias, + 3, parameters.do_rotary, parameters.past_sequence_length); + } + return Status::OK(); +} + +// For MultiHeadAttention with past state +template +Status PrepareQkv_MHA_WithPast(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + DUMP_TENSOR_INIT(); + + if (data.bias == nullptr) { + // Below logic does not support fused attention with past without bias + // When there is past state, the format shall be BxNxSxH, so we disable fused attention when there is past. + + // cross attention with past state + if (data.past_key != nullptr && data.present_key == nullptr) { + assert(data.past_value != nullptr); + assert(data.query != nullptr); + assert(data.key == nullptr); + assert(data.value == nullptr); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.query, q)); + } + // cross attention with present state or self attention with present state + else if (data.past_key == nullptr && data.present_key != nullptr) { + assert(data.past_value == nullptr); + assert(data.present_value != nullptr); + assert(data.query != nullptr); + assert(data.key != nullptr); + assert(data.value != nullptr); + + // TODO: supporting packed qkv for self attention may benefit performance + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.query, q)); + + // TODO: supporting packed kv for cross attention may benefit performance + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.key, data.present_key)); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, + max_threads_per_block, false, data.value, data.present_value)); + } + // self attention with past and present state + else { + assert(data.past_key != nullptr); + assert(data.past_value != nullptr); + assert(data.present_key != nullptr); + assert(data.present_value != nullptr); + assert(data.query != nullptr); + assert(data.key != nullptr); + assert(data.value != nullptr); + // TODO: supporting packed qkv for self attention may benefit performance + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.query, q)); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.key, k)); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, + max_threads_per_block, false, data.value, v)); + } + qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } +#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION + // When past_key/past_value are inputted directly as key/value and there is no present_key/present_value + else if ((data.use_memory_efficient_attention || data.use_flash_attention) && + data.past_key != nullptr && + data.past_value != nullptr && + parameters.pass_past_in_kv) { + // Transpose past_key and past_value to use memory efficient attention + + // past_key (BxNxSxH) => temp_k_workspace (BxSxNxH) + ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.past_key, data.temp_k_workspace)); + // past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) + ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.past_value, data.temp_v_workspace)); + + // query => q, temp_k_workspace => k, temp_v_workspace => v + LaunchAddBias(stream, max_threads_per_block, + batch_size, sequence_length, kv_sequence_length, + num_heads, qk_head_size, v_head_size, + data.bias, data.query, data.temp_k_workspace, data.temp_v_workspace, q, k, v); + + DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + + data.past_key = nullptr; + data.past_value = nullptr; + } + // When there is no past_key/past_value and there is present_key/present_value + // (e.g. get initial kv to use as past_kv in the next iteration) + else if ((data.use_memory_efficient_attention || data.use_flash_attention) && + data.present_key != nullptr && + data.present_value != nullptr) { + // Use memory efficient attention kernel + LaunchAddBias(stream, max_threads_per_block, + batch_size, sequence_length, kv_sequence_length, + num_heads, qk_head_size, v_head_size, + data.bias, data.query, data.key, data.value, q, data.temp_k_workspace, data.temp_v_workspace); + + // temp_k_workspace (BxSxNxH) => present_k (BxNxSxH) + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.temp_k_workspace, data.present_key)); + + // temp_v_workspace (BxSxNxH_v) => present_v (BxNxSxH_v) + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, + max_threads_per_block, false, data.temp_v_workspace, data.present_value)); + + DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", data.temp_k_workspace, batch_size, kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", data.temp_v_workspace, batch_size, kv_sequence_length, num_heads, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } +#endif + else { + // Use unfused kernel for Q, use unfused kernel for K and V if needed + constexpr int format = 0; + // Query (BxSxNxH) => Q (BxNxSxH) + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.query, data.bias, q, + true, -1); + + if (!parameters.pass_past_in_kv) { + T* k_dest = (data.past_key == nullptr && data.present_key != nullptr) ? data.present_key : k; + T* v_dest = (data.past_value == nullptr && data.present_value != nullptr) ? data.present_value : v; + + // Key (BxLxNxH) => K (BxNxLxH) + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, data.bias + num_heads * qk_head_size, k_dest, + true, -1); + + // Value (BxLxNxH_v) => V (BxNxLxH_v) + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, v_head_size, + data.value, data.bias + 2 * num_heads * qk_head_size, v_dest, + true, -1); + + DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("k(BNSH)", k_dest, batch_size, num_heads, kv_sequence_length, qk_head_size); + DUMP_TENSOR_D("v(BNSH)", v_dest, batch_size, num_heads, kv_sequence_length, v_head_size); + } + qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } + return Status::OK(); +} + +// For MultiHeadAttention without past state, with packed QKV inputs +template +Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + void* fused_runner = data.fused_runner; + + T* qkv = data.workspace; + + bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); + + assert(data.bias == nullptr); + assert(qk_head_size == v_head_size); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("packed_qkv", data.query, batch_size * sequence_length, num_heads, 3, qk_head_size); + + if (data.use_memory_efficient_attention || data.use_flash_attention) { + // unpack qkv to BSNH. Note that there is no bias so we need not output query to q. + constexpr int format = 4; + T* qkv_add_bias = nullptr; + LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.query, data.bias, qkv, + true, v_head_size, qkv_add_bias, 3); + DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", k, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", v, batch_size, sequence_length, num_heads, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } else { + if (!use_fused_kernel) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, NOT_IMPLEMENTED, + "packed QKV format is not implemented for current GPU. Please disable it in fusion options."); + } + + qkv_format = AttentionQkvFormat::QKV_BSN3H; + } + return Status::OK(); +} + +// For MultiHeadAttention without past state, with packed KV inputs +template +Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { + const int batch_size = parameters.batch_size; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + // TODO: unpack kv to BNSH for unfused kernel so that we can remove the following constraint. + // CheckInputs verified this constraint. + assert(data.bias == nullptr); + assert(qk_head_size == v_head_size); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("packed_kv", data.key, batch_size * kv_sequence_length, num_heads, 2, qk_head_size); + + if (data.use_memory_efficient_attention || data.use_flash_attention) { + // unpack kv to BSNH. Note that there is no bias so we need not output query to q. + constexpr int format = 4; + T* qkv_add_bias = nullptr; + const T* kv_bias = (data.bias == nullptr ? data.bias : data.bias + parameters.hidden_size); + LaunchAddBiasTranspose(stream, 2, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, kv_bias, k, + true, v_head_size, qkv_add_bias, 2); + DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } else { + if (data.fused_cross_attention_kernel == nullptr) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, NOT_IMPLEMENTED, + "packed KV format is not implemented for current GPU. Please disable packed kv in fusion options."); + } + + qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; + } + return Status::OK(); +} + +// For MultiHeadAttention without past state, with Q, K and V inputs +template +Status PrepareQkv_MHA_NotPacked(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + void* fused_runner = data.fused_runner; + + T* qkv = data.workspace; + + bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); + bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); + + // gemm_buffer == nullptr and not packed + assert(data.query != nullptr && data.key != nullptr && data.value != nullptr); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("query", data.query, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("key", data.key, batch_size, kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("value", data.value, batch_size, kv_sequence_length, num_heads, v_head_size); + +#if DUMP_TENSOR_LEVEL > 1 + if (data.bias != nullptr) { + DUMP_TENSOR_D("query_bias", data.bias, num_heads, qk_head_size); + DUMP_TENSOR_D("key_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size); + DUMP_TENSOR_D("value_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size); + } +#endif + + if (data.relative_position_bias != nullptr && parameters.broadcast_res_pos_bias) { + DUMP_TENSOR_D("relative_position_bias", data.relative_position_bias, + num_heads, sequence_length, kv_sequence_length); + } + + if (data.mask_index != nullptr && parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) { + DUMP_TENSOR_D("mask_index", data.mask_index, 3 * batch_size + 2, 1); + } + + if (data.fused_cross_attention_kernel != nullptr) { + assert(qk_head_size == v_head_size); + + // For fused cross attention, besides adding bias, K and V needed to be packed: + // K (BxSxNxH), V (BxSxNxH) => BxSxNx2xH + LaunchAddBiasTransposeTrt( + stream, max_threads_per_block, + batch_size, sequence_length, + num_heads, qk_head_size, + data.bias, data.query, data.key, data.value, qkv, true, kv_sequence_length); + + qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; + } +#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION + else if (data.use_memory_efficient_attention || data.use_flash_attention) { + LaunchAddBias(stream, max_threads_per_block, + batch_size, sequence_length, kv_sequence_length, + num_heads, qk_head_size, v_head_size, + data.bias, data.query, data.key, data.value, q, k, v); + + DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } +#endif + else if (use_fused_kernel) { + assert(qk_head_size == v_head_size); + + // Q (BxSxNxH), K (BxSxNxH), V (BxSxNxH) => BxSxNx(H + H + H) + LaunchAddBiasTransposeTrt( + stream, max_threads_per_block, + batch_size, sequence_length, + num_heads, qk_head_size, + data.bias, data.query, data.key, data.value, qkv, false, kv_sequence_length); + DUMP_TENSOR_D("qkv(BSN3H)", qkv, batch_size, sequence_length, num_heads, 2 * qk_head_size + v_head_size); + + qkv_format = AttentionQkvFormat::QKV_BSN3H; + } else { // unfused kernel + ORT_ENFORCE(!use_fused_causal, "MultiHeadAttention has not enabled fused causal"); + + // Query (BxSxNxH) => Q (BxNxSxH) + constexpr int format = 0; + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.query, data.bias, q, + true, -1); + + // Key (BxLxNxH) => K (BxNxLxH) + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, k, + true, -1); + + // Value (BxLxNxH_v) => K (BxNxLxH_v) + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, v_head_size, + data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, v, + true, -1); + + DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("k(BNSH)", k, batch_size, num_heads, kv_sequence_length, qk_head_size); + DUMP_TENSOR_D("v(BNSH)", v, batch_size, num_heads, kv_sequence_length, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } + return Status::OK(); +} + +template +Status PrepareQkv(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + QkvData& qkv) { + if (data.has_qkv_workspace) { + const int size_per_batch_q = parameters.sequence_length * parameters.head_size; + const int size_per_batch_k = parameters.kv_sequence_length * parameters.head_size; + const int size_per_batch_v = parameters.kv_sequence_length * parameters.v_head_size; + const int batches = parameters.batch_size * parameters.num_heads; + const size_t elements_q = static_cast(batches) * static_cast(size_per_batch_q); + const size_t elements_k = static_cast(batches) * static_cast(size_per_batch_k); + const size_t elements_v = static_cast(batches) * static_cast(size_per_batch_v); + qkv.q = data.workspace; + qkv.k = data.workspace + elements_q; + qkv.v = qkv.k + elements_k; + qkv.after_v = qkv.v + elements_v; + } + + if (nullptr != data.gemm_buffer) { // Attention operator + ORT_RETURN_IF_ERROR(PrepareQkv_Attention(parameters, data, stream, max_threads_per_block, + qkv.format)); + } else if (data.past_key != nullptr || data.present_key != nullptr) { // mha operator with past/present state + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast(parameters, data, stream, max_threads_per_block, + qkv.q, qkv.k, qkv.v, qkv.format)); + } else if (data.key == nullptr) { // multihead attention operator, no past, packed qkv + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedQKV(parameters, data, stream, max_threads_per_block, + qkv.q, qkv.k, qkv.v, qkv.format)); + } else if (data.value == nullptr) { // multihead attention operator, no past, packed kv + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedKV(parameters, data, stream, max_threads_per_block, + qkv.q, qkv.k, qkv.v, qkv.format)); + } else { // multihead attention operator, no past, separated Q/K/V inputs + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NotPacked(parameters, data, stream, max_threads_per_block, + qkv.q, qkv.k, qkv.v, qkv.format)); + } + + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + return Status::OK(); +} + +// Template Instantiation +template Status PrepareQkv( + contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + QkvData& qkv); + +template Status PrepareQkv( + contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + QkvData& qkv); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 8f1252f863..25f3f59165 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -263,14 +263,12 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; AttentionData data; - data.gemm_buffer = nullptr; data.bias = (nullptr == bias) ? nullptr : reinterpret_cast(bias->Data()); data.query = reinterpret_cast(query->Data()); data.key = (nullptr == key || parameters.pass_past_in_kv) ? nullptr : reinterpret_cast(key->Data()); data.value = (nullptr == value || parameters.pass_past_in_kv) ? nullptr : reinterpret_cast(value->Data()); data.mask_index = (nullptr == key_padding_mask) ? nullptr : key_padding_mask->Data(); data.mask_index_dims = (nullptr == key_padding_mask) ? gsl::span() : key_padding_mask->Shape().GetDims(); - data.past = nullptr; data.past_key = pass_key_value_as_past ? reinterpret_cast(key->Data()) : (nullptr == past_key) ? nullptr : reinterpret_cast(past_key->Data()); @@ -283,7 +281,6 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.temp_k_workspace = use_temp_k_v_workspace ? reinterpret_cast(temp_k_work_space.get()) : nullptr; data.temp_v_workspace = use_temp_k_v_workspace ? reinterpret_cast(temp_v_work_space.get()) : nullptr; data.output = reinterpret_cast(output->MutableData()); - data.present = nullptr; data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast(present_key->MutableData()); data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast(present_value->MutableData()); data.fused_runner = reinterpret_cast(fused_runner); diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index b0556512de..705f2d49fe 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -195,28 +195,21 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; AttentionData data; data.gemm_buffer = reinterpret_cast(gemm_buffer.get()); - data.bias = nullptr; // bias has been added - data.query = nullptr; - data.key = nullptr; - data.value = nullptr; - data.mask_index = (nullptr == mask_index) ? nullptr : mask_index->Data(); - data.mask_index_dims = (nullptr == mask_index) ? gsl::span() : mask_index->Shape().GetDims(); - data.past = (nullptr == past_tensor) ? nullptr : reinterpret_cast(past_tensor->Data()); - data.past_key = nullptr; - data.past_value = nullptr; - data.relative_position_bias = nullptr; // add_qk is not supported in quantized attention + if (nullptr != mask_index) { + data.mask_index = mask_index->Data(); + data.mask_index_dims = mask_index->Shape().GetDims(); + } + + if (nullptr != past_tensor) { + data.past = reinterpret_cast(past_tensor->Data()); + } + data.has_qkv_workspace = true; data.workspace = reinterpret_cast(work_space.get()); data.output = reinterpret_cast(output->MutableData()); - data.present = (nullptr == present) ? nullptr : reinterpret_cast(present->MutableData()); - data.present_key = nullptr; - data.present_value = nullptr; - data.fused_runner = fused_runner; - data.fused_cross_attention_kernel = nullptr; - data.use_flash_attention = use_flash_attention; - data.use_memory_efficient_attention = use_memory_efficient_attention; - data.cumulated_sequence_length_q_cache = nullptr; - data.cumulated_sequence_length_kv_cache = nullptr; + if (nullptr != present) { + data.present = reinterpret_cast(present->MutableData()); + } return QkvToContext(GetDeviceProp(), cublas, context->GetComputeStream(), parameters, data); }