### Description
Support GQA operator on CPU with FP32.



### Motivation and Context
Right now, models generated for CPU and GPU must be different. GQA CPU
allows these models to be the same.
This commit is contained in:
aciddelgado 2024-04-22 19:57:05 -07:00 committed by GitHub
parent c47a6ce70b
commit 94c69f55d4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 3159 additions and 279 deletions

View file

@ -2455,6 +2455,9 @@ This version of the operator has been available since version 1 of the 'com.micr
Group Query Self/Cross Attention.
Supports different number of heads for q and kv. Only supports causal or local attention.
Supports rotary position embedding.
Supports k-v cache.
CPU EP supports fp32... CUDA EP supports fp16.
#### Version
@ -2514,7 +2517,7 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints
<dl>
<dt><tt>T</tt> : tensor(float16), tensor(bfloat16)</dt>
<dt><tt>T</tt> : tensor(float16), tensor(bfloat16), tensor(float)</dt>
<dd>Constrain input and output to float tensors.</dd>
<dt><tt>M</tt> : tensor(int32)</dt>
<dd>Constrain mask to int tensor.</dd>

View file

@ -481,6 +481,7 @@ Do not modify directly.*
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float)|
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float)|
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
|MatMulFpQ4|*in* A:**T1**<br> *in* B:**T2**<br> *in* B_shape:**T3**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(int64)|

View file

@ -68,6 +68,7 @@ class AttentionBase {
const Tensor* past_seq_len = nullptr) const;
int num_heads_; // number of attention heads
int kv_num_heads_; // different for k and v for group query attention
bool is_unidirectional_; // whether every token can only attend to previous tokens.
std::vector<int64_t> qkv_hidden_sizes_; // Q, K, V hidden sizes parsed from the qkv_hidden_sizes attribute.
bool require_same_hidden_size_; // whether the implementation supports different hidden sizes of Q/K/V.

View file

