mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-29 03:30:52 +00:00
GQA 4 CPU (#20299)
### Description Support GQA operator on CPU with FP32. ### Motivation and Context Right now, models generated for CPU and GPU must be different. GQA CPU allows these models to be the same.
This commit is contained in:
parent
c47a6ce70b
commit
94c69f55d4
18 changed files with 3159 additions and 279 deletions
|
|
@ -2455,6 +2455,9 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
Group Query Self/Cross Attention.
|
||||
|
||||
Supports different number of heads for q and kv. Only supports causal or local attention.
|
||||
Supports rotary position embedding.
|
||||
Supports k-v cache.
|
||||
CPU EP supports fp32... CUDA EP supports fp16.
|
||||
|
||||
#### Version
|
||||
|
||||
|
|
@ -2514,7 +2517,7 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
#### Type Constraints
|
||||
|
||||
<dl>
|
||||
<dt><tt>T</tt> : tensor(float16), tensor(bfloat16)</dt>
|
||||
<dt><tt>T</tt> : tensor(float16), tensor(bfloat16), tensor(float)</dt>
|
||||
<dd>Constrain input and output to float tensors.</dd>
|
||||
<dt><tt>M</tt> : tensor(int32)</dt>
|
||||
<dd>Constrain mask to int tensor.</dd>
|
||||
|
|
|
|||
|
|
@ -481,6 +481,7 @@ Do not modify directly.*
|
|||
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|
||||
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float)|
|
||||
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|
||||
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float)|
|
||||
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
|
||||
|MatMulFpQ4|*in* A:**T1**<br> *in* B:**T2**<br> *in* B_shape:**T3**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(int64)|
|
||||
|
|
|
|||
|
|
@ -68,6 +68,7 @@ class AttentionBase {
|
|||
const Tensor* past_seq_len = nullptr) const;
|
||||
|
||||
int num_heads_; // number of attention heads
|
||||
int kv_num_heads_; // different for k and v for group query attention
|
||||
bool is_unidirectional_; // whether every token can only attend to previous tokens.
|
||||
std::vector<int64_t> qkv_hidden_sizes_; // Q, K, V hidden sizes parsed from the qkv_hidden_sizes attribute.
|
||||
bool require_same_hidden_size_; // whether the implementation supports different hidden sizes of Q/K/V.
|
||||
|
|
|
|||
|
|
@ -153,6 +153,49 @@ void PrepareMask(const int32_t* mask_index,
|
|||
}
|
||||
}
|
||||
|
||||
// Applies causal mask and past seqlens right pad to the mask_data buffer.
|
||||
template <typename T>
|
||||
void PrepareMaskGQA(T* mask_data,
|
||||
int batch_size,
|
||||
int sequence_length,
|
||||
int buffer_sequence_length,
|
||||
int local_window_size,
|
||||
const int32_t* seqlens_k) {
|
||||
// mask_data has been filled with 0, and its shape is BxSxT
|
||||
T* p_mask = mask_data;
|
||||
// TODO: parallelize this
|
||||
for (int b_i = 0; b_i < batch_size; b_i++) {
|
||||
if (sequence_length > 1) {
|
||||
// Apply causal/local mask for prompt case.
|
||||
for (int s_i = 0; s_i < sequence_length; s_i++) {
|
||||
for (int m_i = s_i + 1; m_i < buffer_sequence_length; m_i++) {
|
||||
p_mask[s_i * buffer_sequence_length + m_i] = std::numeric_limits<T>::lowest();
|
||||
}
|
||||
// Apply local mask.
|
||||
if (local_window_size > 0) {
|
||||
for (int m_i = 0; m_i < s_i - local_window_size; m_i++) {
|
||||
p_mask[s_i * buffer_sequence_length + m_i] = std::numeric_limits<T>::lowest();
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (sequence_length == 1) {
|
||||
// Apply right padding to mask for token gen case.
|
||||
int total_seqlen = seqlens_k[b_i] + 1;
|
||||
for (int m_i = total_seqlen; m_i < buffer_sequence_length; m_i++) {
|
||||
p_mask[m_i] = std::numeric_limits<T>::lowest();
|
||||
}
|
||||
// Apply local mask.
|
||||
if (local_window_size > 0) {
|
||||
for (int m_i = 0; m_i < total_seqlen - local_window_size - 1; m_i++) {
|
||||
p_mask[m_i] = std::numeric_limits<T>::lowest();
|
||||
}
|
||||
}
|
||||
}
|
||||
ptrdiff_t mask_to_advance = SafeInt<ptrdiff_t>(sequence_length) * buffer_sequence_length;
|
||||
p_mask += mask_to_advance;
|
||||
}
|
||||
}
|
||||
|
||||
// Concatenate a past state chunk PxH with input state chunk LxH into present state chunk TxH
|
||||
// Returns a pointer to the start of present state chunk.
|
||||
template <typename T>
|
||||
|
|
@ -175,5 +218,32 @@ T* ConcatStateChunk(const T* past,
|
|||
return start;
|
||||
}
|
||||
|
||||
// GQA version of ConcatStateChunk
|
||||
template <typename T>
|
||||
T* ConcatStateChunkGQA(const T* past,
|
||||
const T* chunk,
|
||||
T* present,
|
||||
size_t present_buff_chunk_length,
|
||||
size_t past_buff_chunk_length,
|
||||
size_t past_chunk_length,
|
||||
size_t new_chunk_length,
|
||||
bool is_prompt,
|
||||
bool past_present_share_buffer,
|
||||
std::ptrdiff_t i) {
|
||||
T* start = present + i * present_buff_chunk_length;
|
||||
|
||||
T* p = start;
|
||||
if (!is_prompt) {
|
||||
if (!past_present_share_buffer) {
|
||||
const T* src_past = past + i * past_buff_chunk_length;
|
||||
memcpy(p, src_past, past_chunk_length * sizeof(T));
|
||||
}
|
||||
p += past_chunk_length;
|
||||
}
|
||||
|
||||
memcpy(p, chunk, new_chunk_length * sizeof(T));
|
||||
return start;
|
||||
}
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
246
onnxruntime/contrib_ops/cpu/bert/attention_utils.cc
Normal file
246
onnxruntime/contrib_ops/cpu/bert/attention_utils.cc
Normal file
|
|
@ -0,0 +1,246 @@
|
|||
#include "attention_utils.h"
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/framework/transpose_helper.h"
|
||||
#include "core/providers/cpu/tensor/reshape_helper.h"
|
||||
#include "core/providers/cpu/math/element_wise_ops.h"
|
||||
|
||||
using onnxruntime::concurrency::ThreadPool;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
// Reshape Q/K/V from BxSxD to BxSxNxH
|
||||
inline Status Reshape_BSD_to_BSNH(Tensor* qkv,
|
||||
int batch_size,
|
||||
int sequence_length,
|
||||
int num_heads,
|
||||
int head_size) {
|
||||
qkv->Reshape(TensorShape({batch_size, sequence_length, num_heads, head_size}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Transpose Q/K/V from BxSxNxH to BxNxSxH
|
||||
inline Status Transpose_BSNH_to_BNSH(const Tensor* qkv,
|
||||
OrtValue& qkv_transposed,
|
||||
concurrency::ThreadPool* tp) {
|
||||
std::vector<size_t> permutations({0, 2, 1, 3});
|
||||
gsl::span<const size_t> permutations_span{permutations};
|
||||
size_t from = 2, to = 1;
|
||||
SingleAxisTranspose(permutations, *qkv, *qkv_transposed.GetMutable<Tensor>(), from, to, nullptr, tp);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Add bias + transpose for each of Q/K/V
|
||||
template <typename T>
|
||||
Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V data - query is BxSxD, key is BxLxD, value is BxLxD_v
|
||||
const T* qkv_bias, // Input: QKV bias - bias is (D + D + D_v)
|
||||
OrtValue& qkv_with_bias_transposed, // Output: Q/K/V data - query is BxNxSxH, key is BxNxLxH, value is BxNxLxH_v
|
||||
int bias_offset, // bias offset to enter qkv_bias
|
||||
int batch_size, // batch size
|
||||
int sequence_length, // sequence_length for Q, kv_sequence_length for K/V
|
||||
int num_heads, // num heads
|
||||
int head_size, // head_size for Q/K, v_head_size for V
|
||||
int hidden_size, // hidden_size for Q/K, v_hidden_size for V
|
||||
OpKernelContext* context) {
|
||||
// Note: the comments below will refer to Q's dimensions for simplicity
|
||||
auto element_type = DataTypeImpl::GetType<T>();
|
||||
constexpr size_t element_size = sizeof(T);
|
||||
ProcessBroadcastSpanFuncs add_funcs{
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
per_iter_bh.OutputEigen<T>() = per_iter_bh.ScalarInput0<T>() + per_iter_bh.EigenInput1<T>().array();
|
||||
},
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>().array() + per_iter_bh.ScalarInput1<T>();
|
||||
},
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>() + per_iter_bh.EigenInput1<T>();
|
||||
}}; // For element-wise add
|
||||
|
||||
// Allocate space for output of Q(BS, D) + bias(D)
|
||||
AllocatorPtr allocator;
|
||||
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
|
||||
std::vector<int64_t> old_dims({batch_size, sequence_length, hidden_size});
|
||||
gsl::span<const int64_t> old_dims_span{old_dims};
|
||||
TensorShape qkv_with_bias_shape(old_dims_span);
|
||||
OrtValue qkv_with_bias;
|
||||
Tensor::InitOrtValue(element_type, qkv_with_bias_shape, allocator, qkv_with_bias);
|
||||
|
||||
// Get Q's bias from combined bias
|
||||
std::vector<int64_t> bias_dims({hidden_size});
|
||||
gsl::span<const int64_t> bias_dims_span{bias_dims};
|
||||
TensorShape bias_shape(bias_dims_span);
|
||||
OrtValue bias;
|
||||
Tensor::InitOrtValue(element_type, bias_shape, allocator, bias);
|
||||
memcpy(bias.GetMutable<Tensor>()->MutableData<T>(), qkv_bias + bias_offset, hidden_size * element_size);
|
||||
|
||||
// Compute Q(BS, D) + bias(D) as broadcasted element-wise add
|
||||
{
|
||||
InputBroadcaster input_broadcaster(*bias.GetMutable<Tensor>(), *qkv);
|
||||
const InputBroadcaster& const_input_broadcaster = input_broadcaster;
|
||||
Tensor& output_tensor = *qkv_with_bias.GetMutable<Tensor>();
|
||||
|
||||
size_t span_size = input_broadcaster.GetSpanSize();
|
||||
size_t output_size = static_cast<ptrdiff_t>(output_tensor.Shape().Size());
|
||||
void* user_data = nullptr;
|
||||
|
||||
const int loop_len = static_cast<int>(output_size / span_size);
|
||||
double unit_cost = 1.0f;
|
||||
const auto cost = TensorOpCost{static_cast<double>(input_broadcaster.Input0ElementSize()) * span_size,
|
||||
static_cast<double>(output_tensor.DataType()->Size()) * span_size,
|
||||
unit_cost * span_size};
|
||||
auto tp = context->GetOperatorThreadPool();
|
||||
ThreadPool::TryParallelFor(tp, loop_len, cost,
|
||||
[span_size, &const_input_broadcaster, &output_tensor, &add_funcs, user_data](std::ptrdiff_t first_span,
|
||||
std::ptrdiff_t last_span) {
|
||||
InputBroadcaster segment_input_broadcaster(const_input_broadcaster);
|
||||
segment_input_broadcaster.AdvanceBy(first_span * span_size);
|
||||
|
||||
OutputBroadcaster segment_output_broadcaster(span_size, output_tensor,
|
||||
first_span * span_size, last_span * span_size);
|
||||
|
||||
BroadcastHelper segment_helper(segment_input_broadcaster, segment_output_broadcaster, user_data);
|
||||
BroadcastLooper(segment_helper, add_funcs);
|
||||
});
|
||||
}
|
||||
|
||||
// Reshape Q from BxSxD to BxSxNxH
|
||||
ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(qkv_with_bias.GetMutable<Tensor>(), batch_size, sequence_length, num_heads, head_size));
|
||||
|
||||
// Transpose Q from BxSxNxH to BxNxSxH
|
||||
auto tp = context->GetOperatorThreadPool();
|
||||
ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH(qkv_with_bias.GetMutable<Tensor>(), qkv_with_bias_transposed, tp));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Add bias + reshape for each of Q/K/V
|
||||
// This is used in decoder_with_past when the sequence length is 1
|
||||
template <typename T>
|
||||
Status AddBiasReshape(const Tensor* qkv, // Input: Q/K/V data - query is BxSxD, key is BxLxD, value is BxLxD_v
|
||||
const T* qkv_bias, // Input: QKV bias - bias is (D + D + D_v)
|
||||
OrtValue& qkv_with_bias, // Output: Q/K/V data - query is BxNxSxH, key is BxNxLxH, value is BxNxLxH_v
|
||||
int bias_offset, // bias offset to enter qkv_bias
|
||||
int batch_size, // batch size
|
||||
int sequence_length, // sequence_length for Q, kv_sequence_length for K/V
|
||||
int num_heads, // num heads
|
||||
int head_size, // head_size for Q/K, v_head_size for V
|
||||
int hidden_size, // hidden_size for Q/K, v_hidden_size for V
|
||||
OpKernelContext* context) {
|
||||
// Note: the comments below will refer to Q's dimensions for simplicity
|
||||
auto element_type = DataTypeImpl::GetType<T>();
|
||||
constexpr size_t element_size = sizeof(T);
|
||||
ProcessBroadcastSpanFuncs add_funcs{
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
per_iter_bh.OutputEigen<T>() = per_iter_bh.ScalarInput0<T>() + per_iter_bh.EigenInput1<T>().array();
|
||||
},
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>().array() + per_iter_bh.ScalarInput1<T>();
|
||||
},
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>() + per_iter_bh.EigenInput1<T>();
|
||||
}}; // For element-wise add
|
||||
|
||||
// Get Q's bias from combined bias
|
||||
AllocatorPtr allocator;
|
||||
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
|
||||
std::vector<int64_t> bias_dims({hidden_size});
|
||||
gsl::span<const int64_t> bias_dims_span{bias_dims};
|
||||
TensorShape bias_shape(bias_dims_span);
|
||||
OrtValue bias;
|
||||
Tensor::InitOrtValue(element_type, bias_shape, allocator, bias);
|
||||
auto num_bias_elements = SafeInt<size_t>(hidden_size) * element_size;
|
||||
memcpy(bias.GetMutable<Tensor>()->MutableData<T>(), qkv_bias + bias_offset, num_bias_elements);
|
||||
|
||||
// Compute Q(BS, D) + bias(D) as broadcasted element-wise add
|
||||
{
|
||||
InputBroadcaster input_broadcaster(*bias.GetMutable<Tensor>(), *qkv);
|
||||
const InputBroadcaster& const_input_broadcaster = input_broadcaster;
|
||||
Tensor& output_tensor = *qkv_with_bias.GetMutable<Tensor>();
|
||||
|
||||
size_t span_size = input_broadcaster.GetSpanSize();
|
||||
size_t output_size = static_cast<ptrdiff_t>(output_tensor.Shape().Size());
|
||||
void* user_data = nullptr;
|
||||
|
||||
const int loop_len = static_cast<int>(output_size / span_size);
|
||||
double unit_cost = 1.0f;
|
||||
const auto cost = TensorOpCost{static_cast<double>(input_broadcaster.Input0ElementSize()) * span_size,
|
||||
static_cast<double>(output_tensor.DataType()->Size()) * span_size,
|
||||
unit_cost * span_size};
|
||||
auto tp = context->GetOperatorThreadPool();
|
||||
ThreadPool::TryParallelFor(tp, loop_len, cost,
|
||||
[span_size, &const_input_broadcaster, &output_tensor, &add_funcs, user_data](std::ptrdiff_t first_span,
|
||||
std::ptrdiff_t last_span) {
|
||||
InputBroadcaster segment_input_broadcaster(const_input_broadcaster);
|
||||
segment_input_broadcaster.AdvanceBy(first_span * span_size);
|
||||
|
||||
OutputBroadcaster segment_output_broadcaster(span_size, output_tensor,
|
||||
first_span * span_size, last_span * span_size);
|
||||
|
||||
BroadcastHelper segment_helper(segment_input_broadcaster, segment_output_broadcaster, user_data);
|
||||
BroadcastLooper(segment_helper, add_funcs);
|
||||
});
|
||||
}
|
||||
|
||||
// Reshape Q from BxSxD to BxNxSxH
|
||||
qkv_with_bias.GetMutable<Tensor>()->Reshape(TensorShape({batch_size, num_heads, sequence_length, head_size}));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status MaybeTransposeToBNSHAndAddBias(OpKernelContext* context, AllocatorPtr allocator,
|
||||
int batch_size, int num_heads, int sequence_length, int head_size,
|
||||
const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out) {
|
||||
auto element_type = DataTypeImpl::GetType<T>();
|
||||
std::vector<int64_t> new_dims({batch_size, num_heads, sequence_length, head_size});
|
||||
gsl::span<const int64_t> new_dims_span{new_dims};
|
||||
TensorShape v_BNLH(new_dims_span);
|
||||
Tensor::InitOrtValue(element_type, v_BNLH, allocator, out);
|
||||
if (bias == nullptr) {
|
||||
std::unique_ptr<Tensor> reshaped;
|
||||
if (in->Shape().GetDims().size() == 3) {
|
||||
reshaped = std::make_unique<Tensor>(in->DataType(), in->Shape(), const_cast<void*>(in->DataRaw()), in->Location());
|
||||
ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(reshaped.get(), batch_size, sequence_length, num_heads, head_size));
|
||||
}
|
||||
ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH((reshaped == nullptr) ? in : reshaped.get(), out));
|
||||
} else {
|
||||
const auto* qkv_bias = bias->Data<T>();
|
||||
if (sequence_length == 1) {
|
||||
ORT_RETURN_IF_ERROR(AddBiasReshape(in, qkv_bias, out, bias_offset, batch_size, sequence_length, num_heads, head_size, num_heads * head_size, context));
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(AddBiasTranspose(in, qkv_bias, out, bias_offset, batch_size, sequence_length, num_heads, head_size, num_heads * head_size, context));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
template Status MaybeTransposeToBNSHAndAddBias<float>(OpKernelContext* context, AllocatorPtr allocator,
|
||||
int batch_size, int num_heads, int sequence_length, int head_size,
|
||||
const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out);
|
||||
|
||||
template <typename T>
|
||||
Status MaybeTransposeToBNSH(AllocatorPtr allocator,
|
||||
int batch_size, int num_heads, int sequence_length, int head_size,
|
||||
const Tensor* in, OrtValue& out) {
|
||||
auto element_type = DataTypeImpl::GetType<T>();
|
||||
std::vector<int64_t> new_dims({batch_size, num_heads, sequence_length, head_size});
|
||||
gsl::span<const int64_t> new_dims_span{new_dims};
|
||||
TensorShape v_BNLH(new_dims_span);
|
||||
Tensor::InitOrtValue(element_type, v_BNLH, allocator, out);
|
||||
std::unique_ptr<Tensor> reshaped;
|
||||
if (in->Shape().GetDims().size() == 3) {
|
||||
reshaped = std::make_unique<Tensor>(in->DataType(), in->Shape(), const_cast<void*>(in->DataRaw()), in->Location());
|
||||
ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(reshaped.get(), batch_size, sequence_length, num_heads, head_size));
|
||||
}
|
||||
ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH((reshaped == nullptr) ? in : reshaped.get(), out));
|
||||
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
template Status MaybeTransposeToBNSH<float>(AllocatorPtr allocator,
|
||||
int batch_size, int num_heads, int sequence_length, int head_size,
|
||||
const Tensor* in, OrtValue& out);
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
61
onnxruntime/contrib_ops/cpu/bert/attention_utils.h
Normal file
61
onnxruntime/contrib_ops/cpu/bert/attention_utils.h
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
#pragma once
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/framework/transpose_helper.h"
|
||||
#include "core/providers/cpu/tensor/reshape_helper.h"
|
||||
#include "core/providers/cpu/math/element_wise_ops.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
// Reshape Q/K/V from BxSxD to BxSxNxH
|
||||
Status Reshape_BSD_to_BSNH(Tensor* qkv,
|
||||
int batch_size,
|
||||
int sequence_length,
|
||||
int num_heads,
|
||||
int head_size);
|
||||
|
||||
// Transpose Q/K/V from BxSxNxH to BxNxSxH
|
||||
Status Transpose_BSNH_to_BNSH(const Tensor* qkv,
|
||||
OrtValue& qkv_transposed,
|
||||
concurrency::ThreadPool* tp = nullptr);
|
||||
|
||||
// Add bias + transpose for each of Q/K/V
|
||||
template <typename T>
|
||||
Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V data - query is BxSxD, key is BxLxD, value is BxLxD_v
|
||||
const T* qkv_bias, // Input: QKV bias - bias is (D + D + D_v)
|
||||
OrtValue& qkv_with_bias_transposed, // Output: Q/K/V data - query is BxNxSxH, key is BxNxLxH, value is BxNxLxH_v
|
||||
int bias_offset, // bias offset to enter qkv_bias
|
||||
int batch_size, // batch size
|
||||
int sequence_length, // sequence_length for Q, kv_sequence_length for K/V
|
||||
int num_heads, // num heads
|
||||
int head_size, // head_size for Q/K, v_head_size for V
|
||||
int hidden_size, // hidden_size for Q/K, v_hidden_size for V
|
||||
OpKernelContext* context); // OpKernelContext
|
||||
|
||||
// Add bias + reshape for each of Q/K/V
|
||||
// This is used in decoder_with_past when the sequence length is 1
|
||||
template <typename T>
|
||||
Status AddBiasReshape(const Tensor* qkv, // Input: Q/K/V data - query is BxSxD, key is BxLxD, value is BxLxD_v
|
||||
const T* qkv_bias, // Input: QKV bias - bias is (D + D + D_v)
|
||||
OrtValue& qkv_with_bias, // Output: Q/K/V data - query is BxNxSxH, key is BxNxLxH, value is BxNxLxH_v
|
||||
int bias_offset, // bias offset to enter qkv_bias
|
||||
int batch_size, // batch size
|
||||
int sequence_length, // sequence_length for Q, kv_sequence_length for K/V
|
||||
int num_heads, // num heads
|
||||
int head_size, // head_size for Q/K, v_head_size for V
|
||||
int hidden_size, // hidden_size for Q/K, v_hidden_size for V
|
||||
OpKernelContext* context); // OpKernelContext
|
||||
|
||||
template <typename T>
|
||||
Status MaybeTransposeToBNSHAndAddBias(OpKernelContext* context, AllocatorPtr allocator,
|
||||
int batch_size, int num_heads, int sequence_length, int head_size,
|
||||
const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out);
|
||||
|
||||
template <typename T>
|
||||
Status MaybeTransposeToBNSH(AllocatorPtr allocator,
|
||||
int batch_size, int num_heads, int sequence_length, int head_size,
|
||||
const Tensor* in, OrtValue& out);
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
276
onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
Normal file
276
onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
Normal file
|
|
@ -0,0 +1,276 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "attention_base.h"
|
||||
#include "attention_helper.h"
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "contrib_ops/cpu/bert/attention_common.h"
|
||||
#include "core/common/safeint.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
class GQAAttentionBase : public AttentionBase {
|
||||
protected:
|
||||
GQAAttentionBase(const OpKernelInfo& info, bool require_same_hidden_size)
|
||||
: AttentionBase(info, require_same_hidden_size) {}
|
||||
|
||||
int local_window_size_;
|
||||
bool do_rotary_;
|
||||
bool rotary_interleaved_;
|
||||
|
||||
template <typename T>
|
||||
Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH
|
||||
const T* K, // K data with shape BxN_kvxSxH
|
||||
const T* V, // V data with shape BxN_kvxSxH
|
||||
const Tensor* past_key, // past K input tensor (if not using past state)
|
||||
const Tensor* past_value, // past V input tensor (if not using past state)
|
||||
Tensor* output, // output tensor
|
||||
Tensor* present_key, // present K output tensor (if separating present KV)
|
||||
Tensor* present_value, // present V output tensor (if separating present KV)
|
||||
const Tensor* seqlens_k, // past sequence lengths tensor
|
||||
GroupQueryAttentionParameters& parameters, // attention parameters
|
||||
AllocatorPtr allocator, // allocator for temporary tensors
|
||||
OpKernelContext* context) const {
|
||||
const int batch_size = parameters.batch_size;
|
||||
const int sequence_length = parameters.sequence_length;
|
||||
const int head_size = parameters.head_size;
|
||||
const int hidden_size = parameters.hidden_size;
|
||||
const bool packed_qkv = parameters.is_packed_qkv;
|
||||
|
||||
auto* tp = context->GetOperatorThreadPool();
|
||||
|
||||
int seqlen_past_kv_cache = 0;
|
||||
if (past_key != nullptr && past_value != nullptr) {
|
||||
seqlen_past_kv_cache = static_cast<int>(past_key->Shape().GetDims()[2]);
|
||||
}
|
||||
int seqlen_present_kv_cache = static_cast<int>(present_key->Shape().GetDims()[2]);
|
||||
|
||||
// Compute the attention score.
|
||||
size_t bytes = SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * sizeof(T);
|
||||
auto attention_probs = allocator->Alloc(bytes);
|
||||
BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator));
|
||||
|
||||
void* mask_data = nullptr;
|
||||
size_t mask_data_bytes = SafeInt<size_t>(batch_size) * sequence_length * seqlen_present_kv_cache * sizeof(T);
|
||||
mask_data = allocator->Alloc(mask_data_bytes);
|
||||
memset(mask_data, 0, mask_data_bytes);
|
||||
BufferUniquePtr mask_data_buffer(mask_data, BufferDeleter(allocator));
|
||||
|
||||
const T* past_key_data = past_key != nullptr ? past_key->Data<T>() : nullptr;
|
||||
T* present_key_data = present_key != nullptr ? present_key->MutableData<T>() : nullptr;
|
||||
const T* past_value_data = past_value != nullptr ? past_value->Data<T>() : nullptr;
|
||||
T* present_value_data = present_value != nullptr ? present_value->MutableData<T>() : nullptr;
|
||||
|
||||
bool past_present_share_buffer = past_key_data == present_key_data && past_value_data == present_value_data;
|
||||
|
||||
const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K;
|
||||
ComputeAttentionProbs<T>(static_cast<T*>(attention_probs), Q, k,
|
||||
seqlens_k->Data<int32_t>(), static_cast<T*>(mask_data),
|
||||
batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache,
|
||||
head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, tp);
|
||||
|
||||
// Compute the attentionScore * Value: out_tmp(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
|
||||
auto out_tmp_data =
|
||||
allocator->Alloc(SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * head_size * sizeof(T));
|
||||
BufferUniquePtr out_tmp_buffer(out_tmp_data, BufferDeleter(allocator));
|
||||
|
||||
const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V;
|
||||
ComputeVxAttentionScore(output->MutableData<T>(), static_cast<T*>(out_tmp_data), static_cast<T*>(attention_probs),
|
||||
v, seqlens_k->Data<int32_t>(), batch_size, sequence_length, seqlen_past_kv_cache,
|
||||
seqlen_present_kv_cache, head_size, hidden_size, past_value_data, present_value_data,
|
||||
past_present_share_buffer, packed_qkv, tp);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
// Helper function to compute the attention probs. It does 2 things:
|
||||
// attention_probs(B, N, S, T) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, T, H -> B, N, H, T) +
|
||||
// 1 x mask_data(B, N, S, T)
|
||||
// attention_probs(B, N, S, T) = Softmax(attention_probs)
|
||||
template <typename T>
|
||||
void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT
|
||||
const T* Q, // Q data. Its size is BxNxSxH
|
||||
const T* K, // k data. Its size is BxNxLxH
|
||||
const int32_t* seqlens_k, // past sequence lengths tensor
|
||||
T* mask_data, // buffer for mask data.
|
||||
int batch_size, // batch size of self-attention
|
||||
int sequence_length, // sequence length of self-attention (S)
|
||||
int past_buffer_sequence_length, // sequence length of past state
|
||||
int present_buffer_sequence_length, // sequence length of present state
|
||||
int head_size, // head size of self-attention
|
||||
const T* past_key, // past key only (if not using past state)
|
||||
T* present_key, // present key only (if not using present state)
|
||||
bool past_present_share_buffer, // whether present key and value share the same buffer
|
||||
bool packed_qkv, // whether Q, K, V are packed
|
||||
ThreadPool* tp) const { // thread pool
|
||||
const bool is_prompt = sequence_length != 1;
|
||||
const int packed_batch_stride = packed_qkv ? (num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : 0;
|
||||
const int kv_num_heads_factor = num_heads_ / kv_num_heads_;
|
||||
const size_t q_input_chunk_length = static_cast<size_t>(sequence_length) * head_size; // S x H
|
||||
const size_t kv_input_chunk_length = static_cast<size_t>(sequence_length) * head_size; // L x H
|
||||
const size_t past_buff_chunk_length = static_cast<size_t>(past_buffer_sequence_length) * head_size; // L x H
|
||||
const size_t present_buff_chunk_length = static_cast<size_t>(present_buffer_sequence_length) * head_size; // T x H
|
||||
|
||||
PrepareMaskGQA(mask_data, batch_size, sequence_length, present_buffer_sequence_length, local_window_size_, seqlens_k);
|
||||
|
||||
const int loop_len = batch_size * num_heads_;
|
||||
const float alpha = scale_ == 0.0f ? 1.0f / sqrt(static_cast<float>(head_size)) : scale_;
|
||||
|
||||
// TODO: cost might differ for gqa because of right padding and total_sequence_length being sequence dependent
|
||||
TensorOpCost unit_cost;
|
||||
const size_t probs_matrix_bytes = SafeInt<size_t>(sequence_length) * present_buffer_sequence_length * sizeof(T);
|
||||
unit_cost.compute_cycles = static_cast<double>(2 * sequence_length * head_size * present_buffer_sequence_length);
|
||||
unit_cost.bytes_loaded = static_cast<double>((sequence_length + present_buffer_sequence_length) * head_size * sizeof(T));
|
||||
unit_cost.bytes_stored = static_cast<double>(probs_matrix_bytes);
|
||||
|
||||
unit_cost.bytes_loaded += static_cast<double>(probs_matrix_bytes);
|
||||
unit_cost.bytes_stored += static_cast<double>(probs_matrix_bytes);
|
||||
|
||||
if (present_key) {
|
||||
double bytes_to_copy_key = static_cast<double>(sizeof(T) * present_buff_chunk_length);
|
||||
unit_cost.bytes_loaded += bytes_to_copy_key;
|
||||
unit_cost.bytes_stored += bytes_to_copy_key;
|
||||
}
|
||||
|
||||
ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
|
||||
for (std::ptrdiff_t i = begin; i != end; ++i) {
|
||||
const int batch_index = static_cast<int>(i) / num_heads_;
|
||||
const int head_index = static_cast<int>(i) % num_heads_;
|
||||
const int past_seqlen = sequence_length == 1 ? static_cast<int>(seqlens_k[batch_index]) : past_buffer_sequence_length;
|
||||
const size_t past_chunk_length = static_cast<size_t>(past_seqlen) * head_size;
|
||||
|
||||
const int output_offset = static_cast<int>(i) * sequence_length * present_buffer_sequence_length;
|
||||
const int mask_offset = batch_index * sequence_length * present_buffer_sequence_length;
|
||||
T* output = attention_probs + output_offset;
|
||||
|
||||
// Broadcast mask data: (Bx)SxT -> (BxNx)SxT
|
||||
// TODO: mask after present_sequence_length
|
||||
memcpy(output,
|
||||
mask_data + mask_offset,
|
||||
probs_matrix_bytes);
|
||||
|
||||
const T* k;
|
||||
if (packed_qkv) {
|
||||
k = K + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor);
|
||||
} else {
|
||||
k = K + kv_input_chunk_length * (i / kv_num_heads_factor);
|
||||
}
|
||||
if (nullptr != present_key) {
|
||||
k = ConcatStateChunkGQA(past_key, k, present_key, present_buff_chunk_length, past_buff_chunk_length,
|
||||
past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer,
|
||||
i / kv_num_heads_factor);
|
||||
}
|
||||
|
||||
// TODO: CblasTrans stuff what do?
|
||||
// Compute Q*K' + AttentionMask
|
||||
// original transposed each iteration
|
||||
// A: Q (B x N x) S x H (B x N x) S x H S x H
|
||||
// B: K' (B x N x) T x H (B x N x) H x T H x T
|
||||
// C: attention_probs (B x N x) S x T (B x N x) S x T S x T
|
||||
const T* q;
|
||||
if (packed_qkv) {
|
||||
q = Q + packed_batch_stride * batch_index + q_input_chunk_length * head_index;
|
||||
} else {
|
||||
q = Q + q_input_chunk_length * i;
|
||||
}
|
||||
math::Gemm<T, ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, present_buffer_sequence_length, head_size, alpha,
|
||||
q, k, mask_data != nullptr ? 1.0f : 0.0f, output, nullptr);
|
||||
}
|
||||
});
|
||||
|
||||
// attention_probs(B, N, S, T) = Softmax(attention_probs)
|
||||
const int N = batch_size * num_heads_ * sequence_length;
|
||||
const int D = present_buffer_sequence_length;
|
||||
ComputeAttentionSoftmaxInplace(attention_probs, N, D, tp);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ComputeVxAttentionScore(T* output, // buffer for the result with size BxSxNxH
|
||||
T* tmp_buffer, // buffer for temp use with size is BxNxSxH
|
||||
const T* attention_probs, // Attention probs with size BxNxSxT
|
||||
const T* V, // V value with size BxN_kvxSxH
|
||||
const int32_t* seqlens_k, // past sequence lengths tensor
|
||||
int batch_size, // batch size
|
||||
int sequence_length, // sequence length
|
||||
int past_buffer_sequence_length, // sequence length in past state
|
||||
int present_buffer_sequence_length, // sequence length in past state
|
||||
int head_size, // head size of Q, K, V
|
||||
int hidden_size, // hidden size of Output
|
||||
const T* past_value, // past value only (if not using past state)
|
||||
T* present_value, // present value only (if not using present state)
|
||||
bool past_present_share_buffer, // whether present key and value share the same buffer
|
||||
bool packed_qkv, // whether Q, K, V are packed
|
||||
ThreadPool* tp) const {
|
||||
const bool is_prompt = sequence_length != 1;
|
||||
const int packed_batch_stride = packed_qkv ? (num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : 0;
|
||||
const int kv_num_heads_factor = num_heads_ / kv_num_heads_;
|
||||
// TODO: what is with these being ptrdiff_t?
|
||||
const ptrdiff_t q_input_chunk_length = SafeInt<ptrdiff_t>(sequence_length) * head_size; // S x H
|
||||
const ptrdiff_t kv_input_chunk_length = SafeInt<ptrdiff_t>(sequence_length) * head_size; // L x H
|
||||
const size_t past_buff_chunk_length = static_cast<size_t>(past_buffer_sequence_length) * head_size; // L x H
|
||||
const size_t present_buff_chunk_length = static_cast<size_t>(present_buffer_sequence_length) * head_size; // T x H
|
||||
|
||||
// The cost of Gemm
|
||||
TensorOpCost unit_cost;
|
||||
unit_cost.compute_cycles = static_cast<double>(2 * sequence_length * head_size * present_buffer_sequence_length);
|
||||
unit_cost.bytes_loaded = static_cast<double>((sequence_length + head_size) * present_buffer_sequence_length * sizeof(T));
|
||||
unit_cost.bytes_stored = static_cast<double>(sequence_length * head_size * sizeof(T));
|
||||
|
||||
if (present_value) {
|
||||
double bytes_to_copy_value = static_cast<double>(present_buff_chunk_length * sizeof(T));
|
||||
unit_cost.bytes_loaded += bytes_to_copy_value;
|
||||
unit_cost.bytes_stored += bytes_to_copy_value;
|
||||
}
|
||||
|
||||
const size_t bytes_to_copy_trans = SafeInt<size_t>(head_size) * sizeof(T);
|
||||
double bytes_to_copy_trans_all = static_cast<double>(sequence_length * bytes_to_copy_trans);
|
||||
unit_cost.bytes_loaded += bytes_to_copy_trans_all;
|
||||
unit_cost.bytes_stored += bytes_to_copy_trans_all;
|
||||
|
||||
ThreadPool::TryParallelFor(tp, SafeInt<ptrdiff_t>(batch_size) * num_heads_, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
|
||||
for (std::ptrdiff_t i = begin; i != end; ++i) {
|
||||
const int batch_index = static_cast<int>(i / num_heads_);
|
||||
const int head_index = static_cast<int>(i % num_heads_);
|
||||
const int past_seqlen = sequence_length == 1 ? static_cast<int>(seqlens_k[batch_index]) : past_buffer_sequence_length;
|
||||
const size_t past_chunk_length = static_cast<size_t>(past_seqlen) * head_size;
|
||||
|
||||
const T* v;
|
||||
if (packed_qkv) {
|
||||
v = V + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor);
|
||||
} else {
|
||||
v = V + kv_input_chunk_length * (i / kv_num_heads_factor);
|
||||
}
|
||||
if (nullptr != present_value) {
|
||||
v = ConcatStateChunkGQA(past_value, v, present_value, present_buff_chunk_length, past_buff_chunk_length,
|
||||
past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer,
|
||||
i / kv_num_heads_factor);
|
||||
}
|
||||
|
||||
T* current_tmp_data = reinterpret_cast<T*>(tmp_buffer) + q_input_chunk_length * i;
|
||||
ptrdiff_t attention_probs_offset = SafeInt<ptrdiff_t>(sequence_length) * present_buffer_sequence_length * i;
|
||||
math::MatMul<T>(sequence_length, head_size, present_buffer_sequence_length,
|
||||
attention_probs + attention_probs_offset,
|
||||
v, current_tmp_data, nullptr);
|
||||
|
||||
// Transpose: out(B, S, N, H_v) -> out_tmp(B, N, S, H_v)
|
||||
T* src = current_tmp_data;
|
||||
ptrdiff_t dest_offset = (SafeInt<ptrdiff_t>(batch_index) * sequence_length * num_heads_ + head_index) * head_size;
|
||||
T* dest = output + dest_offset;
|
||||
for (int j = 0; j < sequence_length; j++) {
|
||||
memcpy(dest, src, bytes_to_copy_trans);
|
||||
src += head_size;
|
||||
dest += hidden_size;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
192
onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
Normal file
192
onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
Normal file
|
|
@ -0,0 +1,192 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "group_query_attention.h"
|
||||
#include "group_query_attention_helper.h"
|
||||
#include "attention_utils.h"
|
||||
#include "rotary_embedding.h"
|
||||
#include "rotary_embedding_helper.h"
|
||||
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/graph/onnx_protobuf.h"
|
||||
#include "core/common/safeint.h"
|
||||
#include "core/platform/threadpool.h"
|
||||
|
||||
#include <unsupported/Eigen/SpecialFunctions>
|
||||
#include <vector>
|
||||
|
||||
using onnxruntime::concurrency::ThreadPool;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
// These ops are internal-only, so register outside of onnx
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX(
|
||||
GroupQueryAttention,
|
||||
kMSDomain,
|
||||
1,
|
||||
float,
|
||||
kCpuExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
|
||||
.TypeConstraint("M", DataTypeImpl::GetTensorType<int32_t>()),
|
||||
GroupQueryAttention<float>);
|
||||
|
||||
template <typename T>
|
||||
GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info) : OpKernel(info), GQAAttentionBase(info, false) {
|
||||
int64_t num_heads = 0;
|
||||
int64_t kv_num_heads = 0;
|
||||
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
|
||||
ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0);
|
||||
num_heads_ = static_cast<int>(num_heads);
|
||||
kv_num_heads_ = static_cast<int>(kv_num_heads);
|
||||
|
||||
mask_filter_value_ = info.GetAttrOrDefault<float>("mask_filter_value", -10000.0f);
|
||||
local_window_size_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("local_window_size", -1));
|
||||
do_rotary_ = info.GetAttrOrDefault<int64_t>("do_rotary", 0) == 1;
|
||||
rotary_interleaved_ = info.GetAttrOrDefault<int64_t>("rotary_interleaved", 0) == 1;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
|
||||
const Tensor* query = context->Input<Tensor>(0);
|
||||
const Tensor* key = context->Input<Tensor>(1);
|
||||
const Tensor* value = context->Input<Tensor>(2);
|
||||
const Tensor* past_key = context->Input<Tensor>(3);
|
||||
const Tensor* past_value = context->Input<Tensor>(4);
|
||||
const Tensor* seqlens_k = context->Input<Tensor>(5);
|
||||
const Tensor* total_seqlen = context->Input<Tensor>(6);
|
||||
const Tensor* cos_cache = context->Input<Tensor>(7);
|
||||
const Tensor* sin_cache = context->Input<Tensor>(8);
|
||||
|
||||
GroupQueryAttentionParameters parameters = {};
|
||||
constexpr float scale = 1.0f;
|
||||
ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query,
|
||||
key,
|
||||
value,
|
||||
past_key,
|
||||
past_value,
|
||||
cos_cache,
|
||||
sin_cache,
|
||||
¶meters,
|
||||
num_heads_,
|
||||
kv_num_heads_,
|
||||
seqlens_k,
|
||||
total_seqlen,
|
||||
scale));
|
||||
|
||||
const int batch_size = parameters.batch_size;
|
||||
const int sequence_length = parameters.sequence_length;
|
||||
const int present_kv_seqlen = parameters.seqlen_present_kv_cache;
|
||||
int head_size = parameters.head_size;
|
||||
int q_hidden_size = parameters.hidden_size;
|
||||
const bool packed_qkv = parameters.is_packed_qkv;
|
||||
|
||||
std::vector<int64_t> output_shape(3);
|
||||
output_shape[0] = static_cast<int64_t>(batch_size);
|
||||
output_shape[1] = static_cast<int64_t>(sequence_length);
|
||||
output_shape[2] = static_cast<int64_t>(q_hidden_size);
|
||||
Tensor* output = context->Output(0, output_shape);
|
||||
|
||||
std::vector<int64_t> present_k_shape({static_cast<int64_t>(batch_size), static_cast<int64_t>(kv_num_heads_), static_cast<int64_t>(present_kv_seqlen), static_cast<int64_t>(head_size)});
|
||||
std::vector<int64_t> present_v_shape({static_cast<int64_t>(batch_size), static_cast<int64_t>(kv_num_heads_), static_cast<int64_t>(present_kv_seqlen), static_cast<int64_t>(head_size)});
|
||||
Tensor* present_k = context->Output(1, present_k_shape);
|
||||
Tensor* present_v = context->Output(2, present_v_shape);
|
||||
|
||||
AllocatorPtr allocator;
|
||||
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
|
||||
|
||||
auto element_type = DataTypeImpl::GetType<T>();
|
||||
OrtValue Q;
|
||||
OrtValue K;
|
||||
OrtValue V;
|
||||
if (packed_qkv) {
|
||||
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH<T>(
|
||||
allocator, batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size, query, Q));
|
||||
} else if (sequence_length > 1) {
|
||||
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH<T>(
|
||||
allocator, batch_size, num_heads_, sequence_length, head_size, query, Q));
|
||||
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH<T>(
|
||||
allocator, batch_size, kv_num_heads_, sequence_length, head_size, key, K));
|
||||
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH<T>(
|
||||
allocator, batch_size, kv_num_heads_, sequence_length, head_size, value, V));
|
||||
} else {
|
||||
Tensor::InitOrtValue(std::move(const_cast<Tensor&>(*query)), Q);
|
||||
Tensor::InitOrtValue(std::move(const_cast<Tensor&>(*key)), K);
|
||||
Tensor::InitOrtValue(std::move(const_cast<Tensor&>(*value)), V);
|
||||
}
|
||||
|
||||
if (do_rotary_) {
|
||||
rotary_embedding_helper::RotaryParameters rotary_params = {};
|
||||
rotary_params.batch_size = batch_size;
|
||||
rotary_params.sequence_length = sequence_length;
|
||||
rotary_params.hidden_size = q_hidden_size;
|
||||
rotary_params.head_size = head_size;
|
||||
rotary_params.rotary_embedding_dim = parameters.rotary_dim;
|
||||
rotary_params.num_heads = num_heads_;
|
||||
rotary_params.max_sequence_length = sequence_length; // unused
|
||||
rotary_params.seq_stride = head_size;
|
||||
rotary_params.head_stride = sequence_length * rotary_params.seq_stride;
|
||||
rotary_params.batch_stride = (packed_qkv ? (num_heads_ + 2 * kv_num_heads_) : num_heads_) * rotary_params.head_stride;
|
||||
rotary_params.position_ids_format = sequence_length == 1 ? 1 : 0;
|
||||
rotary_params.transposed = true;
|
||||
auto* tp = context->GetOperatorThreadPool();
|
||||
std::vector<int64_t> pos_ids(sequence_length == 1 ? batch_size : 1);
|
||||
if (sequence_length == 1) {
|
||||
for (int b = 0; b < batch_size; b++) {
|
||||
pos_ids[b] = static_cast<int64_t>(seqlens_k->Data<int32_t>()[b]);
|
||||
}
|
||||
} else {
|
||||
pos_ids[0] = static_cast<int64_t>(0);
|
||||
}
|
||||
const T* q_input;
|
||||
const T* k_input;
|
||||
T* q_rotary;
|
||||
T* k_rotary;
|
||||
if (packed_qkv) {
|
||||
OrtValue RotaryQKV;
|
||||
Tensor::InitOrtValue(element_type, TensorShape({batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size}), allocator, RotaryQKV);
|
||||
q_input = Q.Get<Tensor>().Data<T>();
|
||||
k_input = q_input + num_heads_ * sequence_length * head_size;
|
||||
q_rotary = RotaryQKV.GetMutable<Tensor>()->MutableData<T>();
|
||||
k_rotary = q_rotary + num_heads_ * sequence_length * head_size;
|
||||
Q = RotaryQKV;
|
||||
} else {
|
||||
OrtValue RotaryQ;
|
||||
Tensor::InitOrtValue(element_type, TensorShape({batch_size, num_heads_, sequence_length, head_size}), allocator, RotaryQ);
|
||||
OrtValue RotaryK;
|
||||
Tensor::InitOrtValue(element_type, TensorShape({batch_size, kv_num_heads_, sequence_length, head_size}), allocator, RotaryK);
|
||||
q_input = Q.Get<Tensor>().Data<T>();
|
||||
k_input = K.Get<Tensor>().Data<T>();
|
||||
q_rotary = RotaryQ.GetMutable<Tensor>()->MutableData<T>();
|
||||
k_rotary = RotaryK.GetMutable<Tensor>()->MutableData<T>();
|
||||
Q = RotaryQ;
|
||||
K = RotaryK;
|
||||
}
|
||||
ORT_RETURN_IF_ERROR(RunRotaryEmbedding<T>(tp, rotary_params, q_input,
|
||||
pos_ids.data(), cos_cache->Data<T>(),
|
||||
sin_cache->Data<T>(), q_rotary, rotary_interleaved_));
|
||||
|
||||
rotary_params.num_heads = kv_num_heads_;
|
||||
rotary_params.hidden_size = parameters.kv_hidden_size;
|
||||
if (!packed_qkv) {
|
||||
rotary_params.batch_stride = kv_num_heads_ * rotary_params.head_stride;
|
||||
}
|
||||
ORT_RETURN_IF_ERROR(RunRotaryEmbedding<T>(tp, rotary_params, k_input,
|
||||
pos_ids.data(), cos_cache->Data<T>(),
|
||||
sin_cache->Data<T>(), k_rotary, rotary_interleaved_));
|
||||
if (packed_qkv) {
|
||||
const T* v_input = k_input + kv_num_heads_ * sequence_length * head_size;
|
||||
T* v_rotary = k_rotary + kv_num_heads_ * sequence_length * head_size;
|
||||
ORT_RETURN_IF_ERROR(group_query_attention_helper::PackVIntoRotaryQKV<T>(tp, parameters, v_input, v_rotary));
|
||||
}
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
|
||||
// Compute the attention score and apply the score to V
|
||||
return ApplyAttention(Q.Get<Tensor>().Data<T>(), packed_qkv ? nullptr : K.Get<Tensor>().Data<T>(),
|
||||
packed_qkv ? nullptr : V.Get<Tensor>().Data<T>(), past_key, past_value, output, present_k, present_v,
|
||||
seqlens_k, parameters, allocator, context);
|
||||
}
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
21
onnxruntime/contrib_ops/cpu/bert/group_query_attention.h
Normal file
21
onnxruntime/contrib_ops/cpu/bert/group_query_attention.h
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "gqa_attention_base.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
template <typename T>
|
||||
class GroupQueryAttention final : public OpKernel, public GQAAttentionBase {
|
||||
public:
|
||||
GroupQueryAttention(const OpKernelInfo& info);
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
299
onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
Normal file
299
onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
Normal file
|
|
@ -0,0 +1,299 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "contrib_ops/cpu/bert/attention_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace group_query_attention_helper {
|
||||
|
||||
Status CheckInputs(const Tensor* query,
|
||||
const Tensor* key,
|
||||
const Tensor* value,
|
||||
const Tensor* past_key,
|
||||
const Tensor* past_value,
|
||||
const Tensor* cos_cache,
|
||||
const Tensor* sin_cache,
|
||||
void* parameters,
|
||||
int num_heads,
|
||||
int kv_num_heads,
|
||||
const Tensor* seqlens_k,
|
||||
const Tensor* total_seqlen,
|
||||
float scale) {
|
||||
// Note: Here S* is seqlen_past_kv_cache, S+ is seqlen_present_kv_cache
|
||||
// past_key : (B, N_k, S*, H) or (B, N_k, S+, H) or nullptr
|
||||
// past_value : (B, N_k, S*, H) or (B, N_k, S+, H) or nullptr
|
||||
// no packing for q/k/v:
|
||||
// query (Q) : (B, S, D) or (B, S, (D_q + 2 D_kv))
|
||||
// key (K) : (B, S, D_kv) or nullptr
|
||||
// value (V) : (B, S, D_kv) or nullptr
|
||||
|
||||
AttentionQkvFormat qkv_format = Q_K_V_BSNH;
|
||||
AttentionQkvFormat past_kv_format = Q_K_V_BNSH;
|
||||
const bool is_packed_qkv = key == nullptr;
|
||||
|
||||
const auto& query_dims = query->Shape().GetDims();
|
||||
if (query_dims.size() != 3) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ",
|
||||
query_dims.size());
|
||||
}
|
||||
|
||||
int batch_size = static_cast<int>(query_dims[0]);
|
||||
int sequence_length = static_cast<int>(query_dims[1]);
|
||||
int q_hidden_size = static_cast<int>(query_dims[2]);
|
||||
int head_size = 0;
|
||||
|
||||
if (num_heads % kv_num_heads != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ",
|
||||
num_heads % kv_num_heads);
|
||||
}
|
||||
|
||||
int kv_hidden_size = 0;
|
||||
// Check key and value when not packed
|
||||
if (!is_packed_qkv) {
|
||||
head_size = static_cast<int>(q_hidden_size) / num_heads;
|
||||
if (head_size % 8 != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"head_size must be a multiple of 8. Got head_size % 8 == ",
|
||||
head_size % 8);
|
||||
}
|
||||
if (value == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv.");
|
||||
}
|
||||
const auto& key_dims = key->Shape().GetDims();
|
||||
if (key_dims.size() != 3) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ",
|
||||
key_dims.size());
|
||||
} else if (query_dims[0] != key_dims[0]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'query' and 'key' shall have same dim 0 (batch size)");
|
||||
} else if (query_dims[1] != key_dims[1]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'query' and 'key' shall have same dim 1 (sequence length)");
|
||||
}
|
||||
kv_hidden_size = static_cast<int>(key_dims[2]);
|
||||
const auto& value_dims = value->Shape().GetDims();
|
||||
if (value_dims.size() != 3) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ",
|
||||
value_dims.size());
|
||||
} else if (query_dims[0] != value_dims[0]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'query' and 'value' shall have same dim 0 (batch size)");
|
||||
} else if (query_dims[1] != value_dims[1]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'query' and 'value' shall have same dim 1 (sequence length)");
|
||||
} else if (value_dims[2] != kv_hidden_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key.");
|
||||
}
|
||||
} else {
|
||||
// Check packed qkv
|
||||
head_size = static_cast<int>(q_hidden_size) / (num_heads + 2 * kv_num_heads);
|
||||
if (head_size % 8 != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"head_size must be a multiple of 8. Got head_size % 8 == ",
|
||||
head_size % 8);
|
||||
}
|
||||
if (value != nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv.");
|
||||
}
|
||||
q_hidden_size = head_size * num_heads;
|
||||
kv_hidden_size = head_size * kv_num_heads;
|
||||
}
|
||||
|
||||
// Check past-present KV
|
||||
int32_t past_sequence_length = 0;
|
||||
if (past_key != nullptr && past_value != nullptr) {
|
||||
const auto& past_key_dims = past_key->Shape().GetDims();
|
||||
const auto& past_value_dims = past_value->Shape().GetDims();
|
||||
|
||||
if (past_key_dims.size() != 4) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'past_key' is expected to have 4 dimensions, got ",
|
||||
past_key_dims.size());
|
||||
}
|
||||
if (past_value_dims.size() != 4) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'past_value' is expected to have 4 dimensions, got ",
|
||||
past_value_dims.size());
|
||||
}
|
||||
|
||||
if (past_key_dims[0] != batch_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'past_key' dimension 0 should be batch_size, got ",
|
||||
past_key_dims[0]);
|
||||
}
|
||||
if (past_value_dims[0] != batch_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'past_value' dimension 0 should be batch_size, got ",
|
||||
past_value_dims[0]);
|
||||
}
|
||||
|
||||
if (past_key_dims[2] != past_value_dims[2]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"BNSH Input 'past_key' and 'past_value' should have same dimension 2 (max sequence"
|
||||
"length or past sequence length), got ",
|
||||
past_key_dims[1]);
|
||||
}
|
||||
if (past_key_dims[1] != kv_num_heads) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'past_key' shall have kv_num_heads");
|
||||
}
|
||||
if (past_value_dims[1] != kv_num_heads) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'past_value' shall have kv_num_heads");
|
||||
}
|
||||
// We assume all sequence in past kv are right-padded to max or past sequence length
|
||||
past_sequence_length = static_cast<int>(past_key_dims[2]);
|
||||
|
||||
if (past_key_dims[3] != head_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'past_key' dimension 3 should be same as head_size, got ",
|
||||
past_key_dims[3]);
|
||||
}
|
||||
if (past_value_dims[3] != head_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'past_value' dimension 3 should be same as head_size, got ",
|
||||
past_value_dims[3]);
|
||||
}
|
||||
} else if (past_key != nullptr || past_value != nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'past_key' and 'past_value' shall be both present or both absent.");
|
||||
}
|
||||
|
||||
// Check seqlens_k tensor (holding past seqlen for token gen)
|
||||
const auto& seqlens_dim = seqlens_k->Shape().GetDims();
|
||||
if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"seqlens_k must be shape (batch_size).");
|
||||
}
|
||||
|
||||
// Set present sequence length and kv_share_buffer from input total_seqlen tensor
|
||||
if (!onnxruntime::IsScalarOr1ElementVector(total_seqlen)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"total_sequence_length tensor must be of one element.");
|
||||
}
|
||||
int total_sequence_length = *((*total_seqlen).template Data<int32_t>());
|
||||
int present_sequence_length = std::max(total_sequence_length, past_sequence_length);
|
||||
|
||||
int rotary_dim = 0;
|
||||
if (cos_cache != nullptr && sin_cache != nullptr) {
|
||||
const auto& cos_dims = cos_cache->Shape().GetDims();
|
||||
const auto& sin_dims = sin_cache->Shape().GetDims();
|
||||
|
||||
if (head_size % 16 != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"head_size shall be a multiple of 16. Got head_size % 16 == ",
|
||||
head_size % 16);
|
||||
}
|
||||
if (cos_dims[0] < present_sequence_length) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"cos_cache dimension 0 should be of max_sequence_length.");
|
||||
}
|
||||
if (sin_dims[0] < present_sequence_length) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"sin_cache dimension 0 should be of max_sequence_length.");
|
||||
}
|
||||
if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
|
||||
}
|
||||
if (sin_dims[1] > (head_size / 16) * 8 || sin_dims[1] % 8 != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
|
||||
}
|
||||
if (cos_dims[1] != sin_dims[1]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"cos_cache and sin_cache dimension 1 must be the same.");
|
||||
}
|
||||
rotary_dim = static_cast<int>(cos_dims[1] * 2);
|
||||
} else if (cos_cache != nullptr || sin_cache != nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'cos_cache' and 'sin_cache' shall be both present or both absent.");
|
||||
}
|
||||
|
||||
bool is_prompt = sequence_length != 1;
|
||||
|
||||
if (parameters != nullptr) {
|
||||
GroupQueryAttentionParameters* output_parameters = reinterpret_cast<GroupQueryAttentionParameters*>(parameters);
|
||||
output_parameters->batch_size = batch_size;
|
||||
output_parameters->sequence_length = sequence_length; // sequence length of Q
|
||||
output_parameters->seqlen_past_kv_cache = past_sequence_length; // max sequence length of past kv tensors
|
||||
output_parameters->seqlen_present_kv_cache = present_sequence_length; // max sequence length of present kv tensors
|
||||
output_parameters->hidden_size = q_hidden_size;
|
||||
output_parameters->num_heads = num_heads;
|
||||
output_parameters->head_size = head_size;
|
||||
output_parameters->kv_hidden_size = kv_hidden_size;
|
||||
output_parameters->kv_num_heads = kv_num_heads;
|
||||
output_parameters->rotary_dim = rotary_dim;
|
||||
output_parameters->is_packed_qkv = is_packed_qkv;
|
||||
output_parameters->is_unidirectional = true;
|
||||
output_parameters->is_prompt = is_prompt;
|
||||
output_parameters->scale = scale;
|
||||
output_parameters->qkv_format = qkv_format;
|
||||
output_parameters->past_kv_format = past_kv_format;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckInputs(const Tensor* query,
|
||||
const Tensor* key,
|
||||
const Tensor* value,
|
||||
const Tensor* past_key,
|
||||
const Tensor* past_value,
|
||||
const Tensor* cos_cache,
|
||||
const Tensor* sin_cache,
|
||||
void* parameters,
|
||||
int num_heads,
|
||||
int kv_num_heads,
|
||||
const Tensor* seqlens_k,
|
||||
const Tensor* total_seqlen,
|
||||
float scale,
|
||||
int max_threads_per_block) {
|
||||
if (max_threads_per_block > 0 && num_heads > max_threads_per_block) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block);
|
||||
}
|
||||
|
||||
return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, scale);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status PackVIntoRotaryQKV(concurrency::ThreadPool* tp, GroupQueryAttentionParameters parameters, const T* input,
|
||||
T* output) {
|
||||
const int batch_size = parameters.batch_size;
|
||||
const int sequence_length = parameters.sequence_length;
|
||||
int n_heads = parameters.num_heads;
|
||||
const int kv_n_heads = parameters.kv_num_heads;
|
||||
const int head_size = parameters.head_size;
|
||||
int seq_stride = head_size;
|
||||
int head_stride = sequence_length * seq_stride;
|
||||
int batch_stride = (n_heads + 2 * kv_n_heads) * head_stride;
|
||||
|
||||
const int loop_len = batch_size * sequence_length * kv_n_heads;
|
||||
const double cost = static_cast<double>(head_size);
|
||||
ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
|
||||
for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) {
|
||||
const int b = static_cast<int>((ptr / kv_n_heads) / sequence_length);
|
||||
const int s = static_cast<int>((ptr / kv_n_heads) % sequence_length);
|
||||
const int n = static_cast<int>(ptr % kv_n_heads);
|
||||
const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;
|
||||
const T* input_data = input + block_offset;
|
||||
T* output_data = output + block_offset;
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
output_data[i] = input_data[i];
|
||||
}
|
||||
}
|
||||
});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace group_query_attention_helper
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -4,15 +4,13 @@
|
|||
#include "attention_cpu_base.h"
|
||||
#include "multihead_attention.h"
|
||||
#include "multihead_attention_helper.h"
|
||||
#include "attention_utils.h"
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/framework/transpose_helper.h"
|
||||
#include "core/graph/onnx_protobuf.h"
|
||||
#include "core/common/safeint.h"
|
||||
#include "core/platform/threadpool.h"
|
||||
#include "core/providers/cpu/math/element_wise_ops.h"
|
||||
#include "core/providers/cpu/tensor/reshape_helper.h"
|
||||
|
||||
#include <unsupported/Eigen/SpecialFunctions>
|
||||
#include <vector>
|
||||
|
|
@ -43,217 +41,6 @@ MultiHeadAttention<T>::MultiHeadAttention(const OpKernelInfo& info) : OpKernel(i
|
|||
is_unidirectional_ = info.GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;
|
||||
}
|
||||
|
||||
// Reshape Q/K/V from BxSxD to BxSxNxH
|
||||
Status Reshape_BSD_to_BSNH(Tensor* qkv,
|
||||
int batch_size,
|
||||
int sequence_length,
|
||||
int num_heads,
|
||||
int head_size) {
|
||||
std::vector<int64_t> reshape_dims({batch_size, sequence_length, num_heads, head_size});
|
||||
gsl::span<const int64_t> reshape_dims_span{reshape_dims};
|
||||
TensorShape qkv_bsnh(reshape_dims_span);
|
||||
qkv->Reshape(qkv_bsnh);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Transpose Q/K/V from BxSxNxH to BxNxSxH
|
||||
Status Transpose_BSNH_to_BNSH(const Tensor* qkv,
|
||||
OrtValue& qkv_transposed,
|
||||
concurrency::ThreadPool* tp = nullptr) {
|
||||
std::vector<size_t> permutations({0, 2, 1, 3});
|
||||
gsl::span<const size_t> permutations_span{permutations};
|
||||
size_t from = 2, to = 1;
|
||||
SingleAxisTranspose(permutations_span, *qkv, *qkv_transposed.GetMutable<Tensor>(), from, to, nullptr, tp);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Add bias + transpose for each of Q/K/V
|
||||
template <typename T>
|
||||
Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V data - query is BxSxD, key is BxLxD, value is BxLxD_v
|
||||
const T* qkv_bias, // Input: QKV bias - bias is (D + D + D_v)
|
||||
OrtValue& qkv_with_bias_transposed, // Output: Q/K/V data - query is BxNxSxH, key is BxNxLxH, value is BxNxLxH_v
|
||||
int bias_offset, // bias offset to enter qkv_bias
|
||||
int batch_size, // batch size
|
||||
int sequence_length, // sequence_length for Q, kv_sequence_length for K/V
|
||||
int num_heads, // num heads
|
||||
int head_size, // head_size for Q/K, v_head_size for V
|
||||
int hidden_size, // hidden_size for Q/K, v_hidden_size for V
|
||||
OpKernelContext* context) {
|
||||
// Note: the comments below will refer to Q's dimensions for simplicity
|
||||
auto element_type = DataTypeImpl::GetType<T>();
|
||||
constexpr size_t element_size = sizeof(T);
|
||||
ProcessBroadcastSpanFuncs add_funcs{
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
per_iter_bh.OutputEigen<T>() = per_iter_bh.ScalarInput0<T>() + per_iter_bh.EigenInput1<T>().array();
|
||||
},
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>().array() + per_iter_bh.ScalarInput1<T>();
|
||||
},
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>() + per_iter_bh.EigenInput1<T>();
|
||||
}}; // For element-wise add
|
||||
|
||||
// Allocate space for output of Q(BS, D) + bias(D)
|
||||
AllocatorPtr allocator;
|
||||
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
|
||||
std::vector<int64_t> old_dims({batch_size, sequence_length, hidden_size});
|
||||
gsl::span<const int64_t> old_dims_span{old_dims};
|
||||
TensorShape qkv_with_bias_shape(old_dims_span);
|
||||
OrtValue qkv_with_bias;
|
||||
Tensor::InitOrtValue(element_type, qkv_with_bias_shape, allocator, qkv_with_bias);
|
||||
|
||||
// Get Q's bias from combined bias
|
||||
std::vector<int64_t> bias_dims({hidden_size});
|
||||
gsl::span<const int64_t> bias_dims_span{bias_dims};
|
||||
TensorShape bias_shape(bias_dims_span);
|
||||
OrtValue bias;
|
||||
Tensor::InitOrtValue(element_type, bias_shape, allocator, bias);
|
||||
memcpy(bias.GetMutable<Tensor>()->MutableData<T>(), qkv_bias + bias_offset, hidden_size * element_size);
|
||||
|
||||
// Compute Q(BS, D) + bias(D) as broadcasted element-wise add
|
||||
{
|
||||
InputBroadcaster input_broadcaster(*bias.GetMutable<Tensor>(), *qkv);
|
||||
const InputBroadcaster& const_input_broadcaster = input_broadcaster;
|
||||
Tensor& output_tensor = *qkv_with_bias.GetMutable<Tensor>();
|
||||
|
||||
size_t span_size = input_broadcaster.GetSpanSize();
|
||||
size_t output_size = static_cast<ptrdiff_t>(output_tensor.Shape().Size());
|
||||
void* user_data = nullptr;
|
||||
|
||||
const int loop_len = static_cast<int>(output_size / span_size);
|
||||
double unit_cost = 1.0f;
|
||||
const auto cost = TensorOpCost{static_cast<double>(input_broadcaster.Input0ElementSize()) * span_size,
|
||||
static_cast<double>(output_tensor.DataType()->Size()) * span_size,
|
||||
unit_cost * span_size};
|
||||
auto tp = context->GetOperatorThreadPool();
|
||||
ThreadPool::TryParallelFor(tp, loop_len, cost,
|
||||
[span_size, &const_input_broadcaster, &output_tensor, &add_funcs, user_data](std::ptrdiff_t first_span,
|
||||
std::ptrdiff_t last_span) {
|
||||
InputBroadcaster segment_input_broadcaster(const_input_broadcaster);
|
||||
segment_input_broadcaster.AdvanceBy(first_span * span_size);
|
||||
|
||||
OutputBroadcaster segment_output_broadcaster(span_size, output_tensor,
|
||||
first_span * span_size, last_span * span_size);
|
||||
|
||||
BroadcastHelper segment_helper(segment_input_broadcaster, segment_output_broadcaster, user_data);
|
||||
BroadcastLooper(segment_helper, add_funcs);
|
||||
});
|
||||
}
|
||||
|
||||
// Reshape Q from BxSxD to BxSxNxH
|
||||
ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(qkv_with_bias.GetMutable<Tensor>(), batch_size, sequence_length, num_heads, head_size));
|
||||
|
||||
// Transpose Q from BxSxNxH to BxNxSxH
|
||||
auto tp = context->GetOperatorThreadPool();
|
||||
ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH(qkv_with_bias.GetMutable<Tensor>(), qkv_with_bias_transposed, tp));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Add bias + reshape for each of Q/K/V
|
||||
// This is used in decoder_with_past when the sequence length is 1
|
||||
template <typename T>
|
||||
Status AddBiasReshape(const Tensor* qkv, // Input: Q/K/V data - query is BxSxD, key is BxLxD, value is BxLxD_v
|
||||
const T* qkv_bias, // Input: QKV bias - bias is (D + D + D_v)
|
||||
OrtValue& qkv_with_bias, // Output: Q/K/V data - query is BxNxSxH, key is BxNxLxH, value is BxNxLxH_v
|
||||
int bias_offset, // bias offset to enter qkv_bias
|
||||
int batch_size, // batch size
|
||||
int sequence_length, // sequence_length for Q, kv_sequence_length for K/V
|
||||
int num_heads, // num heads
|
||||
int head_size, // head_size for Q/K, v_head_size for V
|
||||
int hidden_size, // hidden_size for Q/K, v_hidden_size for V
|
||||
OpKernelContext* context) {
|
||||
// Note: the comments below will refer to Q's dimensions for simplicity
|
||||
auto element_type = DataTypeImpl::GetType<T>();
|
||||
constexpr size_t element_size = sizeof(T);
|
||||
ProcessBroadcastSpanFuncs add_funcs{
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
per_iter_bh.OutputEigen<T>() = per_iter_bh.ScalarInput0<T>() + per_iter_bh.EigenInput1<T>().array();
|
||||
},
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>().array() + per_iter_bh.ScalarInput1<T>();
|
||||
},
|
||||
[](BroadcastHelper& per_iter_bh) {
|
||||
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>() + per_iter_bh.EigenInput1<T>();
|
||||
}}; // For element-wise add
|
||||
|
||||
// Get Q's bias from combined bias
|
||||
AllocatorPtr allocator;
|
||||
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
|
||||
std::vector<int64_t> bias_dims({hidden_size});
|
||||
gsl::span<const int64_t> bias_dims_span{bias_dims};
|
||||
TensorShape bias_shape(bias_dims_span);
|
||||
OrtValue bias;
|
||||
Tensor::InitOrtValue(element_type, bias_shape, allocator, bias);
|
||||
auto num_bias_elements = SafeInt<size_t>(hidden_size) * element_size;
|
||||
memcpy(bias.GetMutable<Tensor>()->MutableData<T>(), qkv_bias + bias_offset, num_bias_elements);
|
||||
|
||||
// Compute Q(BS, D) + bias(D) as broadcasted element-wise add
|
||||
{
|
||||
InputBroadcaster input_broadcaster(*bias.GetMutable<Tensor>(), *qkv);
|
||||
const InputBroadcaster& const_input_broadcaster = input_broadcaster;
|
||||
Tensor& output_tensor = *qkv_with_bias.GetMutable<Tensor>();
|
||||
|
||||
size_t span_size = input_broadcaster.GetSpanSize();
|
||||
size_t output_size = static_cast<ptrdiff_t>(output_tensor.Shape().Size());
|
||||
void* user_data = nullptr;
|
||||
|
||||
const int loop_len = static_cast<int>(output_size / span_size);
|
||||
double unit_cost = 1.0f;
|
||||
const auto cost = TensorOpCost{static_cast<double>(input_broadcaster.Input0ElementSize()) * span_size,
|
||||
static_cast<double>(output_tensor.DataType()->Size()) * span_size,
|
||||
unit_cost * span_size};
|
||||
auto tp = context->GetOperatorThreadPool();
|
||||
ThreadPool::TryParallelFor(tp, loop_len, cost,
|
||||
[span_size, &const_input_broadcaster, &output_tensor, &add_funcs, user_data](std::ptrdiff_t first_span,
|
||||
std::ptrdiff_t last_span) {
|
||||
InputBroadcaster segment_input_broadcaster(const_input_broadcaster);
|
||||
segment_input_broadcaster.AdvanceBy(first_span * span_size);
|
||||
|
||||
OutputBroadcaster segment_output_broadcaster(span_size, output_tensor,
|
||||
first_span * span_size, last_span * span_size);
|
||||
|
||||
BroadcastHelper segment_helper(segment_input_broadcaster, segment_output_broadcaster, user_data);
|
||||
BroadcastLooper(segment_helper, add_funcs);
|
||||
});
|
||||
}
|
||||
|
||||
// Reshape Q from BxSxD to BxNxSxH
|
||||
std::vector<int64_t> reshape_dims({batch_size, num_heads, sequence_length, head_size});
|
||||
gsl::span<const int64_t> reshape_dims_span{reshape_dims};
|
||||
TensorShape qkv_final_dims(reshape_dims_span);
|
||||
qkv_with_bias.GetMutable<Tensor>()->Reshape(qkv_final_dims);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status MaybeTransposeToBNSHAndAddBias(OpKernelContext* context, AllocatorPtr allocator,
|
||||
int batch_size, int num_heads, int sequence_length, int head_size,
|
||||
const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out) {
|
||||
auto element_type = DataTypeImpl::GetType<T>();
|
||||
std::vector<int64_t> new_dims({batch_size, num_heads, sequence_length, head_size});
|
||||
gsl::span<const int64_t> new_dims_span{new_dims};
|
||||
TensorShape v_BNLH(new_dims_span);
|
||||
Tensor::InitOrtValue(element_type, v_BNLH, allocator, out);
|
||||
if (bias == nullptr) {
|
||||
std::unique_ptr<Tensor> reshaped;
|
||||
if (in->Shape().GetDims().size() == 3) {
|
||||
reshaped = std::make_unique<Tensor>(in->DataType(), in->Shape(), const_cast<void*>(in->DataRaw()), in->Location());
|
||||
ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(reshaped.get(), batch_size, sequence_length, num_heads, head_size));
|
||||
}
|
||||
ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH((reshaped == nullptr) ? in : reshaped.get(), out));
|
||||
} else {
|
||||
const auto* qkv_bias = bias->Data<T>();
|
||||
if (sequence_length == 1) {
|
||||
ORT_RETURN_IF_ERROR(AddBiasReshape(in, qkv_bias, out, bias_offset, batch_size, sequence_length, num_heads, head_size, num_heads * head_size, context));
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(AddBiasTranspose(in, qkv_bias, out, bias_offset, batch_size, sequence_length, num_heads, head_size, num_heads * head_size, context));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
|
||||
const Tensor* query = context->Input<Tensor>(0);
|
||||
|
|
|
|||
|
|
@ -36,6 +36,71 @@ RotaryEmbedding<T>::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) {
|
|||
}
|
||||
}
|
||||
|
||||
// TODO: rotary embedding in place
|
||||
template <typename T>
|
||||
Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters parameters, const T* input,
|
||||
const int64_t* position_ids, const T* cos_cache, const T* sin_cache, T* output,
|
||||
bool interleaved) {
|
||||
const int batch_size = parameters.batch_size;
|
||||
const int sequence_length = parameters.sequence_length;
|
||||
const int n_heads = parameters.num_heads;
|
||||
const int head_size = parameters.head_size;
|
||||
const int head_stride = parameters.head_stride;
|
||||
const int seq_stride = parameters.seq_stride;
|
||||
const int batch_stride = parameters.batch_stride;
|
||||
const int position_ids_format = parameters.position_ids_format;
|
||||
const int rotary_emb_dim = parameters.rotary_embedding_dim;
|
||||
const int half_rotary_emb_dim = rotary_emb_dim / 2;
|
||||
|
||||
const int loop_len = batch_size * sequence_length * n_heads;
|
||||
const double cost = static_cast<double>(rotary_emb_dim);
|
||||
ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
|
||||
for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) {
|
||||
const int b = static_cast<int>((ptr / n_heads) / sequence_length);
|
||||
const int s = static_cast<int>((ptr / n_heads) % sequence_length);
|
||||
const int n = static_cast<int>(ptr % n_heads);
|
||||
|
||||
const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;
|
||||
|
||||
const T* input_data = input + block_offset;
|
||||
T* output_data = output + block_offset;
|
||||
|
||||
// Cache is (M, H/2) or (M, rotary_embedding_dim/2)
|
||||
const int position_id = (position_ids_format == 0)
|
||||
? static_cast<int>(position_ids[0]) + s
|
||||
: static_cast<int>(position_ids[b * sequence_length + s]);
|
||||
const int cache_offset = position_id * half_rotary_emb_dim;
|
||||
const T* cos_data = cos_cache + cache_offset;
|
||||
const T* sin_data = sin_cache + cache_offset;
|
||||
|
||||
int cache_idx = 0;
|
||||
T sign = 0;
|
||||
int j = 0;
|
||||
for (int i = 0; i < rotary_emb_dim; i++) {
|
||||
if (interleaved) {
|
||||
cache_idx = (i / 2) % half_rotary_emb_dim;
|
||||
sign = (i % 2 == 0) ? static_cast<T>(-1) : static_cast<T>(1);
|
||||
j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign
|
||||
} else {
|
||||
cache_idx = i % half_rotary_emb_dim;
|
||||
sign = (i < half_rotary_emb_dim) ? static_cast<T>(-1) : static_cast<T>(1);
|
||||
j = (i + half_rotary_emb_dim) % rotary_emb_dim;
|
||||
}
|
||||
output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx];
|
||||
}
|
||||
for (int i = rotary_emb_dim; i < head_size; i++) {
|
||||
output_data[i] = input_data[i];
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template Status RunRotaryEmbedding<float>(concurrency::ThreadPool* tp, RotaryParameters parameters, const float* input,
|
||||
const int64_t* position_ids, const float* cos_cache, const float* sin_cache, float* output,
|
||||
bool interleaved);
|
||||
|
||||
template <typename T>
|
||||
Status RotaryEmbedding<T>::Compute(OpKernelContext* context) const {
|
||||
const Tensor* input = context->Input<Tensor>(0);
|
||||
|
|
@ -65,72 +130,12 @@ Status RotaryEmbedding<T>::Compute(OpKernelContext* context) const {
|
|||
const T* sin_cache_data = sin_cache->Data<T>();
|
||||
T* output_dest = output->MutableData<T>();
|
||||
|
||||
const int batch_size = parameters.batch_size;
|
||||
const int sequence_length = parameters.sequence_length;
|
||||
const int n_heads = parameters.num_heads;
|
||||
const int head_size = parameters.head_size;
|
||||
const int position_ids_format = parameters.position_ids_format;
|
||||
const int rotary_emb_dim = parameters.rotary_embedding_dim;
|
||||
const int half_rotary_emb_dim = rotary_emb_dim / 2;
|
||||
|
||||
// Default input tensor shape is [batch, seq_len, hidden_size]
|
||||
int head_stride = head_size;
|
||||
int seq_stride = n_heads * head_stride;
|
||||
int batch_stride = sequence_length * seq_stride;
|
||||
if (parameters.transposed) {
|
||||
// Transposed input tensor shape is [batch, n_heads, seq_len, head_size]
|
||||
seq_stride = head_size;
|
||||
head_stride = sequence_length * seq_stride;
|
||||
batch_stride = n_heads * head_stride;
|
||||
}
|
||||
|
||||
AllocatorPtr allocator;
|
||||
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
|
||||
auto* tp = context->GetOperatorThreadPool();
|
||||
|
||||
const int loop_len = batch_size * sequence_length * n_heads;
|
||||
const double cost = static_cast<double>(rotary_emb_dim);
|
||||
ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
|
||||
for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) {
|
||||
const int b = static_cast<int>((ptr / n_heads) / sequence_length);
|
||||
const int s = static_cast<int>((ptr / n_heads) % sequence_length);
|
||||
const int n = static_cast<int>(ptr % n_heads);
|
||||
|
||||
const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;
|
||||
|
||||
const T* input_data = input_src + block_offset;
|
||||
T* output_data = output_dest + block_offset;
|
||||
|
||||
// Cache is (M, H/2) or (M, rotary_embedding_dim/2)
|
||||
const int position_id = (position_ids_format == 0)
|
||||
? static_cast<int>(pos_ids_data[0]) + s
|
||||
: static_cast<int>(pos_ids_data[b * sequence_length + s]);
|
||||
const int cache_offset = position_id * half_rotary_emb_dim;
|
||||
const T* cos_data = cos_cache_data + cache_offset;
|
||||
const T* sin_data = sin_cache_data + cache_offset;
|
||||
|
||||
int cache_idx = 0;
|
||||
T sign = 0;
|
||||
int j = 0;
|
||||
for (int i = 0; i < rotary_emb_dim; i++) {
|
||||
if (interleaved) {
|
||||
cache_idx = (i / 2) % half_rotary_emb_dim;
|
||||
sign = (i % 2 == 0) ? static_cast<T>(-1) : static_cast<T>(1);
|
||||
j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign
|
||||
} else {
|
||||
cache_idx = i % half_rotary_emb_dim;
|
||||
sign = (i < half_rotary_emb_dim) ? static_cast<T>(-1) : static_cast<T>(1);
|
||||
j = (i + half_rotary_emb_dim) % rotary_emb_dim;
|
||||
}
|
||||
output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx];
|
||||
}
|
||||
for (int i = rotary_emb_dim; i < head_size; i++) {
|
||||
output_data[i] = input_data[i];
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return Status::OK();
|
||||
return RunRotaryEmbedding<T>(tp, parameters, input_src, pos_ids_data, cos_cache_data, sin_cache_data, output_dest,
|
||||
interleaved);
|
||||
}
|
||||
|
||||
} // namespace contrib
|
||||
|
|
|
|||
|
|
@ -4,10 +4,16 @@
|
|||
#pragma once
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "rotary_embedding_helper.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
template <typename T>
|
||||
Status RunRotaryEmbedding(onnxruntime::concurrency::ThreadPool* tp, rotary_embedding_helper::RotaryParameters parameters, const T* input,
|
||||
const int64_t* position_ids, const T* cos_cache, const T* sin_cache, T* output,
|
||||
bool interleaved);
|
||||
|
||||
template <typename T>
|
||||
class RotaryEmbedding final : public OpKernel {
|
||||
public:
|
||||
|
|
|
|||
|
|
@ -18,6 +18,9 @@ struct RotaryParameters {
|
|||
int rotary_embedding_dim; // Rotary embedding dimension.
|
||||
int num_heads; // num_heads = hidden_size / head_size
|
||||
int max_sequence_length; // Sequence length used by cos/sin cache
|
||||
int head_stride; // Head stride
|
||||
int seq_stride; // Sequence stride
|
||||
int batch_stride; // Batch stride
|
||||
int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length)
|
||||
bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden)
|
||||
};
|
||||
|
|
@ -116,6 +119,23 @@ Status CheckInputs(const T* input,
|
|||
"head_size / 2 or rotary_embedding_dim / 2, got ", cos_cache_dims[1]);
|
||||
}
|
||||
|
||||
num_heads = num_heads > 0 ? num_heads : static_cast<int>(hidden_size / head_size);
|
||||
// Calculate stride values
|
||||
int head_stride;
|
||||
int seq_stride;
|
||||
int batch_stride;
|
||||
if (transposed) {
|
||||
// Transposed input tensor shape is [batch, n_heads, seq_len, head_size]
|
||||
seq_stride = head_size;
|
||||
head_stride = sequence_length * seq_stride;
|
||||
batch_stride = num_heads * head_stride;
|
||||
} else {
|
||||
// Default input tensor shape is [batch, seq_len, hidden_size]
|
||||
head_stride = head_size;
|
||||
seq_stride = num_heads * head_stride;
|
||||
batch_stride = sequence_length * seq_stride;
|
||||
}
|
||||
|
||||
// Set rotary parameters
|
||||
if (parameters != nullptr) {
|
||||
RotaryParameters* output_parameters = reinterpret_cast<RotaryParameters*>(parameters);
|
||||
|
|
@ -123,8 +143,11 @@ Status CheckInputs(const T* input,
|
|||
output_parameters->sequence_length = sequence_length;
|
||||
output_parameters->hidden_size = hidden_size;
|
||||
output_parameters->head_size = head_size;
|
||||
output_parameters->num_heads = num_heads > 0 ? num_heads : static_cast<int>(hidden_size / head_size);
|
||||
output_parameters->num_heads = num_heads;
|
||||
output_parameters->max_sequence_length = max_sequence_length;
|
||||
output_parameters->head_stride = head_stride;
|
||||
output_parameters->seq_stride = seq_stride;
|
||||
output_parameters->batch_stride = batch_stride;
|
||||
output_parameters->position_ids_format = position_ids_format;
|
||||
output_parameters->transposed = transposed;
|
||||
output_parameters->rotary_embedding_dim = rotary_embedding_dim > 0 ? rotary_embedding_dim : head_size;
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MultiHeadAttention);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GroupQueryAttention);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM);
|
||||
|
|
@ -257,6 +258,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MultiHeadAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GroupQueryAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM)>,
|
||||
|
|
|
|||
|
|
@ -1009,6 +1009,9 @@ constexpr const char* GroupQueryAttention_ver1_doc = R"DOC(
|
|||
Group Query Self/Cross Attention.
|
||||
|
||||
Supports different number of heads for q and kv. Only supports causal or local attention.
|
||||
Supports rotary position embedding.
|
||||
Supports k-v cache.
|
||||
CPU EP supports fp32... CUDA EP supports fp16.
|
||||
)DOC";
|
||||
|
||||
ONNX_MS_OPERATOR_SET_SCHEMA(
|
||||
|
|
@ -1094,7 +1097,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
|
|||
"(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +"
|
||||
"kv_sequence_length.",
|
||||
"T")
|
||||
.TypeConstraint("T", {"tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output to float tensors.")
|
||||
.TypeConstraint("T", {"tensor(float16)", "tensor(bfloat16)", "tensor(float)"}, "Constrain input and output to float tensors.")
|
||||
.TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to int tensor.")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
GroupQueryAttentionTypeAndShapeInference(ctx, 3);
|
||||
|
|
|
|||
1884
onnxruntime/test/python/transformers/test_gqa_cpu.py
Normal file
1884
onnxruntime/test/python/transformers/test_gqa_cpu.py
Normal file
File diff suppressed because it is too large
Load diff
Loading…
Reference in a new issue