From 94c69f55d480cb4a8dcbc161d29ef3acca9392a7 Mon Sep 17 00:00:00 2001
From: aciddelgado <139922440+aciddelgado@users.noreply.github.com>
Date: Mon, 22 Apr 2024 19:57:05 -0700
Subject: [PATCH] GQA 4 CPU (#20299)
### Description
Support GQA operator on CPU with FP32.
### Motivation and Context
Right now, models generated for CPU and GPU must be different. GQA CPU
allows these models to be the same.
---
docs/ContribOperators.md | 5 +-
docs/OperatorKernels.md | 1 +
.../contrib_ops/cpu/bert/attention_base.h | 1 +
.../contrib_ops/cpu/bert/attention_helper.h | 70 +
.../contrib_ops/cpu/bert/attention_utils.cc | 246 +++
.../contrib_ops/cpu/bert/attention_utils.h | 61 +
.../contrib_ops/cpu/bert/gqa_attention_base.h | 276 +++
.../cpu/bert/group_query_attention.cc | 192 ++
.../cpu/bert/group_query_attention.h | 21 +
.../cpu/bert/group_query_attention_helper.h | 299 +++
.../cpu/bert/multihead_attention.cc | 215 +-
.../contrib_ops/cpu/bert/rotary_embedding.cc | 129 +-
.../contrib_ops/cpu/bert/rotary_embedding.h | 6 +
.../cpu/bert/rotary_embedding_helper.h | 25 +-
.../contrib_ops/cpu/cpu_contrib_kernels.cc | 2 +
.../core/graph/contrib_ops/bert_defs.cc | 5 +-
..._flash_attn.py => test_flash_attn_cuda.py} | 0
.../test/python/transformers/test_gqa_cpu.py | 1884 +++++++++++++++++
18 files changed, 3159 insertions(+), 279 deletions(-)
create mode 100644 onnxruntime/contrib_ops/cpu/bert/attention_utils.cc
create mode 100644 onnxruntime/contrib_ops/cpu/bert/attention_utils.h
create mode 100644 onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
create mode 100644 onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
create mode 100644 onnxruntime/contrib_ops/cpu/bert/group_query_attention.h
create mode 100644 onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
rename onnxruntime/test/python/transformers/{test_flash_attn.py => test_flash_attn_cuda.py} (100%)
create mode 100644 onnxruntime/test/python/transformers/test_gqa_cpu.py
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 3d984a54c0..8d3d807514 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -2455,6 +2455,9 @@ This version of the operator has been available since version 1 of the 'com.micr
Group Query Self/Cross Attention.
Supports different number of heads for q and kv. Only supports causal or local attention.
+ Supports rotary position embedding.
+ Supports k-v cache.
+ CPU EP supports fp32... CUDA EP supports fp16.
#### Version
@@ -2514,7 +2517,7 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints
-- T : tensor(float16), tensor(bfloat16)
+- T : tensor(float16), tensor(bfloat16), tensor(float)
- Constrain input and output to float tensors.
- M : tensor(int32)
- Constrain mask to int tensor.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index f54f264255..8fa67ee172 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -481,6 +481,7 @@ Do not modify directly.*
|Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)|
|GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)|
+|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float)|
|Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)|
|MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)|
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_base.h
index a6782daa58..af902a713e 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_base.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.h
@@ -68,6 +68,7 @@ class AttentionBase {
const Tensor* past_seq_len = nullptr) const;
int num_heads_; // number of attention heads
+ int kv_num_heads_; // different for k and v for group query attention
bool is_unidirectional_; // whether every token can only attend to previous tokens.
std::vector qkv_hidden_sizes_; // Q, K, V hidden sizes parsed from the qkv_hidden_sizes attribute.
bool require_same_hidden_size_; // whether the implementation supports different hidden sizes of Q/K/V.
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h
index f1a0ce994e..5d37213303 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h
@@ -153,6 +153,49 @@ void PrepareMask(const int32_t* mask_index,
}
}
+// Applies causal mask and past seqlens right pad to the mask_data buffer.
+template
+void PrepareMaskGQA(T* mask_data,
+ int batch_size,
+ int sequence_length,
+ int buffer_sequence_length,
+ int local_window_size,
+ const int32_t* seqlens_k) {
+ // mask_data has been filled with 0, and its shape is BxSxT
+ T* p_mask = mask_data;
+ // TODO: parallelize this
+ for (int b_i = 0; b_i < batch_size; b_i++) {
+ if (sequence_length > 1) {
+ // Apply causal/local mask for prompt case.
+ for (int s_i = 0; s_i < sequence_length; s_i++) {
+ for (int m_i = s_i + 1; m_i < buffer_sequence_length; m_i++) {
+ p_mask[s_i * buffer_sequence_length + m_i] = std::numeric_limits::lowest();
+ }
+ // Apply local mask.
+ if (local_window_size > 0) {
+ for (int m_i = 0; m_i < s_i - local_window_size; m_i++) {
+ p_mask[s_i * buffer_sequence_length + m_i] = std::numeric_limits::lowest();
+ }
+ }
+ }
+ } else if (sequence_length == 1) {
+ // Apply right padding to mask for token gen case.
+ int total_seqlen = seqlens_k[b_i] + 1;
+ for (int m_i = total_seqlen; m_i < buffer_sequence_length; m_i++) {
+ p_mask[m_i] = std::numeric_limits::lowest();
+ }
+ // Apply local mask.
+ if (local_window_size > 0) {
+ for (int m_i = 0; m_i < total_seqlen - local_window_size - 1; m_i++) {
+ p_mask[m_i] = std::numeric_limits::lowest();
+ }
+ }
+ }
+ ptrdiff_t mask_to_advance = SafeInt(sequence_length) * buffer_sequence_length;
+ p_mask += mask_to_advance;
+ }
+}
+
// Concatenate a past state chunk PxH with input state chunk LxH into present state chunk TxH
// Returns a pointer to the start of present state chunk.
template
@@ -175,5 +218,32 @@ T* ConcatStateChunk(const T* past,
return start;
}
+// GQA version of ConcatStateChunk
+template
+T* ConcatStateChunkGQA(const T* past,
+ const T* chunk,
+ T* present,
+ size_t present_buff_chunk_length,
+ size_t past_buff_chunk_length,
+ size_t past_chunk_length,
+ size_t new_chunk_length,
+ bool is_prompt,
+ bool past_present_share_buffer,
+ std::ptrdiff_t i) {
+ T* start = present + i * present_buff_chunk_length;
+
+ T* p = start;
+ if (!is_prompt) {
+ if (!past_present_share_buffer) {
+ const T* src_past = past + i * past_buff_chunk_length;
+ memcpy(p, src_past, past_chunk_length * sizeof(T));
+ }
+ p += past_chunk_length;
+ }
+
+ memcpy(p, chunk, new_chunk_length * sizeof(T));
+ return start;
+}
+
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc b/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc
new file mode 100644
index 0000000000..7b84971585
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc
@@ -0,0 +1,246 @@
+#include "attention_utils.h"
+#include "core/common/common.h"
+#include "core/framework/tensorprotoutils.h"
+#include "core/framework/transpose_helper.h"
+#include "core/providers/cpu/tensor/reshape_helper.h"
+#include "core/providers/cpu/math/element_wise_ops.h"
+
+using onnxruntime::concurrency::ThreadPool;
+
+namespace onnxruntime {
+namespace contrib {
+
+// Reshape Q/K/V from BxSxD to BxSxNxH
+inline Status Reshape_BSD_to_BSNH(Tensor* qkv,
+ int batch_size,
+ int sequence_length,
+ int num_heads,
+ int head_size) {
+ qkv->Reshape(TensorShape({batch_size, sequence_length, num_heads, head_size}));
+ return Status::OK();
+}
+
+// Transpose Q/K/V from BxSxNxH to BxNxSxH
+inline Status Transpose_BSNH_to_BNSH(const Tensor* qkv,
+ OrtValue& qkv_transposed,
+ concurrency::ThreadPool* tp) {
+ std::vector permutations({0, 2, 1, 3});
+ gsl::span permutations_span{permutations};
+ size_t from = 2, to = 1;
+ SingleAxisTranspose(permutations, *qkv, *qkv_transposed.GetMutable(), from, to, nullptr, tp);
+ return Status::OK();
+}
+
+// Add bias + transpose for each of Q/K/V
+template
+Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V data - query is BxSxD, key is BxLxD, value is BxLxD_v
+ const T* qkv_bias, // Input: QKV bias - bias is (D + D + D_v)
+ OrtValue& qkv_with_bias_transposed, // Output: Q/K/V data - query is BxNxSxH, key is BxNxLxH, value is BxNxLxH_v
+ int bias_offset, // bias offset to enter qkv_bias
+ int batch_size, // batch size
+ int sequence_length, // sequence_length for Q, kv_sequence_length for K/V
+ int num_heads, // num heads
+ int head_size, // head_size for Q/K, v_head_size for V
+ int hidden_size, // hidden_size for Q/K, v_hidden_size for V
+ OpKernelContext* context) {
+ // Note: the comments below will refer to Q's dimensions for simplicity
+ auto element_type = DataTypeImpl::GetType();
+ constexpr size_t element_size = sizeof(T);
+ ProcessBroadcastSpanFuncs add_funcs{
+ [](BroadcastHelper& per_iter_bh) {
+ per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array();
+ },
+ [](BroadcastHelper& per_iter_bh) {
+ per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1();
+ },
+ [](BroadcastHelper& per_iter_bh) {
+ per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1();
+ }}; // For element-wise add
+
+ // Allocate space for output of Q(BS, D) + bias(D)
+ AllocatorPtr allocator;
+ ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
+ std::vector old_dims({batch_size, sequence_length, hidden_size});
+ gsl::span old_dims_span{old_dims};
+ TensorShape qkv_with_bias_shape(old_dims_span);
+ OrtValue qkv_with_bias;
+ Tensor::InitOrtValue(element_type, qkv_with_bias_shape, allocator, qkv_with_bias);
+
+ // Get Q's bias from combined bias
+ std::vector bias_dims({hidden_size});
+ gsl::span bias_dims_span{bias_dims};
+ TensorShape bias_shape(bias_dims_span);
+ OrtValue bias;
+ Tensor::InitOrtValue(element_type, bias_shape, allocator, bias);
+ memcpy(bias.GetMutable()->MutableData(), qkv_bias + bias_offset, hidden_size * element_size);
+
+ // Compute Q(BS, D) + bias(D) as broadcasted element-wise add
+ {
+ InputBroadcaster input_broadcaster(*bias.GetMutable(), *qkv);
+ const InputBroadcaster& const_input_broadcaster = input_broadcaster;
+ Tensor& output_tensor = *qkv_with_bias.GetMutable();
+
+ size_t span_size = input_broadcaster.GetSpanSize();
+ size_t output_size = static_cast(output_tensor.Shape().Size());
+ void* user_data = nullptr;
+
+ const int loop_len = static_cast(output_size / span_size);
+ double unit_cost = 1.0f;
+ const auto cost = TensorOpCost{static_cast(input_broadcaster.Input0ElementSize()) * span_size,
+ static_cast(output_tensor.DataType()->Size()) * span_size,
+ unit_cost * span_size};
+ auto tp = context->GetOperatorThreadPool();
+ ThreadPool::TryParallelFor(tp, loop_len, cost,
+ [span_size, &const_input_broadcaster, &output_tensor, &add_funcs, user_data](std::ptrdiff_t first_span,
+ std::ptrdiff_t last_span) {
+ InputBroadcaster segment_input_broadcaster(const_input_broadcaster);
+ segment_input_broadcaster.AdvanceBy(first_span * span_size);
+
+ OutputBroadcaster segment_output_broadcaster(span_size, output_tensor,
+ first_span * span_size, last_span * span_size);
+
+ BroadcastHelper segment_helper(segment_input_broadcaster, segment_output_broadcaster, user_data);
+ BroadcastLooper(segment_helper, add_funcs);
+ });
+ }
+
+ // Reshape Q from BxSxD to BxSxNxH
+ ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(qkv_with_bias.GetMutable(), batch_size, sequence_length, num_heads, head_size));
+
+ // Transpose Q from BxSxNxH to BxNxSxH
+ auto tp = context->GetOperatorThreadPool();
+ ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH(qkv_with_bias.GetMutable(), qkv_with_bias_transposed, tp));
+
+ return Status::OK();
+}
+
+// Add bias + reshape for each of Q/K/V
+// This is used in decoder_with_past when the sequence length is 1
+template
+Status AddBiasReshape(const Tensor* qkv, // Input: Q/K/V data - query is BxSxD, key is BxLxD, value is BxLxD_v
+ const T* qkv_bias, // Input: QKV bias - bias is (D + D + D_v)
+ OrtValue& qkv_with_bias, // Output: Q/K/V data - query is BxNxSxH, key is BxNxLxH, value is BxNxLxH_v
+ int bias_offset, // bias offset to enter qkv_bias
+ int batch_size, // batch size
+ int sequence_length, // sequence_length for Q, kv_sequence_length for K/V
+ int num_heads, // num heads
+ int head_size, // head_size for Q/K, v_head_size for V
+ int hidden_size, // hidden_size for Q/K, v_hidden_size for V
+ OpKernelContext* context) {
+ // Note: the comments below will refer to Q's dimensions for simplicity
+ auto element_type = DataTypeImpl::GetType();
+ constexpr size_t element_size = sizeof(T);
+ ProcessBroadcastSpanFuncs add_funcs{
+ [](BroadcastHelper& per_iter_bh) {
+ per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array();
+ },
+ [](BroadcastHelper& per_iter_bh) {
+ per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1();
+ },
+ [](BroadcastHelper& per_iter_bh) {
+ per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1();
+ }}; // For element-wise add
+
+ // Get Q's bias from combined bias
+ AllocatorPtr allocator;
+ ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
+ std::vector bias_dims({hidden_size});
+ gsl::span bias_dims_span{bias_dims};
+ TensorShape bias_shape(bias_dims_span);
+ OrtValue bias;
+ Tensor::InitOrtValue(element_type, bias_shape, allocator, bias);
+ auto num_bias_elements = SafeInt(hidden_size) * element_size;
+ memcpy(bias.GetMutable()->MutableData(), qkv_bias + bias_offset, num_bias_elements);
+
+ // Compute Q(BS, D) + bias(D) as broadcasted element-wise add
+ {
+ InputBroadcaster input_broadcaster(*bias.GetMutable(), *qkv);
+ const InputBroadcaster& const_input_broadcaster = input_broadcaster;
+ Tensor& output_tensor = *qkv_with_bias.GetMutable();
+
+ size_t span_size = input_broadcaster.GetSpanSize();
+ size_t output_size = static_cast(output_tensor.Shape().Size());
+ void* user_data = nullptr;
+
+ const int loop_len = static_cast(output_size / span_size);
+ double unit_cost = 1.0f;
+ const auto cost = TensorOpCost{static_cast(input_broadcaster.Input0ElementSize()) * span_size,
+ static_cast(output_tensor.DataType()->Size()) * span_size,
+ unit_cost * span_size};
+ auto tp = context->GetOperatorThreadPool();
+ ThreadPool::TryParallelFor(tp, loop_len, cost,
+ [span_size, &const_input_broadcaster, &output_tensor, &add_funcs, user_data](std::ptrdiff_t first_span,
+ std::ptrdiff_t last_span) {
+ InputBroadcaster segment_input_broadcaster(const_input_broadcaster);
+ segment_input_broadcaster.AdvanceBy(first_span * span_size);
+
+ OutputBroadcaster segment_output_broadcaster(span_size, output_tensor,
+ first_span * span_size, last_span * span_size);
+
+ BroadcastHelper segment_helper(segment_input_broadcaster, segment_output_broadcaster, user_data);
+ BroadcastLooper(segment_helper, add_funcs);
+ });
+ }
+
+ // Reshape Q from BxSxD to BxNxSxH
+ qkv_with_bias.GetMutable()->Reshape(TensorShape({batch_size, num_heads, sequence_length, head_size}));
+
+ return Status::OK();
+}
+
+template
+Status MaybeTransposeToBNSHAndAddBias(OpKernelContext* context, AllocatorPtr allocator,
+ int batch_size, int num_heads, int sequence_length, int head_size,
+ const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out) {
+ auto element_type = DataTypeImpl::GetType();
+ std::vector new_dims({batch_size, num_heads, sequence_length, head_size});
+ gsl::span new_dims_span{new_dims};
+ TensorShape v_BNLH(new_dims_span);
+ Tensor::InitOrtValue(element_type, v_BNLH, allocator, out);
+ if (bias == nullptr) {
+ std::unique_ptr reshaped;
+ if (in->Shape().GetDims().size() == 3) {
+ reshaped = std::make_unique(in->DataType(), in->Shape(), const_cast(in->DataRaw()), in->Location());
+ ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(reshaped.get(), batch_size, sequence_length, num_heads, head_size));
+ }
+ ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH((reshaped == nullptr) ? in : reshaped.get(), out));
+ } else {
+ const auto* qkv_bias = bias->Data();
+ if (sequence_length == 1) {
+ ORT_RETURN_IF_ERROR(AddBiasReshape(in, qkv_bias, out, bias_offset, batch_size, sequence_length, num_heads, head_size, num_heads * head_size, context));
+ } else {
+ ORT_RETURN_IF_ERROR(AddBiasTranspose(in, qkv_bias, out, bias_offset, batch_size, sequence_length, num_heads, head_size, num_heads * head_size, context));
+ }
+ }
+ return Status::OK();
+};
+
+template Status MaybeTransposeToBNSHAndAddBias(OpKernelContext* context, AllocatorPtr allocator,
+ int batch_size, int num_heads, int sequence_length, int head_size,
+ const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out);
+
+template
+Status MaybeTransposeToBNSH(AllocatorPtr allocator,
+ int batch_size, int num_heads, int sequence_length, int head_size,
+ const Tensor* in, OrtValue& out) {
+ auto element_type = DataTypeImpl::GetType();
+ std::vector new_dims({batch_size, num_heads, sequence_length, head_size});
+ gsl::span new_dims_span{new_dims};
+ TensorShape v_BNLH(new_dims_span);
+ Tensor::InitOrtValue(element_type, v_BNLH, allocator, out);
+ std::unique_ptr reshaped;
+ if (in->Shape().GetDims().size() == 3) {
+ reshaped = std::make_unique(in->DataType(), in->Shape(), const_cast(in->DataRaw()), in->Location());
+ ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(reshaped.get(), batch_size, sequence_length, num_heads, head_size));
+ }
+ ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH((reshaped == nullptr) ? in : reshaped.get(), out));
+
+ return Status::OK();
+};
+
+template Status MaybeTransposeToBNSH(AllocatorPtr allocator,
+ int batch_size, int num_heads, int sequence_length, int head_size,
+ const Tensor* in, OrtValue& out);
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_utils.h b/onnxruntime/contrib_ops/cpu/bert/attention_utils.h
new file mode 100644
index 0000000000..d7fb0c496b
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_utils.h
@@ -0,0 +1,61 @@
+#pragma once
+#include "core/common/common.h"
+#include "core/framework/tensorprotoutils.h"
+#include "core/framework/transpose_helper.h"
+#include "core/providers/cpu/tensor/reshape_helper.h"
+#include "core/providers/cpu/math/element_wise_ops.h"
+
+namespace onnxruntime {
+namespace contrib {
+
+// Reshape Q/K/V from BxSxD to BxSxNxH
+Status Reshape_BSD_to_BSNH(Tensor* qkv,
+ int batch_size,
+ int sequence_length,
+ int num_heads,
+ int head_size);
+
+// Transpose Q/K/V from BxSxNxH to BxNxSxH
+Status Transpose_BSNH_to_BNSH(const Tensor* qkv,
+ OrtValue& qkv_transposed,
+ concurrency::ThreadPool* tp = nullptr);
+
+// Add bias + transpose for each of Q/K/V
+template
+Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V data - query is BxSxD, key is BxLxD, value is BxLxD_v
+ const T* qkv_bias, // Input: QKV bias - bias is (D + D + D_v)
+ OrtValue& qkv_with_bias_transposed, // Output: Q/K/V data - query is BxNxSxH, key is BxNxLxH, value is BxNxLxH_v
+ int bias_offset, // bias offset to enter qkv_bias
+ int batch_size, // batch size
+ int sequence_length, // sequence_length for Q, kv_sequence_length for K/V
+ int num_heads, // num heads
+ int head_size, // head_size for Q/K, v_head_size for V
+ int hidden_size, // hidden_size for Q/K, v_hidden_size for V
+ OpKernelContext* context); // OpKernelContext
+
+// Add bias + reshape for each of Q/K/V
+// This is used in decoder_with_past when the sequence length is 1
+template
+Status AddBiasReshape(const Tensor* qkv, // Input: Q/K/V data - query is BxSxD, key is BxLxD, value is BxLxD_v
+ const T* qkv_bias, // Input: QKV bias - bias is (D + D + D_v)
+ OrtValue& qkv_with_bias, // Output: Q/K/V data - query is BxNxSxH, key is BxNxLxH, value is BxNxLxH_v
+ int bias_offset, // bias offset to enter qkv_bias
+ int batch_size, // batch size
+ int sequence_length, // sequence_length for Q, kv_sequence_length for K/V
+ int num_heads, // num heads
+ int head_size, // head_size for Q/K, v_head_size for V
+ int hidden_size, // hidden_size for Q/K, v_hidden_size for V
+ OpKernelContext* context); // OpKernelContext
+
+template
+Status MaybeTransposeToBNSHAndAddBias(OpKernelContext* context, AllocatorPtr allocator,
+ int batch_size, int num_heads, int sequence_length, int head_size,
+ const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out);
+
+template
+Status MaybeTransposeToBNSH(AllocatorPtr allocator,
+ int batch_size, int num_heads, int sequence_length, int head_size,
+ const Tensor* in, OrtValue& out);
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
new file mode 100644
index 0000000000..b0ebc50215
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
@@ -0,0 +1,276 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "attention_base.h"
+#include "attention_helper.h"
+
+#include "core/common/common.h"
+#include "contrib_ops/cpu/bert/attention_common.h"
+#include "core/common/safeint.h"
+#include "core/framework/op_kernel.h"
+
+namespace onnxruntime {
+namespace contrib {
+
+class GQAAttentionBase : public AttentionBase {
+ protected:
+ GQAAttentionBase(const OpKernelInfo& info, bool require_same_hidden_size)
+ : AttentionBase(info, require_same_hidden_size) {}
+
+ int local_window_size_;
+ bool do_rotary_;
+ bool rotary_interleaved_;
+
+ template
+ Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH
+ const T* K, // K data with shape BxN_kvxSxH
+ const T* V, // V data with shape BxN_kvxSxH
+ const Tensor* past_key, // past K input tensor (if not using past state)
+ const Tensor* past_value, // past V input tensor (if not using past state)
+ Tensor* output, // output tensor
+ Tensor* present_key, // present K output tensor (if separating present KV)
+ Tensor* present_value, // present V output tensor (if separating present KV)
+ const Tensor* seqlens_k, // past sequence lengths tensor
+ GroupQueryAttentionParameters& parameters, // attention parameters
+ AllocatorPtr allocator, // allocator for temporary tensors
+ OpKernelContext* context) const {
+ const int batch_size = parameters.batch_size;
+ const int sequence_length = parameters.sequence_length;
+ const int head_size = parameters.head_size;
+ const int hidden_size = parameters.hidden_size;
+ const bool packed_qkv = parameters.is_packed_qkv;
+
+ auto* tp = context->GetOperatorThreadPool();
+
+ int seqlen_past_kv_cache = 0;
+ if (past_key != nullptr && past_value != nullptr) {
+ seqlen_past_kv_cache = static_cast(past_key->Shape().GetDims()[2]);
+ }
+ int seqlen_present_kv_cache = static_cast(present_key->Shape().GetDims()[2]);
+
+ // Compute the attention score.
+ size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * sizeof(T);
+ auto attention_probs = allocator->Alloc(bytes);
+ BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator));
+
+ void* mask_data = nullptr;
+ size_t mask_data_bytes = SafeInt(batch_size) * sequence_length * seqlen_present_kv_cache * sizeof(T);
+ mask_data = allocator->Alloc(mask_data_bytes);
+ memset(mask_data, 0, mask_data_bytes);
+ BufferUniquePtr mask_data_buffer(mask_data, BufferDeleter(allocator));
+
+ const T* past_key_data = past_key != nullptr ? past_key->Data() : nullptr;
+ T* present_key_data = present_key != nullptr ? present_key->MutableData() : nullptr;
+ const T* past_value_data = past_value != nullptr ? past_value->Data() : nullptr;
+ T* present_value_data = present_value != nullptr ? present_value->MutableData() : nullptr;
+
+ bool past_present_share_buffer = past_key_data == present_key_data && past_value_data == present_value_data;
+
+ const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K;
+ ComputeAttentionProbs(static_cast(attention_probs), Q, k,
+ seqlens_k->Data(), static_cast(mask_data),
+ batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache,
+ head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, tp);
+
+ // Compute the attentionScore * Value: out_tmp(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
+ auto out_tmp_data =
+ allocator->Alloc(SafeInt(batch_size) * num_heads_ * sequence_length * head_size * sizeof(T));
+ BufferUniquePtr out_tmp_buffer(out_tmp_data, BufferDeleter(allocator));
+
+ const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V;
+ ComputeVxAttentionScore(output->MutableData(), static_cast(out_tmp_data), static_cast(attention_probs),
+ v, seqlens_k->Data(), batch_size, sequence_length, seqlen_past_kv_cache,
+ seqlen_present_kv_cache, head_size, hidden_size, past_value_data, present_value_data,
+ past_present_share_buffer, packed_qkv, tp);
+
+ return Status::OK();
+ }
+
+ private:
+ // Helper function to compute the attention probs. It does 2 things:
+ // attention_probs(B, N, S, T) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, T, H -> B, N, H, T) +
+ // 1 x mask_data(B, N, S, T)
+ // attention_probs(B, N, S, T) = Softmax(attention_probs)
+ template
+ void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT
+ const T* Q, // Q data. Its size is BxNxSxH
+ const T* K, // k data. Its size is BxNxLxH
+ const int32_t* seqlens_k, // past sequence lengths tensor
+ T* mask_data, // buffer for mask data.
+ int batch_size, // batch size of self-attention
+ int sequence_length, // sequence length of self-attention (S)
+ int past_buffer_sequence_length, // sequence length of past state
+ int present_buffer_sequence_length, // sequence length of present state
+ int head_size, // head size of self-attention
+ const T* past_key, // past key only (if not using past state)
+ T* present_key, // present key only (if not using present state)
+ bool past_present_share_buffer, // whether present key and value share the same buffer
+ bool packed_qkv, // whether Q, K, V are packed
+ ThreadPool* tp) const { // thread pool
+ const bool is_prompt = sequence_length != 1;
+ const int packed_batch_stride = packed_qkv ? (num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : 0;
+ const int kv_num_heads_factor = num_heads_ / kv_num_heads_;
+ const size_t q_input_chunk_length = static_cast(sequence_length) * head_size; // S x H
+ const size_t kv_input_chunk_length = static_cast(sequence_length) * head_size; // L x H
+ const size_t past_buff_chunk_length = static_cast(past_buffer_sequence_length) * head_size; // L x H
+ const size_t present_buff_chunk_length = static_cast(present_buffer_sequence_length) * head_size; // T x H
+
+ PrepareMaskGQA(mask_data, batch_size, sequence_length, present_buffer_sequence_length, local_window_size_, seqlens_k);
+
+ const int loop_len = batch_size * num_heads_;
+ const float alpha = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_;
+
+ // TODO: cost might differ for gqa because of right padding and total_sequence_length being sequence dependent
+ TensorOpCost unit_cost;
+ const size_t probs_matrix_bytes = SafeInt(sequence_length) * present_buffer_sequence_length * sizeof(T);
+ unit_cost.compute_cycles = static_cast(2 * sequence_length * head_size * present_buffer_sequence_length);
+ unit_cost.bytes_loaded = static_cast((sequence_length + present_buffer_sequence_length) * head_size * sizeof(T));
+ unit_cost.bytes_stored = static_cast(probs_matrix_bytes);
+
+ unit_cost.bytes_loaded += static_cast(probs_matrix_bytes);
+ unit_cost.bytes_stored += static_cast(probs_matrix_bytes);
+
+ if (present_key) {
+ double bytes_to_copy_key = static_cast(sizeof(T) * present_buff_chunk_length);
+ unit_cost.bytes_loaded += bytes_to_copy_key;
+ unit_cost.bytes_stored += bytes_to_copy_key;
+ }
+
+ ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
+ for (std::ptrdiff_t i = begin; i != end; ++i) {
+ const int batch_index = static_cast(i) / num_heads_;
+ const int head_index = static_cast(i) % num_heads_;
+ const int past_seqlen = sequence_length == 1 ? static_cast(seqlens_k[batch_index]) : past_buffer_sequence_length;
+ const size_t past_chunk_length = static_cast(past_seqlen) * head_size;
+
+ const int output_offset = static_cast(i) * sequence_length * present_buffer_sequence_length;
+ const int mask_offset = batch_index * sequence_length * present_buffer_sequence_length;
+ T* output = attention_probs + output_offset;
+
+ // Broadcast mask data: (Bx)SxT -> (BxNx)SxT
+ // TODO: mask after present_sequence_length
+ memcpy(output,
+ mask_data + mask_offset,
+ probs_matrix_bytes);
+
+ const T* k;
+ if (packed_qkv) {
+ k = K + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor);
+ } else {
+ k = K + kv_input_chunk_length * (i / kv_num_heads_factor);
+ }
+ if (nullptr != present_key) {
+ k = ConcatStateChunkGQA(past_key, k, present_key, present_buff_chunk_length, past_buff_chunk_length,
+ past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer,
+ i / kv_num_heads_factor);
+ }
+
+ // TODO: CblasTrans stuff what do?
+ // Compute Q*K' + AttentionMask
+ // original transposed each iteration
+ // A: Q (B x N x) S x H (B x N x) S x H S x H
+ // B: K' (B x N x) T x H (B x N x) H x T H x T
+ // C: attention_probs (B x N x) S x T (B x N x) S x T S x T
+ const T* q;
+ if (packed_qkv) {
+ q = Q + packed_batch_stride * batch_index + q_input_chunk_length * head_index;
+ } else {
+ q = Q + q_input_chunk_length * i;
+ }
+ math::Gemm(CblasNoTrans, CblasTrans, sequence_length, present_buffer_sequence_length, head_size, alpha,
+ q, k, mask_data != nullptr ? 1.0f : 0.0f, output, nullptr);
+ }
+ });
+
+ // attention_probs(B, N, S, T) = Softmax(attention_probs)
+ const int N = batch_size * num_heads_ * sequence_length;
+ const int D = present_buffer_sequence_length;
+ ComputeAttentionSoftmaxInplace(attention_probs, N, D, tp);
+ }
+
+ template
+ void ComputeVxAttentionScore(T* output, // buffer for the result with size BxSxNxH
+ T* tmp_buffer, // buffer for temp use with size is BxNxSxH
+ const T* attention_probs, // Attention probs with size BxNxSxT
+ const T* V, // V value with size BxN_kvxSxH
+ const int32_t* seqlens_k, // past sequence lengths tensor
+ int batch_size, // batch size
+ int sequence_length, // sequence length
+ int past_buffer_sequence_length, // sequence length in past state
+ int present_buffer_sequence_length, // sequence length in past state
+ int head_size, // head size of Q, K, V
+ int hidden_size, // hidden size of Output
+ const T* past_value, // past value only (if not using past state)
+ T* present_value, // present value only (if not using present state)
+ bool past_present_share_buffer, // whether present key and value share the same buffer
+ bool packed_qkv, // whether Q, K, V are packed
+ ThreadPool* tp) const {
+ const bool is_prompt = sequence_length != 1;
+ const int packed_batch_stride = packed_qkv ? (num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : 0;
+ const int kv_num_heads_factor = num_heads_ / kv_num_heads_;
+ // TODO: what is with these being ptrdiff_t?
+ const ptrdiff_t q_input_chunk_length = SafeInt(sequence_length) * head_size; // S x H
+ const ptrdiff_t kv_input_chunk_length = SafeInt(sequence_length) * head_size; // L x H
+ const size_t past_buff_chunk_length = static_cast(past_buffer_sequence_length) * head_size; // L x H
+ const size_t present_buff_chunk_length = static_cast(present_buffer_sequence_length) * head_size; // T x H
+
+ // The cost of Gemm
+ TensorOpCost unit_cost;
+ unit_cost.compute_cycles = static_cast(2 * sequence_length * head_size * present_buffer_sequence_length);
+ unit_cost.bytes_loaded = static_cast((sequence_length + head_size) * present_buffer_sequence_length * sizeof(T));
+ unit_cost.bytes_stored = static_cast(sequence_length * head_size * sizeof(T));
+
+ if (present_value) {
+ double bytes_to_copy_value = static_cast(present_buff_chunk_length * sizeof(T));
+ unit_cost.bytes_loaded += bytes_to_copy_value;
+ unit_cost.bytes_stored += bytes_to_copy_value;
+ }
+
+ const size_t bytes_to_copy_trans = SafeInt(head_size) * sizeof(T);
+ double bytes_to_copy_trans_all = static_cast(sequence_length * bytes_to_copy_trans);
+ unit_cost.bytes_loaded += bytes_to_copy_trans_all;
+ unit_cost.bytes_stored += bytes_to_copy_trans_all;
+
+ ThreadPool::TryParallelFor(tp, SafeInt(batch_size) * num_heads_, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
+ for (std::ptrdiff_t i = begin; i != end; ++i) {
+ const int batch_index = static_cast(i / num_heads_);
+ const int head_index = static_cast(i % num_heads_);
+ const int past_seqlen = sequence_length == 1 ? static_cast(seqlens_k[batch_index]) : past_buffer_sequence_length;
+ const size_t past_chunk_length = static_cast(past_seqlen) * head_size;
+
+ const T* v;
+ if (packed_qkv) {
+ v = V + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor);
+ } else {
+ v = V + kv_input_chunk_length * (i / kv_num_heads_factor);
+ }
+ if (nullptr != present_value) {
+ v = ConcatStateChunkGQA(past_value, v, present_value, present_buff_chunk_length, past_buff_chunk_length,
+ past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer,
+ i / kv_num_heads_factor);
+ }
+
+ T* current_tmp_data = reinterpret_cast(tmp_buffer) + q_input_chunk_length * i;
+ ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * present_buffer_sequence_length * i;
+ math::MatMul(sequence_length, head_size, present_buffer_sequence_length,
+ attention_probs + attention_probs_offset,
+ v, current_tmp_data, nullptr);
+
+ // Transpose: out(B, S, N, H_v) -> out_tmp(B, N, S, H_v)
+ T* src = current_tmp_data;
+ ptrdiff_t dest_offset = (SafeInt(batch_index) * sequence_length * num_heads_ + head_index) * head_size;
+ T* dest = output + dest_offset;
+ for (int j = 0; j < sequence_length; j++) {
+ memcpy(dest, src, bytes_to_copy_trans);
+ src += head_size;
+ dest += hidden_size;
+ }
+ }
+ });
+ }
+};
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
new file mode 100644
index 0000000000..8e6b202ca4
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
@@ -0,0 +1,192 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "group_query_attention.h"
+#include "group_query_attention_helper.h"
+#include "attention_utils.h"
+#include "rotary_embedding.h"
+#include "rotary_embedding_helper.h"
+
+#include "core/framework/tensorprotoutils.h"
+#include "core/graph/onnx_protobuf.h"
+#include "core/common/safeint.h"
+#include "core/platform/threadpool.h"
+
+#include
+#include
+
+using onnxruntime::concurrency::ThreadPool;
+
+namespace onnxruntime {
+namespace contrib {
+
+// These ops are internal-only, so register outside of onnx
+ONNX_OPERATOR_TYPED_KERNEL_EX(
+ GroupQueryAttention,
+ kMSDomain,
+ 1,
+ float,
+ kCpuExecutionProvider,
+ KernelDefBuilder()
+ .TypeConstraint("T", DataTypeImpl::GetTensorType())
+ .TypeConstraint("M", DataTypeImpl::GetTensorType()),
+ GroupQueryAttention);
+
+template
+GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) : OpKernel(info), GQAAttentionBase(info, false) {
+ int64_t num_heads = 0;
+ int64_t kv_num_heads = 0;
+ ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
+ ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0);
+ num_heads_ = static_cast(num_heads);
+ kv_num_heads_ = static_cast(kv_num_heads);
+
+ mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f);
+ local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1));
+ do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1;
+ rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1;
+}
+
+template
+Status GroupQueryAttention::Compute(OpKernelContext* context) const {
+ const Tensor* query = context->Input(0);
+ const Tensor* key = context->Input(1);
+ const Tensor* value = context->Input(2);
+ const Tensor* past_key = context->Input(3);
+ const Tensor* past_value = context->Input(4);
+ const Tensor* seqlens_k = context->Input(5);
+ const Tensor* total_seqlen = context->Input(6);
+ const Tensor* cos_cache = context->Input(7);
+ const Tensor* sin_cache = context->Input(8);
+
+ GroupQueryAttentionParameters parameters = {};
+ constexpr float scale = 1.0f;
+ ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query,
+ key,
+ value,
+ past_key,
+ past_value,
+ cos_cache,
+ sin_cache,
+ ¶meters,
+ num_heads_,
+ kv_num_heads_,
+ seqlens_k,
+ total_seqlen,
+ scale));
+
+ const int batch_size = parameters.batch_size;
+ const int sequence_length = parameters.sequence_length;
+ const int present_kv_seqlen = parameters.seqlen_present_kv_cache;
+ int head_size = parameters.head_size;
+ int q_hidden_size = parameters.hidden_size;
+ const bool packed_qkv = parameters.is_packed_qkv;
+
+ std::vector output_shape(3);
+ output_shape[0] = static_cast(batch_size);
+ output_shape[1] = static_cast(sequence_length);
+ output_shape[2] = static_cast(q_hidden_size);
+ Tensor* output = context->Output(0, output_shape);
+
+ std::vector present_k_shape({static_cast(batch_size), static_cast(kv_num_heads_), static_cast(present_kv_seqlen), static_cast(head_size)});
+ std::vector present_v_shape({static_cast(batch_size), static_cast(kv_num_heads_), static_cast(present_kv_seqlen), static_cast(head_size)});
+ Tensor* present_k = context->Output(1, present_k_shape);
+ Tensor* present_v = context->Output(2, present_v_shape);
+
+ AllocatorPtr allocator;
+ ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
+
+ auto element_type = DataTypeImpl::GetType();
+ OrtValue Q;
+ OrtValue K;
+ OrtValue V;
+ if (packed_qkv) {
+ ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH(
+ allocator, batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size, query, Q));
+ } else if (sequence_length > 1) {
+ ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH(
+ allocator, batch_size, num_heads_, sequence_length, head_size, query, Q));
+ ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH(
+ allocator, batch_size, kv_num_heads_, sequence_length, head_size, key, K));
+ ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH(
+ allocator, batch_size, kv_num_heads_, sequence_length, head_size, value, V));
+ } else {
+ Tensor::InitOrtValue(std::move(const_cast(*query)), Q);
+ Tensor::InitOrtValue(std::move(const_cast(*key)), K);
+ Tensor::InitOrtValue(std::move(const_cast(*value)), V);
+ }
+
+ if (do_rotary_) {
+ rotary_embedding_helper::RotaryParameters rotary_params = {};
+ rotary_params.batch_size = batch_size;
+ rotary_params.sequence_length = sequence_length;
+ rotary_params.hidden_size = q_hidden_size;
+ rotary_params.head_size = head_size;
+ rotary_params.rotary_embedding_dim = parameters.rotary_dim;
+ rotary_params.num_heads = num_heads_;
+ rotary_params.max_sequence_length = sequence_length; // unused
+ rotary_params.seq_stride = head_size;
+ rotary_params.head_stride = sequence_length * rotary_params.seq_stride;
+ rotary_params.batch_stride = (packed_qkv ? (num_heads_ + 2 * kv_num_heads_) : num_heads_) * rotary_params.head_stride;
+ rotary_params.position_ids_format = sequence_length == 1 ? 1 : 0;
+ rotary_params.transposed = true;
+ auto* tp = context->GetOperatorThreadPool();
+ std::vector pos_ids(sequence_length == 1 ? batch_size : 1);
+ if (sequence_length == 1) {
+ for (int b = 0; b < batch_size; b++) {
+ pos_ids[b] = static_cast(seqlens_k->Data()[b]);
+ }
+ } else {
+ pos_ids[0] = static_cast(0);
+ }
+ const T* q_input;
+ const T* k_input;
+ T* q_rotary;
+ T* k_rotary;
+ if (packed_qkv) {
+ OrtValue RotaryQKV;
+ Tensor::InitOrtValue(element_type, TensorShape({batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size}), allocator, RotaryQKV);
+ q_input = Q.Get().Data();
+ k_input = q_input + num_heads_ * sequence_length * head_size;
+ q_rotary = RotaryQKV.GetMutable()->MutableData();
+ k_rotary = q_rotary + num_heads_ * sequence_length * head_size;
+ Q = RotaryQKV;
+ } else {
+ OrtValue RotaryQ;
+ Tensor::InitOrtValue(element_type, TensorShape({batch_size, num_heads_, sequence_length, head_size}), allocator, RotaryQ);
+ OrtValue RotaryK;
+ Tensor::InitOrtValue(element_type, TensorShape({batch_size, kv_num_heads_, sequence_length, head_size}), allocator, RotaryK);
+ q_input = Q.Get().Data();
+ k_input = K.Get().Data();
+ q_rotary = RotaryQ.GetMutable()->MutableData();
+ k_rotary = RotaryK.GetMutable()->MutableData();
+ Q = RotaryQ;
+ K = RotaryK;
+ }
+ ORT_RETURN_IF_ERROR(RunRotaryEmbedding(tp, rotary_params, q_input,
+ pos_ids.data(), cos_cache->Data(),
+ sin_cache->Data(), q_rotary, rotary_interleaved_));
+
+ rotary_params.num_heads = kv_num_heads_;
+ rotary_params.hidden_size = parameters.kv_hidden_size;
+ if (!packed_qkv) {
+ rotary_params.batch_stride = kv_num_heads_ * rotary_params.head_stride;
+ }
+ ORT_RETURN_IF_ERROR(RunRotaryEmbedding(tp, rotary_params, k_input,
+ pos_ids.data(), cos_cache->Data(),
+ sin_cache->Data(), k_rotary, rotary_interleaved_));
+ if (packed_qkv) {
+ const T* v_input = k_input + kv_num_heads_ * sequence_length * head_size;
+ T* v_rotary = k_rotary + kv_num_heads_ * sequence_length * head_size;
+ ORT_RETURN_IF_ERROR(group_query_attention_helper::PackVIntoRotaryQKV(tp, parameters, v_input, v_rotary));
+ }
+ }
+
+ ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
+ // Compute the attention score and apply the score to V
+ return ApplyAttention(Q.Get().Data(), packed_qkv ? nullptr : K.Get().Data(),
+ packed_qkv ? nullptr : V.Get().Data(), past_key, past_value, output, present_k, present_v,
+ seqlens_k, parameters, allocator, context);
+}
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.h
new file mode 100644
index 0000000000..cbe99e0378
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.h
@@ -0,0 +1,21 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/common/common.h"
+#include "core/framework/op_kernel.h"
+#include "gqa_attention_base.h"
+
+namespace onnxruntime {
+namespace contrib {
+
+template
+class GroupQueryAttention final : public OpKernel, public GQAAttentionBase {
+ public:
+ GroupQueryAttention(const OpKernelInfo& info);
+ Status Compute(OpKernelContext* context) const override;
+};
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
new file mode 100644
index 0000000000..e2a615eba4
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
@@ -0,0 +1,299 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/common/common.h"
+#include "core/providers/common.h"
+#include "contrib_ops/cpu/bert/attention_common.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace group_query_attention_helper {
+
+Status CheckInputs(const Tensor* query,
+ const Tensor* key,
+ const Tensor* value,
+ const Tensor* past_key,
+ const Tensor* past_value,
+ const Tensor* cos_cache,
+ const Tensor* sin_cache,
+ void* parameters,
+ int num_heads,
+ int kv_num_heads,
+ const Tensor* seqlens_k,
+ const Tensor* total_seqlen,
+ float scale) {
+ // Note: Here S* is seqlen_past_kv_cache, S+ is seqlen_present_kv_cache
+ // past_key : (B, N_k, S*, H) or (B, N_k, S+, H) or nullptr
+ // past_value : (B, N_k, S*, H) or (B, N_k, S+, H) or nullptr
+ // no packing for q/k/v:
+ // query (Q) : (B, S, D) or (B, S, (D_q + 2 D_kv))
+ // key (K) : (B, S, D_kv) or nullptr
+ // value (V) : (B, S, D_kv) or nullptr
+
+ AttentionQkvFormat qkv_format = Q_K_V_BSNH;
+ AttentionQkvFormat past_kv_format = Q_K_V_BNSH;
+ const bool is_packed_qkv = key == nullptr;
+
+ const auto& query_dims = query->Shape().GetDims();
+ if (query_dims.size() != 3) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ",
+ query_dims.size());
+ }
+
+ int batch_size = static_cast(query_dims[0]);
+ int sequence_length = static_cast(query_dims[1]);
+ int q_hidden_size = static_cast(query_dims[2]);
+ int head_size = 0;
+
+ if (num_heads % kv_num_heads != 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ",
+ num_heads % kv_num_heads);
+ }
+
+ int kv_hidden_size = 0;
+ // Check key and value when not packed
+ if (!is_packed_qkv) {
+ head_size = static_cast(q_hidden_size) / num_heads;
+ if (head_size % 8 != 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "head_size must be a multiple of 8. Got head_size % 8 == ",
+ head_size % 8);
+ }
+ if (value == nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv.");
+ }
+ const auto& key_dims = key->Shape().GetDims();
+ if (key_dims.size() != 3) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ",
+ key_dims.size());
+ } else if (query_dims[0] != key_dims[0]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'query' and 'key' shall have same dim 0 (batch size)");
+ } else if (query_dims[1] != key_dims[1]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'query' and 'key' shall have same dim 1 (sequence length)");
+ }
+ kv_hidden_size = static_cast(key_dims[2]);
+ const auto& value_dims = value->Shape().GetDims();
+ if (value_dims.size() != 3) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ",
+ value_dims.size());
+ } else if (query_dims[0] != value_dims[0]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'query' and 'value' shall have same dim 0 (batch size)");
+ } else if (query_dims[1] != value_dims[1]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'query' and 'value' shall have same dim 1 (sequence length)");
+ } else if (value_dims[2] != kv_hidden_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key.");
+ }
+ } else {
+ // Check packed qkv
+ head_size = static_cast(q_hidden_size) / (num_heads + 2 * kv_num_heads);
+ if (head_size % 8 != 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "head_size must be a multiple of 8. Got head_size % 8 == ",
+ head_size % 8);
+ }
+ if (value != nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv.");
+ }
+ q_hidden_size = head_size * num_heads;
+ kv_hidden_size = head_size * kv_num_heads;
+ }
+
+ // Check past-present KV
+ int32_t past_sequence_length = 0;
+ if (past_key != nullptr && past_value != nullptr) {
+ const auto& past_key_dims = past_key->Shape().GetDims();
+ const auto& past_value_dims = past_value->Shape().GetDims();
+
+ if (past_key_dims.size() != 4) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'past_key' is expected to have 4 dimensions, got ",
+ past_key_dims.size());
+ }
+ if (past_value_dims.size() != 4) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'past_value' is expected to have 4 dimensions, got ",
+ past_value_dims.size());
+ }
+
+ if (past_key_dims[0] != batch_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'past_key' dimension 0 should be batch_size, got ",
+ past_key_dims[0]);
+ }
+ if (past_value_dims[0] != batch_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'past_value' dimension 0 should be batch_size, got ",
+ past_value_dims[0]);
+ }
+
+ if (past_key_dims[2] != past_value_dims[2]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "BNSH Input 'past_key' and 'past_value' should have same dimension 2 (max sequence"
+ "length or past sequence length), got ",
+ past_key_dims[1]);
+ }
+ if (past_key_dims[1] != kv_num_heads) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'past_key' shall have kv_num_heads");
+ }
+ if (past_value_dims[1] != kv_num_heads) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'past_value' shall have kv_num_heads");
+ }
+ // We assume all sequence in past kv are right-padded to max or past sequence length
+ past_sequence_length = static_cast(past_key_dims[2]);
+
+ if (past_key_dims[3] != head_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'past_key' dimension 3 should be same as head_size, got ",
+ past_key_dims[3]);
+ }
+ if (past_value_dims[3] != head_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'past_value' dimension 3 should be same as head_size, got ",
+ past_value_dims[3]);
+ }
+ } else if (past_key != nullptr || past_value != nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'past_key' and 'past_value' shall be both present or both absent.");
+ }
+
+ // Check seqlens_k tensor (holding past seqlen for token gen)
+ const auto& seqlens_dim = seqlens_k->Shape().GetDims();
+ if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "seqlens_k must be shape (batch_size).");
+ }
+
+ // Set present sequence length and kv_share_buffer from input total_seqlen tensor
+ if (!onnxruntime::IsScalarOr1ElementVector(total_seqlen)) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "total_sequence_length tensor must be of one element.");
+ }
+ int total_sequence_length = *((*total_seqlen).template Data());
+ int present_sequence_length = std::max(total_sequence_length, past_sequence_length);
+
+ int rotary_dim = 0;
+ if (cos_cache != nullptr && sin_cache != nullptr) {
+ const auto& cos_dims = cos_cache->Shape().GetDims();
+ const auto& sin_dims = sin_cache->Shape().GetDims();
+
+ if (head_size % 16 != 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "head_size shall be a multiple of 16. Got head_size % 16 == ",
+ head_size % 16);
+ }
+ if (cos_dims[0] < present_sequence_length) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "cos_cache dimension 0 should be of max_sequence_length.");
+ }
+ if (sin_dims[0] < present_sequence_length) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "sin_cache dimension 0 should be of max_sequence_length.");
+ }
+ if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
+ }
+ if (sin_dims[1] > (head_size / 16) * 8 || sin_dims[1] % 8 != 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
+ }
+ if (cos_dims[1] != sin_dims[1]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "cos_cache and sin_cache dimension 1 must be the same.");
+ }
+ rotary_dim = static_cast(cos_dims[1] * 2);
+ } else if (cos_cache != nullptr || sin_cache != nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'cos_cache' and 'sin_cache' shall be both present or both absent.");
+ }
+
+ bool is_prompt = sequence_length != 1;
+
+ if (parameters != nullptr) {
+ GroupQueryAttentionParameters* output_parameters = reinterpret_cast(parameters);
+ output_parameters->batch_size = batch_size;
+ output_parameters->sequence_length = sequence_length; // sequence length of Q
+ output_parameters->seqlen_past_kv_cache = past_sequence_length; // max sequence length of past kv tensors
+ output_parameters->seqlen_present_kv_cache = present_sequence_length; // max sequence length of present kv tensors
+ output_parameters->hidden_size = q_hidden_size;
+ output_parameters->num_heads = num_heads;
+ output_parameters->head_size = head_size;
+ output_parameters->kv_hidden_size = kv_hidden_size;
+ output_parameters->kv_num_heads = kv_num_heads;
+ output_parameters->rotary_dim = rotary_dim;
+ output_parameters->is_packed_qkv = is_packed_qkv;
+ output_parameters->is_unidirectional = true;
+ output_parameters->is_prompt = is_prompt;
+ output_parameters->scale = scale;
+ output_parameters->qkv_format = qkv_format;
+ output_parameters->past_kv_format = past_kv_format;
+ }
+
+ return Status::OK();
+}
+
+Status CheckInputs(const Tensor* query,
+ const Tensor* key,
+ const Tensor* value,
+ const Tensor* past_key,
+ const Tensor* past_value,
+ const Tensor* cos_cache,
+ const Tensor* sin_cache,
+ void* parameters,
+ int num_heads,
+ int kv_num_heads,
+ const Tensor* seqlens_k,
+ const Tensor* total_seqlen,
+ float scale,
+ int max_threads_per_block) {
+ if (max_threads_per_block > 0 && num_heads > max_threads_per_block) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block);
+ }
+
+ return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, scale);
+}
+
+template
+Status PackVIntoRotaryQKV(concurrency::ThreadPool* tp, GroupQueryAttentionParameters parameters, const T* input,
+ T* output) {
+ const int batch_size = parameters.batch_size;
+ const int sequence_length = parameters.sequence_length;
+ int n_heads = parameters.num_heads;
+ const int kv_n_heads = parameters.kv_num_heads;
+ const int head_size = parameters.head_size;
+ int seq_stride = head_size;
+ int head_stride = sequence_length * seq_stride;
+ int batch_stride = (n_heads + 2 * kv_n_heads) * head_stride;
+
+ const int loop_len = batch_size * sequence_length * kv_n_heads;
+ const double cost = static_cast(head_size);
+ ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
+ for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) {
+ const int b = static_cast((ptr / kv_n_heads) / sequence_length);
+ const int s = static_cast((ptr / kv_n_heads) % sequence_length);
+ const int n = static_cast(ptr % kv_n_heads);
+ const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;
+ const T* input_data = input + block_offset;
+ T* output_data = output + block_offset;
+ for (int i = 0; i < head_size; i++) {
+ output_data[i] = input_data[i];
+ }
+ }
+ });
+ return Status::OK();
+}
+
+} // namespace group_query_attention_helper
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
index c4e4b4ec70..6e85be15d9 100644
--- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
@@ -4,15 +4,13 @@
#include "attention_cpu_base.h"
#include "multihead_attention.h"
#include "multihead_attention_helper.h"
+#include "attention_utils.h"
#include "core/common/common.h"
#include "core/framework/tensorprotoutils.h"
-#include "core/framework/transpose_helper.h"
#include "core/graph/onnx_protobuf.h"
#include "core/common/safeint.h"
#include "core/platform/threadpool.h"
-#include "core/providers/cpu/math/element_wise_ops.h"
-#include "core/providers/cpu/tensor/reshape_helper.h"
#include
#include
@@ -43,217 +41,6 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) : OpKernel(i
is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1;
}
-// Reshape Q/K/V from BxSxD to BxSxNxH
-Status Reshape_BSD_to_BSNH(Tensor* qkv,
- int batch_size,
- int sequence_length,
- int num_heads,
- int head_size) {
- std::vector reshape_dims({batch_size, sequence_length, num_heads, head_size});
- gsl::span reshape_dims_span{reshape_dims};
- TensorShape qkv_bsnh(reshape_dims_span);
- qkv->Reshape(qkv_bsnh);
- return Status::OK();
-}
-
-// Transpose Q/K/V from BxSxNxH to BxNxSxH
-Status Transpose_BSNH_to_BNSH(const Tensor* qkv,
- OrtValue& qkv_transposed,
- concurrency::ThreadPool* tp = nullptr) {
- std::vector permutations({0, 2, 1, 3});
- gsl::span permutations_span{permutations};
- size_t from = 2, to = 1;
- SingleAxisTranspose(permutations_span, *qkv, *qkv_transposed.GetMutable(), from, to, nullptr, tp);
- return Status::OK();
-}
-
-// Add bias + transpose for each of Q/K/V
-template
-Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V data - query is BxSxD, key is BxLxD, value is BxLxD_v
- const T* qkv_bias, // Input: QKV bias - bias is (D + D + D_v)
- OrtValue& qkv_with_bias_transposed, // Output: Q/K/V data - query is BxNxSxH, key is BxNxLxH, value is BxNxLxH_v
- int bias_offset, // bias offset to enter qkv_bias
- int batch_size, // batch size
- int sequence_length, // sequence_length for Q, kv_sequence_length for K/V
- int num_heads, // num heads
- int head_size, // head_size for Q/K, v_head_size for V
- int hidden_size, // hidden_size for Q/K, v_hidden_size for V
- OpKernelContext* context) {
- // Note: the comments below will refer to Q's dimensions for simplicity
- auto element_type = DataTypeImpl::GetType();
- constexpr size_t element_size = sizeof(T);
- ProcessBroadcastSpanFuncs add_funcs{
- [](BroadcastHelper& per_iter_bh) {
- per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array();
- },
- [](BroadcastHelper& per_iter_bh) {
- per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1();
- },
- [](BroadcastHelper& per_iter_bh) {
- per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1();
- }}; // For element-wise add
-
- // Allocate space for output of Q(BS, D) + bias(D)
- AllocatorPtr allocator;
- ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
- std::vector old_dims({batch_size, sequence_length, hidden_size});
- gsl::span old_dims_span{old_dims};
- TensorShape qkv_with_bias_shape(old_dims_span);
- OrtValue qkv_with_bias;
- Tensor::InitOrtValue(element_type, qkv_with_bias_shape, allocator, qkv_with_bias);
-
- // Get Q's bias from combined bias
- std::vector