@ -153,6 +153,49 @@ void PrepareMask(const int32_t* mask_index,
}
}
// Applies causal mask and past seqlens right pad to the mask_data buffer.
template <typename T>
void PrepareMaskGQA(T* mask_data,
int batch_size,
int sequence_length,
int buffer_sequence_length,
int local_window_size,
const int32_t* seqlens_k) {
// mask_data has been filled with 0, and its shape is BxSxT
T* p_mask = mask_data;
// TODO: parallelize this
for (int b_i = 0; b_i < batch_size; b_i++) {
if (sequence_length > 1) {
// Apply causal/local mask for prompt case.
for (int s_i = 0; s_i < sequence_length; s_i++) {
for (int m_i = s_i + 1; m_i < buffer_sequence_length; m_i++) {
p_mask[s_i * buffer_sequence_length + m_i] = std::numeric_limits<T>::lowest();
}
// Apply local mask.
if (local_window_size > 0) {
for (int m_i = 0; m_i < s_i - local_window_size; m_i++) {
p_mask[s_i * buffer_sequence_length + m_i] = std::numeric_limits<T>::lowest();
}
}
}
} else if (sequence_length == 1) {
// Apply right padding to mask for token gen case.
int total_seqlen = seqlens_k[b_i] + 1;
for (int m_i = total_seqlen; m_i < buffer_sequence_length; m_i++) {
p_mask[m_i] = std::numeric_limits<T>::lowest();
}
// Apply local mask.
if (local_window_size > 0) {
for (int m_i = 0; m_i < total_seqlen - local_window_size - 1; m_i++) {
p_mask[m_i] = std::numeric_limits<T>::lowest();
}
}
}
ptrdiff_t mask_to_advance = SafeInt<ptrdiff_t>(sequence_length) * buffer_sequence_length;
p_mask += mask_to_advance;
}
}
// Concatenate a past state chunk PxH with input state chunk LxH into present state chunk TxH
// Returns a pointer to the start of present state chunk.
template <typename T>
@ -175,5 +218,32 @@ T* ConcatStateChunk(const T* past,
return start;
}
// GQA version of ConcatStateChunk
template <typename T>
T* ConcatStateChunkGQA(const T* past,
const T* chunk,
T* present,
size_t present_buff_chunk_length,
size_t past_buff_chunk_length,
size_t past_chunk_length,
size_t new_chunk_length,
bool is_prompt,
bool past_present_share_buffer,
std::ptrdiff_t i) {
T* start = present + i * present_buff_chunk_length;
T* p = start;
if (!is_prompt) {
if (!past_present_share_buffer) {
const T* src_past = past + i * past_buff_chunk_length;
memcpy(p, src_past, past_chunk_length * sizeof(T));
}
p += past_chunk_length;
}
memcpy(p, chunk, new_chunk_length * sizeof(T));
return start;
}
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,246 @@
#include "attention_utils.h"
#include "core/common/common.h"
#include "core/framework/tensorprotoutils.h"
#include "core/framework/transpose_helper.h"
#include "core/providers/cpu/tensor/reshape_helper.h"
#include "core/providers/cpu/math/element_wise_ops.h"
using onnxruntime::concurrency::ThreadPool;
namespace onnxruntime {
namespace contrib {
// Reshape Q/K/V from BxSxD to BxSxNxH
inline Status Reshape_BSD_to_BSNH(Tensor* qkv,
int batch_size,
int sequence_length,
int num_heads,
int head_size) {
qkv->Reshape(TensorShape({batch_size, sequence_length, num_heads, head_size}));
return Status::OK();
}
// Transpose Q/K/V from BxSxNxH to BxNxSxH
inline Status Transpose_BSNH_to_BNSH(const Tensor* qkv,
OrtValue& qkv_transposed,
concurrency::ThreadPool* tp) {
std::vector<size_t> permutations({0, 2, 1, 3});
gsl::span<const size_t> permutations_span{permutations};
size_t from = 2, to = 1;
SingleAxisTranspose(permutations, *qkv, *qkv_transposed.GetMutable<Tensor>(), from, to, nullptr, tp);
return Status::OK();
}
// Add bias + transpose for each of Q/K/V
template <typename T>
Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V data - query is BxSxD, key is BxLxD, value is BxLxD_v
const T* qkv_bias, // Input: QKV bias - bias is (D + D + D_v)
OrtValue& qkv_with_bias_transposed, // Output: Q/K/V data - query is BxNxSxH, key is BxNxLxH, value is BxNxLxH_v
int bias_offset, // bias offset to enter qkv_bias
int batch_size, // batch size
int sequence_length, // sequence_length for Q, kv_sequence_length for K/V
int num_heads, // num heads
int head_size, // head_size for Q/K, v_head_size for V
int hidden_size, // hidden_size for Q/K, v_hidden_size for V
OpKernelContext* context) {
// Note: the comments below will refer to Q's dimensions for simplicity
auto element_type = DataTypeImpl::GetType<T>();
constexpr size_t element_size = sizeof(T);
ProcessBroadcastSpanFuncs add_funcs{
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.ScalarInput0<T>() + per_iter_bh.EigenInput1<T>().array();
},
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>().array() + per_iter_bh.ScalarInput1<T>();
},
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>() + per_iter_bh.EigenInput1<T>();
}}; // For element-wise add
// Allocate space for output of Q(BS, D) + bias(D)
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
std::vector<int64_t> old_dims({batch_size, sequence_length, hidden_size});
gsl::span<const int64_t> old_dims_span{old_dims};
TensorShape qkv_with_bias_shape(old_dims_span);
OrtValue qkv_with_bias;
Tensor::InitOrtValue(element_type, qkv_with_bias_shape, allocator, qkv_with_bias);
// Get Q's bias from combined bias
std::vector<int64_t> bias_dims({hidden_size});
gsl::span<const int64_t> bias_dims_span{bias_dims};
TensorShape bias_shape(bias_dims_span);
OrtValue bias;
Tensor::InitOrtValue(element_type, bias_shape, allocator, bias);
memcpy(bias.GetMutable<Tensor>()->MutableData<T>(), qkv_bias + bias_offset, hidden_size * element_size);
// Compute Q(BS, D) + bias(D) as broadcasted element-wise add
{
InputBroadcaster input_broadcaster(*bias.GetMutable<Tensor>(), *qkv);
const InputBroadcaster& const_input_broadcaster = input_broadcaster;
Tensor& output_tensor = *qkv_with_bias.GetMutable<Tensor>();
size_t span_size = input_broadcaster.GetSpanSize();
size_t output_size = static_cast<ptrdiff_t>(output_tensor.Shape().Size());
void* user_data = nullptr;
const int loop_len = static_cast<int>(output_size / span_size);
double unit_cost = 1.0f;
const auto cost = TensorOpCost{static_cast<double>(input_broadcaster.Input0ElementSize()) * span_size,
static_cast<double>(output_tensor.DataType()->Size()) * span_size,
unit_cost * span_size};
auto tp = context->GetOperatorThreadPool();
ThreadPool::TryParallelFor(tp, loop_len, cost,
[span_size, &const_input_broadcaster, &output_tensor, &add_funcs, user_data](std::ptrdiff_t first_span,
std::ptrdiff_t last_span) {
InputBroadcaster segment_input_broadcaster(const_input_broadcaster);
segment_input_broadcaster.AdvanceBy(first_span * span_size);
OutputBroadcaster segment_output_broadcaster(span_size, output_tensor,
first_span * span_size, last_span * span_size);
BroadcastHelper segment_helper(segment_input_broadcaster, segment_output_broadcaster, user_data);
BroadcastLooper(segment_helper, add_funcs);
});
}
// Reshape Q from BxSxD to BxSxNxH
ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(qkv_with_bias.GetMutable<Tensor>(), batch_size, sequence_length, num_heads, head_size));
// Transpose Q from BxSxNxH to BxNxSxH
auto tp = context->GetOperatorThreadPool();
ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH(qkv_with_bias.GetMutable<Tensor>(), qkv_with_bias_transposed, tp));
return Status::OK();
}
// Add bias + reshape for each of Q/K/V
// This is used in decoder_with_past when the sequence length is 1
template <typename T>
Status AddBiasReshape(const Tensor* qkv, // Input: Q/K/V data - query is BxSxD, key is BxLxD, value is BxLxD_v
const T* qkv_bias, // Input: QKV bias - bias is (D + D + D_v)
OrtValue& qkv_with_bias, // Output: Q/K/V data - query is BxNxSxH, key is BxNxLxH, value is BxNxLxH_v
int bias_offset, // bias offset to enter qkv_bias
int batch_size, // batch size
int sequence_length, // sequence_length for Q, kv_sequence_length for K/V
int num_heads, // num heads
int head_size, // head_size for Q/K, v_head_size for V
int hidden_size, // hidden_size for Q/K, v_hidden_size for V
OpKernelContext* context) {
// Note: the comments below will refer to Q's dimensions for simplicity
auto element_type = DataTypeImpl::GetType<T>();
constexpr size_t element_size = sizeof(T);
ProcessBroadcastSpanFuncs add_funcs{
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.ScalarInput0<T>() + per_iter_bh.EigenInput1<T>().array();
},
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>().array() + per_iter_bh.ScalarInput1<T>();
},
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>() + per_iter_bh.EigenInput1<T>();
}}; // For element-wise add
// Get Q's bias from combined bias
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
std::vector<int64_t> bias_dims({hidden_size});
gsl::span<const int64_t> bias_dims_span{bias_dims};
TensorShape bias_shape(bias_dims_span);
OrtValue bias;
Tensor::InitOrtValue(element_type, bias_shape, allocator, bias);
auto num_bias_elements = SafeInt<size_t>(hidden_size) * element_size;
memcpy(bias.GetMutable<Tensor>()->MutableData<T>(), qkv_bias + bias_offset, num_bias_elements);
// Compute Q(BS, D) + bias(D) as broadcasted element-wise add
{
InputBroadcaster input_broadcaster(*bias.GetMutable<Tensor>(), *qkv);
const InputBroadcaster& const_input_broadcaster = input_broadcaster;
Tensor& output_tensor = *qkv_with_bias.GetMutable<Tensor>();
size_t span_size = input_broadcaster.GetSpanSize();
size_t output_size = static_cast<ptrdiff_t>(output_tensor.Shape().Size());
void* user_data = nullptr;
const int loop_len = static_cast<int>(output_size / span_size);
double unit_cost = 1.0f;
const auto cost = TensorOpCost{static_cast<double>(input_broadcaster.Input0ElementSize()) * span_size,
static_cast<double>(output_tensor.DataType()->Size()) * span_size,
unit_cost * span_size};
auto tp = context->GetOperatorThreadPool();
ThreadPool::TryParallelFor(tp, loop_len, cost,
[span_size, &const_input_broadcaster, &output_tensor, &add_funcs, user_data](std::ptrdiff_t first_span,
std::ptrdiff_t last_span) {
InputBroadcaster segment_input_broadcaster(const_input_broadcaster);
segment_input_broadcaster.AdvanceBy(first_span * span_size);
OutputBroadcaster segment_output_broadcaster(span_size, output_tensor,
first_span * span_size, last_span * span_size);
BroadcastHelper segment_helper(segment_input_broadcaster, segment_output_broadcaster, user_data);
BroadcastLooper(segment_helper, add_funcs);
});
}
// Reshape Q from BxSxD to BxNxSxH
qkv_with_bias.GetMutable<Tensor>()->Reshape(TensorShape({batch_size, num_heads, sequence_length, head_size}));
return Status::OK();
}
template <typename T>
Status MaybeTransposeToBNSHAndAddBias(OpKernelContext* context, AllocatorPtr allocator,
int batch_size, int num_heads, int sequence_length, int head_size,
const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out) {
auto element_type = DataTypeImpl::GetType<T>();
std::vector<int64_t> new_dims({batch_size, num_heads, sequence_length, head_size});
gsl::span<const int64_t> new_dims_span{new_dims};
TensorShape v_BNLH(new_dims_span);
Tensor::InitOrtValue(element_type, v_BNLH, allocator, out);
if (bias == nullptr) {
std::unique_ptr<Tensor> reshaped;
if (in->Shape().GetDims().size() == 3) {
reshaped = std::make_unique<Tensor>(in->DataType(), in->Shape(), const_cast<void*>(in->DataRaw()), in->Location());
ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(reshaped.get(), batch_size, sequence_length, num_heads, head_size));
}
ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH((reshaped == nullptr) ? in : reshaped.get(), out));
} else {
const auto* qkv_bias = bias->Data<T>();
if (sequence_length == 1) {
ORT_RETURN_IF_ERROR(AddBiasReshape(in, qkv_bias, out, bias_offset, batch_size, sequence_length, num_heads, head_size, num_heads * head_size, context));
} else {
ORT_RETURN_IF_ERROR(AddBiasTranspose(in, qkv_bias, out, bias_offset, batch_size, sequence_length, num_heads, head_size, num_heads * head_size, context));
}
}
return Status::OK();
};
template Status MaybeTransposeToBNSHAndAddBias<float>(OpKernelContext* context, AllocatorPtr allocator,
int batch_size, int num_heads, int sequence_length, int head_size,
const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out);
template <typename T>
Status MaybeTransposeToBNSH(AllocatorPtr allocator,
int batch_size, int num_heads, int sequence_length, int head_size,
const Tensor* in, OrtValue& out) {
auto element_type = DataTypeImpl::GetType<T>();
std::vector<int64_t> new_dims({batch_size, num_heads, sequence_length, head_size});
gsl::span<const int64_t> new_dims_span{new_dims};
TensorShape v_BNLH(new_dims_span);
Tensor::InitOrtValue(element_type, v_BNLH, allocator, out);
std::unique_ptr<Tensor> reshaped;
if (in->Shape().GetDims().size() == 3) {
reshaped = std::make_unique<Tensor>(in->DataType(), in->Shape(), const_cast<void*>(in->DataRaw()), in->Location());
ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(reshaped.get(), batch_size, sequence_length, num_heads, head_size));
}
ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH((reshaped == nullptr) ? in : reshaped.get(), out));
return Status::OK();
};
template Status MaybeTransposeToBNSH<float>(AllocatorPtr allocator,
int batch_size, int num_heads, int sequence_length, int head_size,
const Tensor* in, OrtValue& out);
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,61 @@
#pragma once
#include "core/common/common.h"
#include "core/framework/tensorprotoutils.h"
#include "core/framework/transpose_helper.h"
#include "core/providers/cpu/tensor/reshape_helper.h"
#include "core/providers/cpu/math/element_wise_ops.h"
namespace onnxruntime {
namespace contrib {
// Reshape Q/K/V from BxSxD to BxSxNxH
Status Reshape_BSD_to_BSNH(Tensor* qkv,
int batch_size,
int sequence_length,
int num_heads,
int head_size);
// Transpose Q/K/V from BxSxNxH to BxNxSxH
Status Transpose_BSNH_to_BNSH(const Tensor* qkv,
OrtValue& qkv_transposed,
concurrency::ThreadPool* tp = nullptr);
// Add bias + transpose for each of Q/K/V
template <typename T>
Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V data - query is BxSxD, key is BxLxD, value is BxLxD_v
const T* qkv_bias, // Input: QKV bias - bias is (D + D + D_v)
OrtValue& qkv_with_bias_transposed, // Output: Q/K/V data - query is BxNxSxH, key is BxNxLxH, value is BxNxLxH_v
int bias_offset, // bias offset to enter qkv_bias
int batch_size, // batch size
int sequence_length, // sequence_length for Q, kv_sequence_length for K/V
int num_heads, // num heads
int head_size, // head_size for Q/K, v_head_size for V
int hidden_size, // hidden_size for Q/K, v_hidden_size for V
OpKernelContext* context); // OpKernelContext
// Add bias + reshape for each of Q/K/V
// This is used in decoder_with_past when the sequence length is 1
template <typename T>
Status AddBiasReshape(const Tensor* qkv, // Input: Q/K/V data - query is BxSxD, key is BxLxD, value is BxLxD_v
const T* qkv_bias, // Input: QKV bias - bias is (D + D + D_v)
OrtValue& qkv_with_bias, // Output: Q/K/V data - query is BxNxSxH, key is BxNxLxH, value is BxNxLxH_v
int bias_offset, // bias offset to enter qkv_bias
int batch_size, // batch size
int sequence_length, // sequence_length for Q, kv_sequence_length for K/V
int num_heads, // num heads
int head_size, // head_size for Q/K, v_head_size for V
int hidden_size, // hidden_size for Q/K, v_hidden_size for V
OpKernelContext* context); // OpKernelContext
template <typename T>
Status MaybeTransposeToBNSHAndAddBias(OpKernelContext* context, AllocatorPtr allocator,
int batch_size, int num_heads, int sequence_length, int head_size,
const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out);
template <typename T>
Status MaybeTransposeToBNSH(AllocatorPtr allocator,
int batch_size, int num_heads, int sequence_length, int head_size,
const Tensor* in, OrtValue& out);
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,276 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "attention_base.h"
#include "attention_helper.h"
#include "core/common/common.h"
#include "contrib_ops/cpu/bert/attention_common.h"
#include "core/common/safeint.h"
#include "core/framework/op_kernel.h"
namespace onnxruntime {
namespace contrib {
class GQAAttentionBase : public AttentionBase {
protected:
GQAAttentionBase(const OpKernelInfo& info, bool require_same_hidden_size)
: AttentionBase(info, require_same_hidden_size) {}
int local_window_size_;
bool do_rotary_;
bool rotary_interleaved_;
template <typename T>
Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH
const T* K, // K data with shape BxN_kvxSxH
const T* V, // V data with shape BxN_kvxSxH
const Tensor* past_key, // past K input tensor (if not using past state)
const Tensor* past_value, // past V input tensor (if not using past state)
Tensor* output, // output tensor
Tensor* present_key, // present K output tensor (if separating present KV)
Tensor* present_value, // present V output tensor (if separating present KV)
const Tensor* seqlens_k, // past sequence lengths tensor
GroupQueryAttentionParameters& parameters, // attention parameters
AllocatorPtr allocator, // allocator for temporary tensors
OpKernelContext* context) const {
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int head_size = parameters.head_size;
const int hidden_size = parameters.hidden_size;
const bool packed_qkv = parameters.is_packed_qkv;
auto* tp = context->GetOperatorThreadPool();
int seqlen_past_kv_cache = 0;
if (past_key != nullptr && past_value != nullptr) {
seqlen_past_kv_cache = static_cast<int>(past_key->Shape().GetDims()[2]);
}
int seqlen_present_kv_cache = static_cast<int>(present_key->Shape().GetDims()[2]);
// Compute the attention score.
size_t bytes = SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * sizeof(T);
auto attention_probs = allocator->Alloc(bytes);
BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator));
void* mask_data = nullptr;
size_t mask_data_bytes = SafeInt<size_t>(batch_size) * sequence_length * seqlen_present_kv_cache * sizeof(T);
mask_data = allocator->Alloc(mask_data_bytes);
memset(mask_data, 0, mask_data_bytes);
BufferUniquePtr mask_data_buffer(mask_data, BufferDeleter(allocator));
const T* past_key_data = past_key != nullptr ? past_key->Data<T>() : nullptr;
T* present_key_data = present_key != nullptr ? present_key->MutableData<T>() : nullptr;
const T* past_value_data = past_value != nullptr ? past_value->Data<T>() : nullptr;
T* present_value_data = present_value != nullptr ? present_value->MutableData<T>() : nullptr;
bool past_present_share_buffer = past_key_data == present_key_data && past_value_data == present_value_data;
const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K;
ComputeAttentionProbs<T>(static_cast<T*>(attention_probs), Q, k,
seqlens_k->Data<int32_t>(), static_cast<T*>(mask_data),
batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache,
head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, tp);
// Compute the attentionScore * Value: out_tmp(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
auto out_tmp_data =
allocator->Alloc(SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * head_size * sizeof(T));
BufferUniquePtr out_tmp_buffer(out_tmp_data, BufferDeleter(allocator));
const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V;
ComputeVxAttentionScore(output->MutableData<T>(), static_cast<T*>(out_tmp_data), static_cast<T*>(attention_probs),
v, seqlens_k->Data<int32_t>(), batch_size, sequence_length, seqlen_past_kv_cache,
seqlen_present_kv_cache, head_size, hidden_size, past_value_data, present_value_data,
past_present_share_buffer, packed_qkv, tp);
return Status::OK();
}
private:
// Helper function to compute the attention probs. It does 2 things:
// attention_probs(B, N, S, T) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, T, H -> B, N, H, T) +
// 1 x mask_data(B, N, S, T)
// attention_probs(B, N, S, T) = Softmax(attention_probs)
template <typename T>
void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT
const T* Q, // Q data. Its size is BxNxSxH
const T* K, // k data. Its size is BxNxLxH
const int32_t* seqlens_k, // past sequence lengths tensor
T* mask_data, // buffer for mask data.
int batch_size, // batch size of self-attention
int sequence_length, // sequence length of self-attention (S)
int past_buffer_sequence_length, // sequence length of past state
int present_buffer_sequence_length, // sequence length of present state
int head_size, // head size of self-attention
const T* past_key, // past key only (if not using past state)
T* present_key, // present key only (if not using present state)
bool past_present_share_buffer, // whether present key and value share the same buffer
bool packed_qkv, // whether Q, K, V are packed
ThreadPool* tp) const { // thread pool
const bool is_prompt = sequence_length != 1;
const int packed_batch_stride = packed_qkv ? (num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : 0;
const int kv_num_heads_factor = num_heads_ / kv_num_heads_;
const size_t q_input_chunk_length = static_cast<size_t>(sequence_length) * head_size; // S x H
const size_t kv_input_chunk_length = static_cast<size_t>(sequence_length) * head_size; // L x H
const size_t past_buff_chunk_length = static_cast<size_t>(past_buffer_sequence_length) * head_size; // L x H
const size_t present_buff_chunk_length = static_cast<size_t>(present_buffer_sequence_length) * head_size; // T x H
PrepareMaskGQA(mask_data, batch_size, sequence_length, present_buffer_sequence_length, local_window_size_, seqlens_k);
const int loop_len = batch_size * num_heads_;
const float alpha = scale_ == 0.0f ? 1.0f / sqrt(static_cast<float>(head_size)) : scale_;
// TODO: cost might differ for gqa because of right padding and total_sequence_length being sequence dependent
TensorOpCost unit_cost;
const size_t probs_matrix_bytes = SafeInt<size_t>(sequence_length) * present_buffer_sequence_length * sizeof(T);
unit_cost.compute_cycles = static_cast<double>(2 * sequence_length * head_size * present_buffer_sequence_length);
unit_cost.bytes_loaded = static_cast<double>((sequence_length + present_buffer_sequence_length) * head_size * sizeof(T));
unit_cost.bytes_stored = static_cast<double>(probs_matrix_bytes);
unit_cost.bytes_loaded += static_cast<double>(probs_matrix_bytes);
unit_cost.bytes_stored += static_cast<double>(probs_matrix_bytes);
if (present_key) {
double bytes_to_copy_key = static_cast<double>(sizeof(T) * present_buff_chunk_length);
unit_cost.bytes_loaded += bytes_to_copy_key;
unit_cost.bytes_stored += bytes_to_copy_key;
}
ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t i = begin; i != end; ++i) {
const int batch_index = static_cast<int>(i) / num_heads_;
const int head_index = static_cast<int>(i) % num_heads_;
const int past_seqlen = sequence_length == 1 ? static_cast<int>(seqlens_k[batch_index]) : past_buffer_sequence_length;
const size_t past_chunk_length = static_cast<size_t>(past_seqlen) * head_size;
const int output_offset = static_cast<int>(i) * sequence_length * present_buffer_sequence_length;
const int mask_offset = batch_index * sequence_length * present_buffer_sequence_length;
T* output = attention_probs + output_offset;
// Broadcast mask data: (Bx)SxT -> (BxNx)SxT
// TODO: mask after present_sequence_length
memcpy(output,
mask_data + mask_offset,
probs_matrix_bytes);
const T* k;
if (packed_qkv) {
k = K + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor);
} else {
k = K + kv_input_chunk_length * (i / kv_num_heads_factor);
}
if (nullptr != present_key) {
k = ConcatStateChunkGQA(past_key, k, present_key, present_buff_chunk_length, past_buff_chunk_length,
past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer,
i / kv_num_heads_factor);
}
// TODO: CblasTrans stuff what do?
// Compute Q*K' + AttentionMask
// original transposed each iteration
// A: Q (B x N x) S x H (B x N x) S x H S x H
// B: K' (B x N x) T x H (B x N x) H x T H x T
// C: attention_probs (B x N x) S x T (B x N x) S x T S x T
const T* q;
if (packed_qkv) {
q = Q + packed_batch_stride * batch_index + q_input_chunk_length * head_index;
} else {
q = Q + q_input_chunk_length * i;
}
math::Gemm<T, ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, present_buffer_sequence_length, head_size, alpha,
q, k, mask_data != nullptr ? 1.0f : 0.0f, output, nullptr);
}
});
// attention_probs(B, N, S, T) = Softmax(attention_probs)
const int N = batch_size * num_heads_ * sequence_length;
const int D = present_buffer_sequence_length;
ComputeAttentionSoftmaxInplace(attention_probs, N, D, tp);
}
template <typename T>
void ComputeVxAttentionScore(T* output, // buffer for the result with size BxSxNxH
T* tmp_buffer, // buffer for temp use with size is BxNxSxH
const T* attention_probs, // Attention probs with size BxNxSxT
const T* V, // V value with size BxN_kvxSxH
const int32_t* seqlens_k, // past sequence lengths tensor
int batch_size, // batch size
int sequence_length, // sequence length
int past_buffer_sequence_length, // sequence length in past state
int present_buffer_sequence_length, // sequence length in past state
int head_size, // head size of Q, K, V
int hidden_size, // hidden size of Output
const T* past_value, // past value only (if not using past state)
T* present_value, // present value only (if not using present state)
bool past_present_share_buffer, // whether present key and value share the same buffer
bool packed_qkv, // whether Q, K, V are packed
ThreadPool* tp) const {
const bool is_prompt = sequence_length != 1;
const int packed_batch_stride = packed_qkv ? (num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : 0;
const int kv_num_heads_factor = num_heads_ / kv_num_heads_;
// TODO: what is with these being ptrdiff_t?
const ptrdiff_t q_input_chunk_length = SafeInt<ptrdiff_t>(sequence_length) * head_size; // S x H
const ptrdiff_t kv_input_chunk_length = SafeInt<ptrdiff_t>(sequence_length) * head_size; // L x H
const size_t past_buff_chunk_length = static_cast<size_t>(past_buffer_sequence_length) * head_size; // L x H
const size_t present_buff_chunk_length = static_cast<size_t>(present_buffer_sequence_length) * head_size; // T x H
// The cost of Gemm
TensorOpCost unit_cost;
unit_cost.compute_cycles = static_cast<double>(2 * sequence_length * head_size * present_buffer_sequence_length);
unit_cost.bytes_loaded = static_cast<double>((sequence_length + head_size) * present_buffer_sequence_length * sizeof(T));
unit_cost.bytes_stored = static_cast<double>(sequence_length * head_size * sizeof(T));
if (present_value) {
double bytes_to_copy_value = static_cast<double>(present_buff_chunk_length * sizeof(T));
unit_cost.bytes_loaded += bytes_to_copy_value;
unit_cost.bytes_stored += bytes_to_copy_value;
}
const size_t bytes_to_copy_trans = SafeInt<size_t>(head_size) * sizeof(T);
double bytes_to_copy_trans_all = static_cast<double>(sequence_length * bytes_to_copy_trans);
unit_cost.bytes_loaded += bytes_to_copy_trans_all;
unit_cost.bytes_stored += bytes_to_copy_trans_all;
ThreadPool::TryParallelFor(tp, SafeInt<ptrdiff_t>(batch_size) * num_heads_, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t i = begin; i != end; ++i) {
const int batch_index = static_cast<int>(i / num_heads_);
const int head_index = static_cast<int>(i % num_heads_);
const int past_seqlen = sequence_length == 1 ? static_cast<int>(seqlens_k[batch_index]) : past_buffer_sequence_length;
const size_t past_chunk_length = static_cast<size_t>(past_seqlen) * head_size;
const T* v;
if (packed_qkv) {
v = V + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor);
} else {
v = V + kv_input_chunk_length * (i / kv_num_heads_factor);
}
if (nullptr != present_value) {
v = ConcatStateChunkGQA(past_value, v, present_value, present_buff_chunk_length, past_buff_chunk_length,
past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer,
i / kv_num_heads_factor);
}
T* current_tmp_data = reinterpret_cast<T*>(tmp_buffer) + q_input_chunk_length * i;
ptrdiff_t attention_probs_offset = SafeInt<ptrdiff_t>(sequence_length) * present_buffer_sequence_length * i;
math::MatMul<T>(sequence_length, head_size, present_buffer_sequence_length,
attention_probs + attention_probs_offset,
v, current_tmp_data, nullptr);
// Transpose: out(B, S, N, H_v) -> out_tmp(B, N, S, H_v)
T* src = current_tmp_data;
ptrdiff_t dest_offset = (SafeInt<ptrdiff_t>(batch_index) * sequence_length * num_heads_ + head_index) * head_size;
T* dest = output + dest_offset;
for (int j = 0; j < sequence_length; j++) {
memcpy(dest, src, bytes_to_copy_trans);
src += head_size;
dest += hidden_size;
}
}
});
}
};
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,192 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "group_query_attention.h"
#include "group_query_attention_helper.h"
#include "attention_utils.h"
#include "rotary_embedding.h"
#include "rotary_embedding_helper.h"
#include "core/framework/tensorprotoutils.h"
#include "core/graph/onnx_protobuf.h"
#include "core/common/safeint.h"
#include "core/platform/threadpool.h"
#include <unsupported/Eigen/SpecialFunctions>
#include <vector>
using onnxruntime::concurrency::ThreadPool;
namespace onnxruntime {
namespace contrib {
// These ops are internal-only, so register outside of onnx
ONNX_OPERATOR_TYPED_KERNEL_EX(
GroupQueryAttention,
kMSDomain,
1,
float,
kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("M", DataTypeImpl::GetTensorType<int32_t>()),
GroupQueryAttention<float>);
template <typename T>
GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info) : OpKernel(info), GQAAttentionBase(info, false) {
int64_t num_heads = 0;
int64_t kv_num_heads = 0;
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0);
num_heads_ = static_cast<int>(num_heads);
kv_num_heads_ = static_cast<int>(kv_num_heads);
mask_filter_value_ = info.GetAttrOrDefault<float>("mask_filter_value", -10000.0f);
local_window_size_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("local_window_size", -1));
do_rotary_ = info.GetAttrOrDefault<int64_t>("do_rotary", 0) == 1;
rotary_interleaved_ = info.GetAttrOrDefault<int64_t>("rotary_interleaved", 0) == 1;
}
template <typename T>
Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
const Tensor* query = context->Input<Tensor>(0);
const Tensor* key = context->Input<Tensor>(1);
const Tensor* value = context->Input<Tensor>(2);
const Tensor* past_key = context->Input<Tensor>(3);
const Tensor* past_value = context->Input<Tensor>(4);
const Tensor* seqlens_k = context->Input<Tensor>(5);
const Tensor* total_seqlen = context->Input<Tensor>(6);
const Tensor* cos_cache = context->Input<Tensor>(7);
const Tensor* sin_cache = context->Input<Tensor>(8);
GroupQueryAttentionParameters parameters = {};
constexpr float scale = 1.0f;
ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query,
key,
value,
past_key,
past_value,
cos_cache,
sin_cache,
&parameters,
num_heads_,
kv_num_heads_,
seqlens_k,
total_seqlen,
scale));
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int present_kv_seqlen = parameters.seqlen_present_kv_cache;
int head_size = parameters.head_size;
int q_hidden_size = parameters.hidden_size;
const bool packed_qkv = parameters.is_packed_qkv;
std::vector<int64_t> output_shape(3);
output_shape[0] = static_cast<int64_t>(batch_size);
output_shape[1] = static_cast<int64_t>(sequence_length);
output_shape[2] = static_cast<int64_t>(q_hidden_size);
Tensor* output = context->Output(0, output_shape);
std::vector<int64_t> present_k_shape({static_cast<int64_t>(batch_size), static_cast<int64_t>(kv_num_heads_), static_cast<int64_t>(present_kv_seqlen), static_cast<int64_t>(head_size)});
std::vector<int64_t> present_v_shape({static_cast<int64_t>(batch_size), static_cast<int64_t>(kv_num_heads_), static_cast<int64_t>(present_kv_seqlen), static_cast<int64_t>(head_size)});
Tensor* present_k = context->Output(1, present_k_shape);
Tensor* present_v = context->Output(2, present_v_shape);
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
auto element_type = DataTypeImpl::GetType<T>();
OrtValue Q;
OrtValue K;
OrtValue V;
if (packed_qkv) {
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH<T>(
allocator, batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size, query, Q));
} else if (sequence_length > 1) {
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH<T>(
allocator, batch_size, num_heads_, sequence_length, head_size, query, Q));
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH<T>(
allocator, batch_size, kv_num_heads_, sequence_length, head_size, key, K));
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH<T>(
allocator, batch_size, kv_num_heads_, sequence_length, head_size, value, V));
} else {
Tensor::InitOrtValue(std::move(const_cast<Tensor&>(*query)), Q);
Tensor::InitOrtValue(std::move(const_cast<Tensor&>(*key)), K);
Tensor::InitOrtValue(std::move(const_cast<Tensor&>(*value)), V);
}
if (do_rotary_) {
rotary_embedding_helper::RotaryParameters rotary_params = {};
rotary_params.batch_size = batch_size;
rotary_params.sequence_length = sequence_length;
rotary_params.hidden_size = q_hidden_size;
rotary_params.head_size = head_size;
rotary_params.rotary_embedding_dim = parameters.rotary_dim;
rotary_params.num_heads = num_heads_;
rotary_params.max_sequence_length = sequence_length; // unused
rotary_params.seq_stride = head_size;
rotary_params.head_stride = sequence_length * rotary_params.seq_stride;
rotary_params.batch_stride = (packed_qkv ? (num_heads_ + 2 * kv_num_heads_) : num_heads_) * rotary_params.head_stride;
rotary_params.position_ids_format = sequence_length == 1 ? 1 : 0;
rotary_params.transposed = true;
auto* tp = context->GetOperatorThreadPool();
std::vector<int64_t> pos_ids(sequence_length == 1 ? batch_size : 1);
if (sequence_length == 1) {
for (int b = 0; b < batch_size; b++) {
pos_ids[b] = static_cast<int64_t>(seqlens_k->Data<int32_t>()[b]);
}
} else {
pos_ids[0] = static_cast<int64_t>(0);
}
const T* q_input;
const T* k_input;
T* q_rotary;
T* k_rotary;
if (packed_qkv) {
OrtValue RotaryQKV;
Tensor::InitOrtValue(element_type, TensorShape({batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size}), allocator, RotaryQKV);
q_input = Q.Get<Tensor>().Data<T>();
k_input = q_input + num_heads_ * sequence_length * head_size;
q_rotary = RotaryQKV.GetMutable<Tensor>()->MutableData<T>();
k_rotary = q_rotary + num_heads_ * sequence_length * head_size;
Q = RotaryQKV;
} else {
OrtValue RotaryQ;
Tensor::InitOrtValue(element_type, TensorShape({batch_size, num_heads_, sequence_length, head_size}), allocator, RotaryQ);
OrtValue RotaryK;
Tensor::InitOrtValue(element_type, TensorShape({batch_size, kv_num_heads_, sequence_length, head_size}), allocator, RotaryK);
q_input = Q.Get<Tensor>().Data<T>();
k_input = K.Get<Tensor>().Data<T>();
q_rotary = RotaryQ.GetMutable<Tensor>()->MutableData<T>();
k_rotary = RotaryK.GetMutable<Tensor>()->MutableData<T>();
Q = RotaryQ;
K = RotaryK;
}
ORT_RETURN_IF_ERROR(RunRotaryEmbedding<T>(tp, rotary_params, q_input,
pos_ids.data(), cos_cache->Data<T>(),
sin_cache->Data<T>(), q_rotary, rotary_interleaved_));
rotary_params.num_heads = kv_num_heads_;
rotary_params.hidden_size = parameters.kv_hidden_size;
if (!packed_qkv) {
rotary_params.batch_stride = kv_num_heads_ * rotary_params.head_stride;
}
ORT_RETURN_IF_ERROR(RunRotaryEmbedding<T>(tp, rotary_params, k_input,
pos_ids.data(), cos_cache->Data<T>(),
sin_cache->Data<T>(), k_rotary, rotary_interleaved_));
if (packed_qkv) {
const T* v_input = k_input + kv_num_heads_ * sequence_length * head_size;
T* v_rotary = k_rotary + kv_num_heads_ * sequence_length * head_size;
ORT_RETURN_IF_ERROR(group_query_attention_helper::PackVIntoRotaryQKV<T>(tp, parameters, v_input, v_rotary));
}
}
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
// Compute the attention score and apply the score to V
return ApplyAttention(Q.Get<Tensor>().Data<T>(), packed_qkv ? nullptr : K.Get<Tensor>().Data<T>(),
packed_qkv ? nullptr : V.Get<Tensor>().Data<T>(), past_key, past_value, output, present_k, present_v,
seqlens_k, parameters, allocator, context);
}
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,21 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "gqa_attention_base.h"
namespace onnxruntime {
namespace contrib {
template <typename T>
class GroupQueryAttention final : public OpKernel, public GQAAttentionBase {
public:
GroupQueryAttention(const OpKernelInfo& info);
Status Compute(OpKernelContext* context) const override;
};
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,299 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/providers/common.h"
#include "contrib_ops/cpu/bert/attention_common.h"
namespace onnxruntime {
namespace contrib {
namespace group_query_attention_helper {
Status CheckInputs(const Tensor* query,
const Tensor* key,
const Tensor* value,
const Tensor* past_key,
const Tensor* past_value,
const Tensor* cos_cache,
const Tensor* sin_cache,
void* parameters,
int num_heads,
int kv_num_heads,
const Tensor* seqlens_k,
const Tensor* total_seqlen,
float scale) {
// Note: Here S* is seqlen_past_kv_cache, S+ is seqlen_present_kv_cache
// past_key : (B, N_k, S*, H) or (B, N_k, S+, H) or nullptr
// past_value : (B, N_k, S*, H) or (B, N_k, S+, H) or nullptr
// no packing for q/k/v:
// query (Q) : (B, S, D) or (B, S, (D_q + 2 D_kv))
// key (K) : (B, S, D_kv) or nullptr
// value (V) : (B, S, D_kv) or nullptr
AttentionQkvFormat qkv_format = Q_K_V_BSNH;
AttentionQkvFormat past_kv_format = Q_K_V_BNSH;
const bool is_packed_qkv = key == nullptr;
const auto& query_dims = query->Shape().GetDims();
if (query_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ",
query_dims.size());
}
int batch_size = static_cast<int>(query_dims[0]);
int sequence_length = static_cast<int>(query_dims[1]);
int q_hidden_size = static_cast<int>(query_dims[2]);
int head_size = 0;
if (num_heads % kv_num_heads != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ",
num_heads % kv_num_heads);
}
int kv_hidden_size = 0;
// Check key and value when not packed
if (!is_packed_qkv) {
head_size = static_cast<int>(q_hidden_size) / num_heads;
if (head_size % 8 != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"head_size must be a multiple of 8. Got head_size % 8 == ",
head_size % 8);
}
if (value == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv.");
}
const auto& key_dims = key->Shape().GetDims();
if (key_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ",
key_dims.size());
} else if (query_dims[0] != key_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'key' shall have same dim 0 (batch size)");
} else if (query_dims[1] != key_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'key' shall have same dim 1 (sequence length)");
}
kv_hidden_size = static_cast<int>(key_dims[2]);
const auto& value_dims = value->Shape().GetDims();
if (value_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ",
value_dims.size());
} else if (query_dims[0] != value_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'value' shall have same dim 0 (batch size)");
} else if (query_dims[1] != value_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'value' shall have same dim 1 (sequence length)");
} else if (value_dims[2] != kv_hidden_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key.");
}
} else {
// Check packed qkv
head_size = static_cast<int>(q_hidden_size) / (num_heads + 2 * kv_num_heads);
if (head_size % 8 != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"head_size must be a multiple of 8. Got head_size % 8 == ",
head_size % 8);
}
if (value != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv.");
}
q_hidden_size = head_size * num_heads;
kv_hidden_size = head_size * kv_num_heads;
}
// Check past-present KV
int32_t past_sequence_length = 0;
if (past_key != nullptr && past_value != nullptr) {
const auto& past_key_dims = past_key->Shape().GetDims();
const auto& past_value_dims = past_value->Shape().GetDims();
if (past_key_dims.size() != 4) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_key' is expected to have 4 dimensions, got ",
past_key_dims.size());
}
if (past_value_dims.size() != 4) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_value' is expected to have 4 dimensions, got ",
past_value_dims.size());
}
if (past_key_dims[0] != batch_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_key' dimension 0 should be batch_size, got ",
past_key_dims[0]);
}
if (past_value_dims[0] != batch_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_value' dimension 0 should be batch_size, got ",
past_value_dims[0]);
}
if (past_key_dims[2] != past_value_dims[2]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"BNSH Input 'past_key' and 'past_value' should have same dimension 2 (max sequence"
"length or past sequence length), got ",
past_key_dims[1]);
}
if (past_key_dims[1] != kv_num_heads) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_key' shall have kv_num_heads");
}
if (past_value_dims[1] != kv_num_heads) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_value' shall have kv_num_heads");
}
// We assume all sequence in past kv are right-padded to max or past sequence length
past_sequence_length = static_cast<int>(past_key_dims[2]);
if (past_key_dims[3] != head_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_key' dimension 3 should be same as head_size, got ",
past_key_dims[3]);
}
if (past_value_dims[3] != head_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_value' dimension 3 should be same as head_size, got ",
past_value_dims[3]);
}
} else if (past_key != nullptr || past_value != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_key' and 'past_value' shall be both present or both absent.");
}
// Check seqlens_k tensor (holding past seqlen for token gen)
const auto& seqlens_dim = seqlens_k->Shape().GetDims();
if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"seqlens_k must be shape (batch_size).");
}
// Set present sequence length and kv_share_buffer from input total_seqlen tensor
if (!onnxruntime::IsScalarOr1ElementVector(total_seqlen)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"total_sequence_length tensor must be of one element.");
}
int total_sequence_length = *((*total_seqlen).template Data<int32_t>());
int present_sequence_length = std::max(total_sequence_length, past_sequence_length);
int rotary_dim = 0;
if (cos_cache != nullptr && sin_cache != nullptr) {
const auto& cos_dims = cos_cache->Shape().GetDims();
const auto& sin_dims = sin_cache->Shape().GetDims();
if (head_size % 16 != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"head_size shall be a multiple of 16. Got head_size % 16 == ",
head_size % 16);
}
if (cos_dims[0] < present_sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"cos_cache dimension 0 should be of max_sequence_length.");
}
if (sin_dims[0] < present_sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"sin_cache dimension 0 should be of max_sequence_length.");
}
if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
}
if (sin_dims[1] > (head_size / 16) * 8 || sin_dims[1] % 8 != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
}
if (cos_dims[1] != sin_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"cos_cache and sin_cache dimension 1 must be the same.");
}
rotary_dim = static_cast<int>(cos_dims[1] * 2);
} else if (cos_cache != nullptr || sin_cache != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'cos_cache' and 'sin_cache' shall be both present or both absent.");
}
bool is_prompt = sequence_length != 1;
if (parameters != nullptr) {
GroupQueryAttentionParameters* output_parameters = reinterpret_cast<GroupQueryAttentionParameters*>(parameters);
output_parameters->batch_size = batch_size;
output_parameters->sequence_length = sequence_length; // sequence length of Q
output_parameters->seqlen_past_kv_cache = past_sequence_length; // max sequence length of past kv tensors
output_parameters->seqlen_present_kv_cache = present_sequence_length; // max sequence length of present kv tensors
output_parameters->hidden_size = q_hidden_size;
output_parameters->num_heads = num_heads;
output_parameters->head_size = head_size;
output_parameters->kv_hidden_size = kv_hidden_size;
output_parameters->kv_num_heads = kv_num_heads;
output_parameters->rotary_dim = rotary_dim;
output_parameters->is_packed_qkv = is_packed_qkv;
output_parameters->is_unidirectional = true;
output_parameters->is_prompt = is_prompt;
output_parameters->scale = scale;
output_parameters->qkv_format = qkv_format;
output_parameters->past_kv_format = past_kv_format;
}
return Status::OK();
}
Status CheckInputs(const Tensor* query,
const Tensor* key,
const Tensor* value,
const Tensor* past_key,
const Tensor* past_value,
const Tensor* cos_cache,
const Tensor* sin_cache,
void* parameters,
int num_heads,
int kv_num_heads,
const Tensor* seqlens_k,
const Tensor* total_seqlen,
float scale,
int max_threads_per_block) {
if (max_threads_per_block > 0 && num_heads > max_threads_per_block) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block);
}
return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, scale);
}
template <typename T>
Status PackVIntoRotaryQKV(concurrency::ThreadPool* tp, GroupQueryAttentionParameters parameters, const T* input,
T* output) {
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
int n_heads = parameters.num_heads;
const int kv_n_heads = parameters.kv_num_heads;
const int head_size = parameters.head_size;
int seq_stride = head_size;
int head_stride = sequence_length * seq_stride;
int batch_stride = (n_heads + 2 * kv_n_heads) * head_stride;
const int loop_len = batch_size * sequence_length * kv_n_heads;
const double cost = static_cast<double>(head_size);
ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) {
const int b = static_cast<int>((ptr / kv_n_heads) / sequence_length);
const int s = static_cast<int>((ptr / kv_n_heads) % sequence_length);
const int n = static_cast<int>(ptr % kv_n_heads);
const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;
const T* input_data = input + block_offset;
T* output_data = output + block_offset;
for (int i = 0; i < head_size; i++) {
output_data[i] = input_data[i];
}
}
});
return Status::OK();
}
} // namespace group_query_attention_helper
} // namespace contrib
} // namespace onnxruntime

View file

@ -4,15 +4,13 @@
#include "attention_cpu_base.h"
#include "multihead_attention.h"
#include "multihead_attention_helper.h"
#include "attention_utils.h"
#include "core/common/common.h"
#include "core/framework/tensorprotoutils.h"
#include "core/framework/transpose_helper.h"
#include "core/graph/onnx_protobuf.h"
#include "core/common/safeint.h"
#include "core/platform/threadpool.h"
#include "core/providers/cpu/math/element_wise_ops.h"
#include "core/providers/cpu/tensor/reshape_helper.h"
#include <unsupported/Eigen/SpecialFunctions>
#include <vector>
@ -43,217 +41,6 @@ MultiHeadAttention<T>::MultiHeadAttention(const OpKernelInfo& info) : OpKernel(i
is_unidirectional_ = info.GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;
}
// Reshape Q/K/V from BxSxD to BxSxNxH
Status Reshape_BSD_to_BSNH(Tensor* qkv,
int batch_size,
int sequence_length,
int num_heads,
int head_size) {
std::vector<int64_t> reshape_dims({batch_size, sequence_length, num_heads, head_size});
gsl::span<const int64_t> reshape_dims_span{reshape_dims};
TensorShape qkv_bsnh(reshape_dims_span);
qkv->Reshape(qkv_bsnh);
return Status::OK();
}
// Transpose Q/K/V from BxSxNxH to BxNxSxH
Status Transpose_BSNH_to_BNSH(const Tensor* qkv,
OrtValue& qkv_transposed,
concurrency::ThreadPool* tp = nullptr) {
std::vector<size_t> permutations({0, 2, 1, 3});
gsl::span<const size_t> permutations_span{permutations};
size_t from = 2, to = 1;
SingleAxisTranspose(permutations_span, *qkv, *qkv_transposed.GetMutable<Tensor>(), from, to, nullptr, tp);
return Status::OK();
}
// Add bias + transpose for each of Q/K/V
template <typename T>
Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V data - query is BxSxD, key is BxLxD, value is BxLxD_v
const T* qkv_bias, // Input: QKV bias - bias is (D + D + D_v)
OrtValue& qkv_with_bias_transposed, // Output: Q/K/V data - query is BxNxSxH, key is BxNxLxH, value is BxNxLxH_v
int bias_offset, // bias offset to enter qkv_bias
int batch_size, // batch size
int sequence_length, // sequence_length for Q, kv_sequence_length for K/V
int num_heads, // num heads
int head_size, // head_size for Q/K, v_head_size for V
int hidden_size, // hidden_size for Q/K, v_hidden_size for V
OpKernelContext* context) {
// Note: the comments below will refer to Q's dimensions for simplicity
auto element_type = DataTypeImpl::GetType<T>();
constexpr size_t element_size = sizeof(T);
ProcessBroadcastSpanFuncs add_funcs{
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.ScalarInput0<T>() + per_iter_bh.EigenInput1<T>().array();
},
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>().array() + per_iter_bh.ScalarInput1<T>();
},
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>() + per_iter_bh.EigenInput1<T>();
}}; // For element-wise add
// Allocate space for output of Q(BS, D) + bias(D)
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
std::vector<int64_t> old_dims({batch_size, sequence_length, hidden_size});
gsl::span<const int64_t> old_dims_span{old_dims};
TensorShape qkv_with_bias_shape(old_dims_span);
OrtValue qkv_with_bias;
Tensor::InitOrtValue(element_type, qkv_with_bias_shape, allocator, qkv_with_bias);
// Get Q's bias from combined bias
std::vector<int64_t> bias_dims({hidden_size});
gsl::span<const int64_t> bias_dims_span{bias_dims};
TensorShape bias_shape(bias_dims_span);
OrtValue bias;
Tensor::InitOrtValue(element_type, bias_shape, allocator, bias);
memcpy(bias.GetMutable<Tensor>()->MutableData<T>(), qkv_bias + bias_offset, hidden_size * element_size);
// Compute Q(BS, D) + bias(D) as broadcasted element-wise add
{
InputBroadcaster input_broadcaster(*bias.GetMutable<Tensor>(), *qkv);
const InputBroadcaster& const_input_broadcaster = input_broadcaster;
Tensor& output_tensor = *qkv_with_bias.GetMutable<Tensor>();
size_t span_size = input_broadcaster.GetSpanSize();
size_t output_size = static_cast<ptrdiff_t>(output_tensor.Shape().Size());
void* user_data = nullptr;
const int loop_len = static_cast<int>(output_size / span_size);
double unit_cost = 1.0f;
const auto cost = TensorOpCost{static_cast<double>(input_broadcaster.Input0ElementSize()) * span_size,
static_cast<double>(output_tensor.DataType()->Size()) * span_size,
unit_cost * span_size};
auto tp = context->GetOperatorThreadPool();
ThreadPool::TryParallelFor(tp, loop_len, cost,
[span_size, &const_input_broadcaster, &output_tensor, &add_funcs, user_data](std::ptrdiff_t first_span,
std::ptrdiff_t last_span) {
InputBroadcaster segment_input_broadcaster(const_input_broadcaster);
segment_input_broadcaster.AdvanceBy(first_span * span_size);
OutputBroadcaster segment_output_broadcaster(span_size, output_tensor,
first_span * span_size, last_span * span_size);
BroadcastHelper segment_helper(segment_input_broadcaster, segment_output_broadcaster, user_data);
BroadcastLooper(segment_helper, add_funcs);
});
}
// Reshape Q from BxSxD to BxSxNxH
ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(qkv_with_bias.GetMutable<Tensor>(), batch_size, sequence_length, num_heads, head_size));
// Transpose Q from BxSxNxH to BxNxSxH
auto tp = context->GetOperatorThreadPool();
ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH(qkv_with_bias.GetMutable<Tensor>(), qkv_with_bias_transposed, tp));
return Status::OK();
}
// Add bias + reshape for each of Q/K/V
// This is used in decoder_with_past when the sequence length is 1
template <typename T>
Status AddBiasReshape(const Tensor* qkv, // Input: Q/K/V data - query is BxSxD, key is BxLxD, value is BxLxD_v
const T* qkv_bias, // Input: QKV bias - bias is (D + D + D_v)
OrtValue& qkv_with_bias, // Output: Q/K/V data - query is BxNxSxH, key is BxNxLxH, value is BxNxLxH_v
int bias_offset, // bias offset to enter qkv_bias
int batch_size, // batch size
int sequence_length, // sequence_length for Q, kv_sequence_length for K/V
int num_heads, // num heads
int head_size, // head_size for Q/K, v_head_size for V
int hidden_size, // hidden_size for Q/K, v_hidden_size for V
OpKernelContext* context) {
// Note: the comments below will refer to Q's dimensions for simplicity
auto element_type = DataTypeImpl::GetType<T>();
constexpr size_t element_size = sizeof(T);
ProcessBroadcastSpanFuncs add_funcs{
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.ScalarInput0<T>() + per_iter_bh.EigenInput1<T>().array();
},
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>().array() + per_iter_bh.ScalarInput1<T>();
},
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>() + per_iter_bh.EigenInput1<T>();
}}; // For element-wise add
// Get Q's bias from combined bias
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
std::vector<int64_t> bias_dims({hidden_size});
gsl::span<const int64_t> bias_dims_span{bias_dims};
TensorShape bias_shape(bias_dims_span);
OrtValue bias;
Tensor::InitOrtValue(element_type, bias_shape, allocator, bias);
auto num_bias_elements = SafeInt<size_t>(hidden_size) * element_size;
memcpy(bias.GetMutable<Tensor>()->MutableData<T>(), qkv_bias + bias_offset, num_bias_elements);
// Compute Q(BS, D) + bias(D) as broadcasted element-wise add
{
InputBroadcaster input_broadcaster(*bias.GetMutable<Tensor>(), *qkv);
const InputBroadcaster& const_input_broadcaster = input_broadcaster;
Tensor& output_tensor = *qkv_with_bias.GetMutable<Tensor>();
size_t span_size = input_broadcaster.GetSpanSize();
size_t output_size = static_cast<ptrdiff_t>(output_tensor.Shape().Size());
void* user_data = nullptr;
const int loop_len = static_cast<int>(output_size / span_size);
double unit_cost = 1.0f;
const auto cost = TensorOpCost{static_cast<double>(input_broadcaster.Input0ElementSize()) * span_size,
static_cast<double>(output_tensor.DataType()->Size()) * span_size,
unit_cost * span_size};
auto tp = context->GetOperatorThreadPool();
ThreadPool::TryParallelFor(tp, loop_len, cost,
[span_size, &const_input_broadcaster, &output_tensor, &add_funcs, user_data](std::ptrdiff_t first_span,
std::ptrdiff_t last_span) {
InputBroadcaster segment_input_broadcaster(const_input_broadcaster);
segment_input_broadcaster.AdvanceBy(first_span * span_size);
OutputBroadcaster segment_output_broadcaster(span_size, output_tensor,
first_span * span_size, last_span * span_size);
BroadcastHelper segment_helper(segment_input_broadcaster, segment_output_broadcaster, user_data);
BroadcastLooper(segment_helper, add_funcs);
});
}
// Reshape Q from BxSxD to BxNxSxH
std::vector<int64_t> reshape_dims({batch_size, num_heads, sequence_length, head_size});
gsl::span<const int64_t> reshape_dims_span{reshape_dims};
TensorShape qkv_final_dims(reshape_dims_span);
qkv_with_bias.GetMutable<Tensor>()->Reshape(qkv_final_dims);
return Status::OK();
}
template <typename T>
Status MaybeTransposeToBNSHAndAddBias(OpKernelContext* context, AllocatorPtr allocator,
int batch_size, int num_heads, int sequence_length, int head_size,
const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out) {
auto element_type = DataTypeImpl::GetType<T>();
std::vector<int64_t> new_dims({batch_size, num_heads, sequence_length, head_size});
gsl::span<const int64_t> new_dims_span{new_dims};
TensorShape v_BNLH(new_dims_span);
Tensor::InitOrtValue(element_type, v_BNLH, allocator, out);
if (bias == nullptr) {
std::unique_ptr<Tensor> reshaped;
if (in->Shape().GetDims().size() == 3) {
reshaped = std::make_unique<Tensor>(in->DataType(), in->Shape(), const_cast<void*>(in->DataRaw()), in->Location());
ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(reshaped.get(), batch_size, sequence_length, num_heads, head_size));
}
ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH((reshaped == nullptr) ? in : reshaped.get(), out));
} else {
const auto* qkv_bias = bias->Data<T>();
if (sequence_length == 1) {
ORT_RETURN_IF_ERROR(AddBiasReshape(in, qkv_bias, out, bias_offset, batch_size, sequence_length, num_heads, head_size, num_heads * head_size, context));
} else {
ORT_RETURN_IF_ERROR(AddBiasTranspose(in, qkv_bias, out, bias_offset, batch_size, sequence_length, num_heads, head_size, num_heads * head_size, context));
}
}
return Status::OK();
};
template <typename T>
Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
const Tensor* query = context->Input<Tensor>(0);

View file

@ -36,6 +36,71 @@ RotaryEmbedding<T>::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) {
}
}
// TODO: rotary embedding in place
template <typename T>
Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters parameters, const T* input,
const int64_t* position_ids, const T* cos_cache, const T* sin_cache, T* output,
bool interleaved) {
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int n_heads = parameters.num_heads;
const int head_size = parameters.head_size;
const int head_stride = parameters.head_stride;
const int seq_stride = parameters.seq_stride;
const int batch_stride = parameters.batch_stride;
const int position_ids_format = parameters.position_ids_format;
const int rotary_emb_dim = parameters.rotary_embedding_dim;
const int half_rotary_emb_dim = rotary_emb_dim / 2;
const int loop_len = batch_size * sequence_length * n_heads;
const double cost = static_cast<double>(rotary_emb_dim);
ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) {
const int b = static_cast<int>((ptr / n_heads) / sequence_length);
const int s = static_cast<int>((ptr / n_heads) % sequence_length);
const int n = static_cast<int>(ptr % n_heads);
const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;
const T* input_data = input + block_offset;
T* output_data = output + block_offset;
// Cache is (M, H/2) or (M, rotary_embedding_dim/2)
const int position_id = (position_ids_format == 0)
? static_cast<int>(position_ids[0]) + s
: static_cast<int>(position_ids[b * sequence_length + s]);
const int cache_offset = position_id * half_rotary_emb_dim;
const T* cos_data = cos_cache + cache_offset;
const T* sin_data = sin_cache + cache_offset;
int cache_idx = 0;
T sign = 0;
int j = 0;
for (int i = 0; i < rotary_emb_dim; i++) {
if (interleaved) {
cache_idx = (i / 2) % half_rotary_emb_dim;
sign = (i % 2 == 0) ? static_cast<T>(-1) : static_cast<T>(1);
j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign
} else {
cache_idx = i % half_rotary_emb_dim;
sign = (i < half_rotary_emb_dim) ? static_cast<T>(-1) : static_cast<T>(1);
j = (i + half_rotary_emb_dim) % rotary_emb_dim;
}
output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx];
}
for (int i = rotary_emb_dim; i < head_size; i++) {
output_data[i] = input_data[i];
}
}
});
return Status::OK();
}
template Status RunRotaryEmbedding<float>(concurrency::ThreadPool* tp, RotaryParameters parameters, const float* input,
const int64_t* position_ids, const float* cos_cache, const float* sin_cache, float* output,
bool interleaved);
template <typename T>
Status RotaryEmbedding<T>::Compute(OpKernelContext* context) const {
const Tensor* input = context->Input<Tensor>(0);
@ -65,72 +130,12 @@ Status RotaryEmbedding<T>::Compute(OpKernelContext* context) const {
const T* sin_cache_data = sin_cache->Data<T>();
T* output_dest = output->MutableData<T>();
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int n_heads = parameters.num_heads;
const int head_size = parameters.head_size;
const int position_ids_format = parameters.position_ids_format;
const int rotary_emb_dim = parameters.rotary_embedding_dim;
const int half_rotary_emb_dim = rotary_emb_dim / 2;
// Default input tensor shape is [batch, seq_len, hidden_size]
int head_stride = head_size;
int seq_stride = n_heads * head_stride;
int batch_stride = sequence_length * seq_stride;
if (parameters.transposed) {
// Transposed input tensor shape is [batch, n_heads, seq_len, head_size]
seq_stride = head_size;
head_stride = sequence_length * seq_stride;
batch_stride = n_heads * head_stride;
}
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
auto* tp = context->GetOperatorThreadPool();
const int loop_len = batch_size * sequence_length * n_heads;
const double cost = static_cast<double>(rotary_emb_dim);
ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) {
const int b = static_cast<int>((ptr / n_heads) / sequence_length);
const int s = static_cast<int>((ptr / n_heads) % sequence_length);
const int n = static_cast<int>(ptr % n_heads);
const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;
const T* input_data = input_src + block_offset;
T* output_data = output_dest + block_offset;
// Cache is (M, H/2) or (M, rotary_embedding_dim/2)
const int position_id = (position_ids_format == 0)
? static_cast<int>(pos_ids_data[0]) + s
: static_cast<int>(pos_ids_data[b * sequence_length + s]);
const int cache_offset = position_id * half_rotary_emb_dim;
const T* cos_data = cos_cache_data + cache_offset;
const T* sin_data = sin_cache_data + cache_offset;
int cache_idx = 0;
T sign = 0;
int j = 0;
for (int i = 0; i < rotary_emb_dim; i++) {
if (interleaved) {
cache_idx = (i / 2) % half_rotary_emb_dim;
sign = (i % 2 == 0) ? static_cast<T>(-1) : static_cast<T>(1);
j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign
} else {
cache_idx = i % half_rotary_emb_dim;
sign = (i < half_rotary_emb_dim) ? static_cast<T>(-1) : static_cast<T>(1);
j = (i + half_rotary_emb_dim) % rotary_emb_dim;
}
output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx];
}
for (int i = rotary_emb_dim; i < head_size; i++) {
output_data[i] = input_data[i];
}
}
});
return Status::OK();
return RunRotaryEmbedding<T>(tp, parameters, input_src, pos_ids_data, cos_cache_data, sin_cache_data, output_dest,
interleaved);
}
} // namespace contrib

View file

@ -4,10 +4,16 @@
#pragma once
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "rotary_embedding_helper.h"
namespace onnxruntime {
namespace contrib {
template <typename T>
Status RunRotaryEmbedding(onnxruntime::concurrency::ThreadPool* tp, rotary_embedding_helper::RotaryParameters parameters, const T* input,
const int64_t* position_ids, const T* cos_cache, const T* sin_cache, T* output,
bool interleaved);
template <typename T>
class RotaryEmbedding final : public OpKernel {
public:

View file

@ -18,6 +18,9 @@ struct RotaryParameters {
int rotary_embedding_dim; // Rotary embedding dimension.
int num_heads; // num_heads = hidden_size / head_size
int max_sequence_length; // Sequence length used by cos/sin cache
int head_stride; // Head stride
int seq_stride; // Sequence stride
int batch_stride; // Batch stride
int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length)
bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden)
};
@ -116,6 +119,23 @@ Status CheckInputs(const T* input,
"head_size / 2 or rotary_embedding_dim / 2, got ", cos_cache_dims[1]);
}
num_heads = num_heads > 0 ? num_heads : static_cast<int>(hidden_size / head_size);
// Calculate stride values
int head_stride;
int seq_stride;
int batch_stride;
if (transposed) {
// Transposed input tensor shape is [batch, n_heads, seq_len, head_size]
seq_stride = head_size;
head_stride = sequence_length * seq_stride;
batch_stride = num_heads * head_stride;
} else {
// Default input tensor shape is [batch, seq_len, hidden_size]
head_stride = head_size;
seq_stride = num_heads * head_stride;
batch_stride = sequence_length * seq_stride;
}
// Set rotary parameters
if (parameters != nullptr) {
RotaryParameters* output_parameters = reinterpret_cast<RotaryParameters*>(parameters);
@ -123,8 +143,11 @@ Status CheckInputs(const T* input,
output_parameters->sequence_length = sequence_length;
output_parameters->hidden_size = hidden_size;
output_parameters->head_size = head_size;
output_parameters->num_heads = num_heads > 0 ? num_heads : static_cast<int>(hidden_size / head_size);
output_parameters->num_heads = num_heads;
output_parameters->max_sequence_length = max_sequence_length;
output_parameters->head_stride = head_stride;
output_parameters->seq_stride = seq_stride;
output_parameters->batch_stride = batch_stride;
output_parameters->position_ids_format = position_ids_format;
output_parameters->transposed = transposed;
output_parameters->rotary_embedding_dim = rotary_embedding_dim > 0 ? rotary_embedding_dim : head_size;

View file

@ -20,6 +20,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MultiHeadAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GroupQueryAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM);
@ -257,6 +258,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MultiHeadAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GroupQueryAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM)>,

View file

@ -1009,6 +1009,9 @@ constexpr const char* GroupQueryAttention_ver1_doc = R"DOC(
Group Query Self/Cross Attention.
Supports different number of heads for q and kv. Only supports causal or local attention.
Supports rotary position embedding.
Supports k-v cache.
CPU EP supports fp32... CUDA EP supports fp16.
)DOC";
ONNX_MS_OPERATOR_SET_SCHEMA(
@ -1094,7 +1097,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +"
"kv_sequence_length.",
"T")
.TypeConstraint("T", {"tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output to float tensors.")
.TypeConstraint("T", {"tensor(float16)", "tensor(bfloat16)", "tensor(float)"}, "Constrain input and output to float tensors.")
.TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to int tensor.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
GroupQueryAttentionTypeAndShapeInference(ctx, 3);

File diff suppressed because it is too large Load diff