mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-25 22:26:24 +00:00
Attention kernel update to handle different Q,K,V hidden sizes (#8039)
* changes working to convert akv nodes * changes to replace nodes * changes to accomodate qkv hidden sizes as attributes * kernel to accept qkv_hidden_size attributes * Working till compute for varied dimension, todo applyattention() * changes to make all regression tests work * inference running successfully without prepack * success inference with pre-pack weights * add test for diff sizes * bias shape need not be a mul of 3 * get the output_hidden_size from input * infer output shape from input * merge with master * cleaning up files that got merged wrong * accurancy at accepted level * added unit test case for different dimensions * all unit tests passing * packed weights working for attention * prepacked weights working * added test case for newly added extra qk input * updated unit test to test only extra add qk * fixing build error * removing few debugs * reverting test changes * all python test passing * cleaning up * new unit test added, major clean up of code * removed extra code * minor * minor fix to tests * prepack weights code cleaned up * compacted compute() in attention.cc * reformat compute() * making a parameter T * adding 3 q,k,v buffers in all cases * fixing build * running tests only on cpu * Updating docs * trigger ci builds * Addressing comments in PR * addressing some more comments * get add_qk_str from add_qk node directly * updating docs, added extra check to verify attn inputs * Optimized the extra add by parallelizing * added attention_shape to symbolic_shape_infer.py * minor refactoring to address comments
This commit is contained in:
parent
c3129306e5
commit
afce0e2543
13 changed files with 577 additions and 124 deletions
|
|
@ -84,11 +84,13 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
<dl>
|
||||
<dt><tt>num_heads</tt> : int (required)</dt>
|
||||
<dd>Number of attention heads</dd>
|
||||
<dt><tt>qkv_hidden_sizes</tt> : list of ints</dt>
|
||||
<dd>Hidden layer sizes of Q, K, V paths in Attention</dd>
|
||||
<dt><tt>unidirectional</tt> : int</dt>
|
||||
<dd>Whether every token can only attend to previous tokens. Default value is 0.</dd>
|
||||
</dl>
|
||||
|
||||
#### Inputs (3 - 5)
|
||||
#### Inputs (3 - 6)
|
||||
|
||||
<dl>
|
||||
<dt><tt>input</tt> : T</dt>
|
||||
|
|
@ -101,6 +103,8 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
<dd>Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, past_sequence_length + sequence_length)or (batch_size, sequence_length, past_sequence_length + sequence_length), or index with shape (batch_size) or (2 * batch_size).</dd>
|
||||
<dt><tt>past</tt> (optional) : T</dt>
|
||||
<dd>past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).</dd>
|
||||
<dt><tt>extra_add</tt> (optional) : T</dt>
|
||||
<dd>additional add to QxK' with shape (batch_size, num_heads, sequence_length, sequence_length).</dd>
|
||||
</dl>
|
||||
|
||||
#### Outputs (1 - 2)
|
||||
|
|
|
|||
|
|
@ -366,7 +366,7 @@ Do not modify directly.*
|
|||
| |
|
||||
| |
|
||||
|**Operator Domain:** *com.microsoft*||||
|
||||
|Attention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float)|
|
||||
|Attention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* extra_add:**T**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float)|
|
||||
|AttnLSTM|*in* X:**T**<br> *in* W:**T**<br> *in* R:**T**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *in* initial_c:**T**<br> *in* P:**T**<br> *in* QW:**T**<br> *in* MW:**T**<br> *in* V:**T**<br> *in* M:**T**<br> *in* memory_seq_lens:**T1**<br> *in* AW:**T**<br> *out* Y:**T**<br> *out* Y_h:**T**<br> *out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)<br/> **T1** = tensor(int32)|
|
||||
|BiasGelu|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(float)|
|
||||
|CDist|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(double), tensor(float)|
|
||||
|
|
@ -705,7 +705,7 @@ Do not modify directly.*
|
|||
| |
|
||||
| |
|
||||
|**Operator Domain:** *com.microsoft*||||
|
||||
|Attention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|Attention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* extra_add:**T**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|BiasDropout|*in* data:**T**<br> *in* bias:**T**<br> *in* residual:**T**<br> *in* ratio:**T1**<br> *in* training_mode:**T2**<br> *out* output:**T**<br> *out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
|
||||
|BiasGelu|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|BiasSoftmax|*in* data:**T**<br> *in* bias:**T**<br> *out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ class Attention : public OpKernel, public AttentionCPUBase {
|
|||
public:
|
||||
explicit Attention(const OpKernelInfo& info);
|
||||
|
||||
bool IsPackWeightsSuccessful(int qkv_index, AllocatorPtr alloc, size_t head_size, size_t input_hidden_size, const T* weights_data, size_t weight_matrix_col_size, PrePackedWeights* prepacked_weights);
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
|
||||
|
|
@ -31,8 +33,13 @@ class Attention : public OpKernel, public AttentionCPUBase {
|
|||
/*out*/ bool& used_shared_buffers) override;
|
||||
|
||||
private:
|
||||
BufferUniquePtr packed_weights_;
|
||||
size_t packed_weights_size_ = 0;
|
||||
BufferUniquePtr q_packed_weights_;
|
||||
BufferUniquePtr k_packed_weights_;
|
||||
BufferUniquePtr v_packed_weights_;
|
||||
|
||||
size_t q_packed_weights_size_ = 0;
|
||||
size_t k_packed_weights_size_ = 0;
|
||||
size_t v_packed_weights_size_ = 0;
|
||||
TensorShape weight_shape_;
|
||||
};
|
||||
|
||||
|
|
@ -51,7 +58,8 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
|
|||
const TensorShape& weights_shape,
|
||||
const TensorShape& bias_shape,
|
||||
const Tensor*& mask_index,
|
||||
const Tensor* past) const {
|
||||
const Tensor* past,
|
||||
const Tensor* extra_add_qk) const {
|
||||
// Input shapes:
|
||||
// input : (batch_size, sequence_length, input_hidden_size)
|
||||
// weights : (input_hidden_size, 3 * hidden_size)
|
||||
|
|
@ -61,10 +69,16 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
|
|||
// or (batch_size, past_sequence_length + sequence_length)
|
||||
// or (batch_size, sequence_length, past_sequence_length + sequence_length)
|
||||
// past : (2, batch_size, num_heads, past_sequence_length, head_size)
|
||||
// extra_add_qk: (batch_size, num_heads, sequence_length, sequence_length)
|
||||
//
|
||||
// Where hidden_size = num_heads * head_size.
|
||||
// When a model is pruned (like some attention heads are removed), hidden_size < input_hidden_size.
|
||||
|
||||
if (past != nullptr && extra_add_qk != nullptr) {
|
||||
// past is used on GPT-2 model with past state, we don't have a case for extra add qk yet
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Attention cannot have past sequence and extra add qk");
|
||||
}
|
||||
|
||||
const auto& dims = input_shape.GetDims();
|
||||
if (dims.size() != 3) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'input' is expected to have 3 dimensions, got ",
|
||||
|
|
@ -83,21 +97,53 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
|
|||
"Input 1 dimension 0 should have same length as dimension 2 of input 0");
|
||||
}
|
||||
|
||||
int hidden_size = static_cast<int>(weights_dims[1]) / 3;
|
||||
if (3 * hidden_size != static_cast<int>(weights_dims[1])) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 1 dimension 1 should be 3 times of hidden dimension");
|
||||
}
|
||||
|
||||
if (hidden_size % num_heads_ != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "hidden_size should be divisiable by num_heads.");
|
||||
}
|
||||
|
||||
const auto& bias_dims = bias_shape.GetDims();
|
||||
if (bias_dims.size() != 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' is expected to have 1 dimension, got ",
|
||||
bias_dims.size());
|
||||
}
|
||||
|
||||
int hidden_size = 0;
|
||||
|
||||
if (qkv_hidden_sizes_.size() == 0) {
|
||||
hidden_size = static_cast<int>(weights_dims[1]) / 3;
|
||||
if (3 * hidden_size != static_cast<int>(weights_dims[1])) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 1 dimension 1 should be 3 times of hidden dimension");
|
||||
}
|
||||
|
||||
if (hidden_size % num_heads_ != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "hidden_size should be divisiable by num_heads.");
|
||||
}
|
||||
} else {
|
||||
int qkv_sizes = 0;
|
||||
|
||||
if (qkv_hidden_sizes_.size() != 3) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"qkv_hidden_sizes attribute should have 3 elements");
|
||||
}
|
||||
|
||||
if (qkv_hidden_sizes_[0] != qkv_hidden_sizes_[1]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"qkv_hidden_sizes first element should be same as the second");
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < qkv_hidden_sizes_.size(); i++) {
|
||||
if (qkv_hidden_sizes_[i] % num_heads_ != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "hidden_size should be divisiable by num_heads:", qkv_hidden_sizes_[i]);
|
||||
}
|
||||
|
||||
qkv_sizes += static_cast<int>(qkv_hidden_sizes_[i]);
|
||||
}
|
||||
|
||||
int qkv_hidden_sizes_sum = static_cast<int>(weights_dims[1]);
|
||||
if (qkv_hidden_sizes_sum != qkv_sizes) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "qkv_sizes doesn't match the wights dimension");
|
||||
}
|
||||
|
||||
hidden_size = static_cast<int>(qkv_hidden_sizes_[2]);
|
||||
}
|
||||
|
||||
if (bias_dims[0] != weights_dims[1]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'bias' dimension 0 should have same length as dimension 1 of input 'weights'");
|
||||
|
|
@ -157,6 +203,33 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
|
|||
mask_dims.size());
|
||||
}
|
||||
}
|
||||
|
||||
if (extra_add_qk != nullptr) {
|
||||
const auto& extra_add_qk_dims = extra_add_qk->Shape().GetDims();
|
||||
|
||||
if (extra_add_qk_dims.size() != 4) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'extra_add_qk' is expected to have 4 dimensions, got ",
|
||||
extra_add_qk_dims.size());
|
||||
}
|
||||
|
||||
if (extra_add_qk_dims[0] != batch_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'extra_add_qk' dimension 0 should be same as batch_size, got ",
|
||||
extra_add_qk_dims[0]);
|
||||
}
|
||||
if (extra_add_qk_dims[1] != num_heads_) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'extra_add_qk' dimension 1 should be same as number of heads, got ",
|
||||
extra_add_qk_dims[1]);
|
||||
}
|
||||
if (extra_add_qk_dims[2] != sequence_length) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'extra_add_qk' dimension 2 should be same as sequence_length, got ",
|
||||
extra_add_qk_dims[2]);
|
||||
}
|
||||
if (extra_add_qk_dims[3] != sequence_length) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'extra_add_qk' dimension 3 should be same as sequence_length, got ",
|
||||
extra_add_qk_dims[3]);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -170,7 +243,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
|
|||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block);
|
||||
}
|
||||
|
||||
return CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past);
|
||||
return CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past, nullptr);
|
||||
}
|
||||
|
||||
Tensor* AttentionBase::GetPresent(OpKernelContext* context,
|
||||
|
|
@ -203,6 +276,71 @@ template <typename T>
|
|||
Attention<T>::Attention(const OpKernelInfo& info) : OpKernel(info), AttentionCPUBase(info) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool Attention<T>::IsPackWeightsSuccessful(int qkv_index,
|
||||
AllocatorPtr alloc,
|
||||
size_t head_size,
|
||||
size_t input_hidden_size,
|
||||
const T* weights_data,
|
||||
size_t weight_matrix_col_size,
|
||||
/*out*/ PrePackedWeights* prepacked_weights) {
|
||||
size_t packb_size = MlasGemmPackBSize(head_size, input_hidden_size);
|
||||
if (packb_size == 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t loop_len = static_cast<size_t>(num_heads_);
|
||||
size_t packed_weights_data_size = packb_size * loop_len; // The same size would be computed by AllocArray() below
|
||||
auto* packed_weights_data = static_cast<uint8_t*>(alloc->AllocArray(packb_size, loop_len));
|
||||
|
||||
// Initialize memory to 0 as there could be some padding associated with pre-packed
|
||||
// buffer memory and we don not want it uninitialized and generate different hashes
|
||||
// if and when we try to cache this pre-packed buffer for sharing between sessions.
|
||||
memset(packed_weights_data, 0, packed_weights_data_size);
|
||||
switch (qkv_index) {
|
||||
case 0:
|
||||
q_packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc));
|
||||
q_packed_weights_size_ = packb_size;
|
||||
break;
|
||||
case 1:
|
||||
k_packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc));
|
||||
k_packed_weights_size_ = packb_size;
|
||||
break;
|
||||
case 2:
|
||||
v_packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc));
|
||||
v_packed_weights_size_ = packb_size;
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < loop_len; i++) {
|
||||
MlasGemmPackB(CblasNoTrans, head_size, input_hidden_size, weights_data, weight_matrix_col_size, packed_weights_data);
|
||||
packed_weights_data += packb_size;
|
||||
weights_data += head_size;
|
||||
}
|
||||
|
||||
bool share_prepacked_weights = (prepacked_weights != nullptr);
|
||||
if (share_prepacked_weights) {
|
||||
switch (qkv_index) {
|
||||
case 0:
|
||||
prepacked_weights->buffers_.push_back(std::move(q_packed_weights_));
|
||||
break;
|
||||
case 1:
|
||||
prepacked_weights->buffers_.push_back(std::move(k_packed_weights_));
|
||||
break;
|
||||
case 2:
|
||||
prepacked_weights->buffers_.push_back(std::move(v_packed_weights_));
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
prepacked_weights->buffer_sizes_.push_back(packed_weights_data_size);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status Attention<T>::PrePack(const Tensor& weights, int input_idx, AllocatorPtr alloc,
|
||||
/*out*/ bool& is_packed,
|
||||
|
|
@ -219,46 +357,47 @@ Status Attention<T>::PrePack(const Tensor& weights, int input_idx, AllocatorPtr
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
const size_t input_hidden_size = static_cast<size_t>(weights_dims[0]);
|
||||
const size_t hidden_size_x3 = static_cast<size_t>(weights_dims[1]);
|
||||
const size_t hidden_size = hidden_size_x3 / 3;
|
||||
const size_t head_size = hidden_size / num_heads_;
|
||||
|
||||
// Bail out if the weights shape has an expected shape.
|
||||
if ((hidden_size == 0) || ((hidden_size % num_heads_) != 0) || (hidden_size_x3 != 3 * hidden_size)) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const auto* weights_data = weights.Data<T>();
|
||||
const size_t input_hidden_size = static_cast<size_t>(weights_dims[0]);
|
||||
size_t q_hidden_size, k_hidden_size, v_hidden_size;
|
||||
|
||||
packed_weights_size_ = MlasGemmPackBSize(head_size, input_hidden_size);
|
||||
if (packed_weights_size_ == 0) {
|
||||
if (qkv_hidden_sizes_.size() != 0) {
|
||||
q_hidden_size = qkv_hidden_sizes_[0];
|
||||
k_hidden_size = qkv_hidden_sizes_[1];
|
||||
v_hidden_size = qkv_hidden_sizes_[2];
|
||||
|
||||
if (q_hidden_size == 0 || k_hidden_size == 0 || v_hidden_size == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (q_hidden_size % num_heads_ != 0 || k_hidden_size % num_heads_ != 0 || v_hidden_size % num_heads_ != 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
} else {
|
||||
const size_t hidden_size_x3 = static_cast<size_t>(weights_dims[1]);
|
||||
const size_t hidden_size = hidden_size_x3 / 3;
|
||||
|
||||
if (hidden_size % num_heads_ != 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
q_hidden_size = hidden_size;
|
||||
k_hidden_size = hidden_size;
|
||||
v_hidden_size = hidden_size;
|
||||
}
|
||||
|
||||
const size_t q_head_size = q_hidden_size / num_heads_;
|
||||
const size_t k_head_size = k_hidden_size / num_heads_;
|
||||
const size_t v_head_size = v_hidden_size / num_heads_;
|
||||
const size_t weight_matrix_col_size = q_hidden_size + k_hidden_size + v_hidden_size;
|
||||
|
||||
if (!IsPackWeightsSuccessful(0, alloc, q_head_size, input_hidden_size, weights_data, weight_matrix_col_size, prepacked_weights) ||
|
||||
!IsPackWeightsSuccessful(1, alloc, k_head_size, input_hidden_size, weights_data + (num_heads_ * q_head_size), weight_matrix_col_size, prepacked_weights) ||
|
||||
!IsPackWeightsSuccessful(2, alloc, v_head_size, input_hidden_size, weights_data + (num_heads_ * (q_head_size + k_head_size)), weight_matrix_col_size, prepacked_weights)) {
|
||||
// we are not cleaning up anything, assuming caller takes care of this
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const size_t loop_len = static_cast<size_t>(3) * num_heads_;
|
||||
size_t packed_weights_data_size = packed_weights_size_ * loop_len; // The same size would be computed by AllocArray() below
|
||||
auto* packed_weights_data = static_cast<uint8_t*>(alloc->AllocArray(packed_weights_size_, loop_len));
|
||||
|
||||
// Initialize memory to 0 as there could be some padding associated with pre-packed
|
||||
// buffer memory and we don not want it uninitialized and generate different hashes
|
||||
// if and when we try to cache this pre-packed buffer for sharing between sessions.
|
||||
memset(packed_weights_data, 0, packed_weights_data_size);
|
||||
|
||||
packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc));
|
||||
|
||||
for (size_t i = 0; i < loop_len; i++) {
|
||||
MlasGemmPackB(CblasNoTrans, head_size, input_hidden_size, weights_data, hidden_size_x3, packed_weights_data);
|
||||
packed_weights_data += packed_weights_size_;
|
||||
weights_data += head_size;
|
||||
}
|
||||
|
||||
bool share_prepacked_weights = (prepacked_weights != nullptr);
|
||||
if (share_prepacked_weights) {
|
||||
prepacked_weights->buffers_.push_back(std::move(packed_weights_));
|
||||
prepacked_weights->buffer_sizes_.push_back(packed_weights_data_size);
|
||||
}
|
||||
|
||||
is_packed = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -272,7 +411,9 @@ Status Attention<T>::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& pre
|
|||
}
|
||||
|
||||
used_shared_buffers = true;
|
||||
packed_weights_ = std::move(prepacked_buffers[0]);
|
||||
q_packed_weights_ = std::move(prepacked_buffers[0]);
|
||||
k_packed_weights_ = std::move(prepacked_buffers[1]);
|
||||
v_packed_weights_ = std::move(prepacked_buffers[2]);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -280,25 +421,35 @@ Status Attention<T>::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& pre
|
|||
template <typename T>
|
||||
Status Attention<T>::Compute(OpKernelContext* context) const {
|
||||
const Tensor* input = context->Input<Tensor>(0);
|
||||
const Tensor* weights = packed_weights_ ? nullptr : context->Input<Tensor>(1);
|
||||
const Tensor* weights = q_packed_weights_ ? nullptr : context->Input<Tensor>(1);
|
||||
const Tensor* bias = context->Input<Tensor>(2);
|
||||
|
||||
const Tensor* mask_index = context->Input<Tensor>(3);
|
||||
const Tensor* past = context->Input<Tensor>(4);
|
||||
const Tensor* extra_add_qk = context->Input<Tensor>(5);
|
||||
|
||||
const TensorShape& weights_shape = (weights ? weights->Shape() : weight_shape_);
|
||||
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(),
|
||||
weights_shape,
|
||||
bias->Shape(),
|
||||
mask_index,
|
||||
past));
|
||||
past,
|
||||
extra_add_qk));
|
||||
|
||||
const auto& shape = input->Shape().GetDims();
|
||||
const int batch_size = static_cast<int>(shape[0]);
|
||||
const int sequence_length = static_cast<int>(shape[1]);
|
||||
const int input_hidden_size = static_cast<int>(shape[2]);
|
||||
|
||||
int hidden_size;
|
||||
|
||||
if (qkv_hidden_sizes_.size() == 0) {
|
||||
const auto& weights_dims = weights_shape.GetDims();
|
||||
hidden_size = static_cast<int>(weights_dims[1]) / 3;
|
||||
} else {
|
||||
hidden_size = static_cast<int>(qkv_hidden_sizes_[2]);
|
||||
}
|
||||
|
||||
const auto& weights_dims = weights_shape.GetDims();
|
||||
const int hidden_size = static_cast<int>(weights_dims[1]) / 3;
|
||||
const int head_size = hidden_size / num_heads_;
|
||||
|
||||
std::vector<int64_t> output_shape(3);
|
||||
|
|
@ -309,19 +460,35 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
|
|||
|
||||
constexpr size_t element_size = sizeof(T);
|
||||
|
||||
int q_hidden_size = 0;
|
||||
int k_hidden_size = 0;
|
||||
int v_hidden_size = 0;
|
||||
if (qkv_hidden_sizes_.size() == 0) {
|
||||
q_hidden_size = hidden_size;
|
||||
k_hidden_size = hidden_size;
|
||||
v_hidden_size = hidden_size;
|
||||
} else {
|
||||
q_hidden_size = static_cast<int>(qkv_hidden_sizes_[0]);
|
||||
k_hidden_size = static_cast<int>(qkv_hidden_sizes_[1]);
|
||||
v_hidden_size = static_cast<int>(qkv_hidden_sizes_[2]);
|
||||
}
|
||||
const int qkv_head_size[3] = {q_hidden_size / num_heads_, k_hidden_size / num_heads_, v_hidden_size / num_heads_};
|
||||
|
||||
AllocatorPtr allocator;
|
||||
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
|
||||
|
||||
auto* tp = context->GetOperatorThreadPool();
|
||||
// Compute Q, K, V
|
||||
// gemm_data(BS, 3NH) = input(BS, D) x weights(D, 3NH) + bias(3NH)
|
||||
// D (input_hidden_size) is hidden dimension of input, where D could be larger than hidden_size (NH) when model is pruned.
|
||||
auto gemm_data = allocator->Alloc(SafeInt<size_t>(batch_size) * sequence_length * 3 * hidden_size * element_size);
|
||||
// gemm_data(BS, NT) = input(BS, D) x weights(D, NT) + bias(NT)
|
||||
// D (input_hidden_size) is hidden dimension of input, where D could be larger than any of the hidden_sizes
|
||||
// (NH) when model is pruned. T = H1 + H2 + H3, where H1, H2, H3 are head sizes of Q, K, V respectively
|
||||
auto gemm_data = allocator->Alloc(SafeInt<size_t>(batch_size) * sequence_length * (q_hidden_size + k_hidden_size + v_hidden_size) * element_size);
|
||||
BufferUniquePtr gemm_buffer(gemm_data, BufferDeleter(allocator));
|
||||
|
||||
auto Q = reinterpret_cast<T*>(gemm_data);
|
||||
auto K = Q + static_cast<size_t>(batch_size) * sequence_length * hidden_size;
|
||||
auto V = K + static_cast<size_t>(batch_size) * sequence_length * hidden_size;
|
||||
auto K = Q + static_cast<size_t>(batch_size) * sequence_length * q_hidden_size;
|
||||
auto V = K + static_cast<size_t>(batch_size) * sequence_length * k_hidden_size;
|
||||
|
||||
T* QKV[3] = {Q, K, V};
|
||||
|
||||
{
|
||||
|
|
@ -339,14 +506,25 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
|
|||
const int qkv_index = static_cast<int>(i % 3);
|
||||
|
||||
int input_offset = batch_index * sequence_length * input_hidden_size;
|
||||
int weights_offset = qkv_index * hidden_size + head_index * head_size;
|
||||
|
||||
T* qkv_dest = QKV[qkv_index];
|
||||
int head_size = qkv_head_size[qkv_index];
|
||||
int weights_offset = 0;
|
||||
int bias_offset = qkv_index * q_hidden_size + head_index * head_size;
|
||||
|
||||
if (q_packed_weights_ == nullptr) {
|
||||
weights_offset = bias_offset;
|
||||
} else {
|
||||
weights_offset = head_index * head_size;
|
||||
}
|
||||
|
||||
int qkv_offset = (batch_index * num_heads_ + head_index) * (sequence_length * head_size);
|
||||
|
||||
// TODO!! memcpy here makes it not worthwhile to use Gemm batch. Possible to post process?
|
||||
// broadcast 3NH -> (3.B.N.S.H)
|
||||
const T* broadcast_data_src = bias_data + weights_offset;
|
||||
// broadcast NH -> (B.N.S.H) for each of Q, K, V
|
||||
const T* broadcast_data_src = bias_data + bias_offset;
|
||||
T* broadcast_data_dest = QKV[qkv_index] + qkv_offset;
|
||||
|
||||
for (int seq_index = 0; seq_index < sequence_length; seq_index++) {
|
||||
memcpy(broadcast_data_dest, broadcast_data_src, head_size * sizeof(T));
|
||||
broadcast_data_dest += head_size;
|
||||
|
|
@ -354,15 +532,22 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
|
|||
|
||||
// original transposed iteration
|
||||
// A: input (BxSxD) (B.)S x D S x D
|
||||
// B: weights (Dx3xNxH) D x (3.N.)H D x H
|
||||
// C: QKV[qkv_index] (3xBxNxSxH) (3.B.N.)S x H S x H
|
||||
if (packed_weights_) {
|
||||
const auto* packed_weight =
|
||||
static_cast<const uint8_t*>(packed_weights_.get()) + packed_weights_size_ * (weights_offset / head_size);
|
||||
// B: weights (DxNxT) D x (N.)T D x H
|
||||
// C: QKV[qkv_index] (BxNxSxT) (B.N.)S x T S x H
|
||||
if (q_packed_weights_) {
|
||||
uint8_t* packed_weight;
|
||||
if (qkv_index == 0) {
|
||||
packed_weight = static_cast<uint8_t*>(q_packed_weights_.get()) + q_packed_weights_size_ * (weights_offset / head_size);
|
||||
} else if (qkv_index == 1) {
|
||||
packed_weight = static_cast<uint8_t*>(k_packed_weights_.get()) + k_packed_weights_size_ * (weights_offset / head_size);
|
||||
} else {
|
||||
packed_weight = static_cast<uint8_t*>(v_packed_weights_.get()) + v_packed_weights_size_ * (weights_offset / head_size);
|
||||
}
|
||||
|
||||
MlasGemm(
|
||||
CblasNoTrans, // TransA = no
|
||||
sequence_length, // M = S
|
||||
head_size, // N = H
|
||||
head_size, // N = H
|
||||
input_hidden_size, // K = D
|
||||
1.0f, // alpha
|
||||
input_data + input_offset, // A
|
||||
|
|
@ -374,20 +559,20 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
|
|||
nullptr); // use single-thread
|
||||
} else {
|
||||
math::GemmEx<float, ThreadPool>(
|
||||
CblasNoTrans, // TransA = no
|
||||
CblasNoTrans, // TransB = no
|
||||
sequence_length, // M = S
|
||||
head_size, // N = H
|
||||
input_hidden_size, // K = D
|
||||
1.0f, // alpha
|
||||
input_data + input_offset, // A
|
||||
input_hidden_size, // lda = D
|
||||
weights_data + weights_offset, // B
|
||||
3 * hidden_size, // ldb = 3NH
|
||||
1.0f, // beta
|
||||
qkv_dest + qkv_offset, // C
|
||||
head_size, // ldc
|
||||
nullptr // use single-thread
|
||||
CblasNoTrans, // TransA = no
|
||||
CblasNoTrans, // TransB = no
|
||||
sequence_length, // M = S
|
||||
head_size, // N = H
|
||||
input_hidden_size, // K = D
|
||||
1.0f, // alpha
|
||||
input_data + input_offset, // A
|
||||
input_hidden_size, // lda = D
|
||||
weights_data + weights_offset, // B
|
||||
q_hidden_size + k_hidden_size + v_hidden_size,// ldb = NH1 + NH2 + NH3
|
||||
1.0f, // beta
|
||||
qkv_dest + qkv_offset, // C
|
||||
head_size, // ldc
|
||||
nullptr // use single-thread
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -397,8 +582,8 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
|
|||
// Compute the attention score and apply the score to V
|
||||
return ApplyAttention(Q, K, V, mask_index, past, output,
|
||||
batch_size, sequence_length,
|
||||
head_size, hidden_size, context);
|
||||
qkv_head_size[0], qkv_head_size[2], v_hidden_size,
|
||||
extra_add_qk, context);
|
||||
}
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -33,16 +33,22 @@ class AttentionBase {
|
|||
num_heads_ = static_cast<int>(num_heads);
|
||||
|
||||
is_unidirectional_ = info.GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;
|
||||
|
||||
if (!info.GetAttrs<int64_t>("qkv_hidden_sizes", qkv_hidden_sizes_).IsOK() || qkv_hidden_sizes_.empty()) {
|
||||
qkv_hidden_sizes_.resize(0);
|
||||
}
|
||||
}
|
||||
|
||||
Status CheckInputs(const TensorShape& input_shape,
|
||||
const TensorShape& weights_shape,
|
||||
const TensorShape& bias_shape,
|
||||
const Tensor*& mask_index, // For dummy mask with shape (1, 1) or (batch_size, 1), it will be updated to nullptr.
|
||||
const Tensor* past) const;
|
||||
const Tensor* past,
|
||||
const Tensor *extra_add_qk) const;
|
||||
|
||||
int num_heads_; // number of attention heads
|
||||
bool is_unidirectional_; // whether every token can only attend to previous tokens.
|
||||
std::vector<int64_t> qkv_hidden_sizes_; // Q, K, V path hidden layer sizes
|
||||
};
|
||||
|
||||
} // namespace contrib
|
||||
|
|
|
|||
|
|
@ -26,8 +26,10 @@ class AttentionCPUBase : public AttentionBase {
|
|||
Tensor* output, // output tensor
|
||||
int batch_size, // batch size
|
||||
int sequence_length, // sequence length
|
||||
int head_size, // head size
|
||||
int hidden_size, // hidden size
|
||||
int qk_head_size, // qk_head_size
|
||||
int v_head_size, // head_size
|
||||
int v_hidden_size, // hidden_size
|
||||
const Tensor* extra_add_qk,// extra add in QK. Its size is BxNxSxS
|
||||
OpKernelContext* context) const {
|
||||
AllocatorPtr allocator;
|
||||
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
|
||||
|
|
@ -35,7 +37,7 @@ class AttentionCPUBase : public AttentionBase {
|
|||
auto* tp = context->GetOperatorThreadPool();
|
||||
|
||||
int past_sequence_length = 0;
|
||||
Tensor* present = GetPresent(context, past, batch_size, head_size, sequence_length, past_sequence_length);
|
||||
Tensor* present = GetPresent(context, past, batch_size, v_head_size, sequence_length, past_sequence_length);
|
||||
|
||||
// Total sequence length including that of past state: S* = S' + S
|
||||
const int all_sequence_length = past_sequence_length + sequence_length;
|
||||
|
|
@ -61,18 +63,23 @@ class AttentionCPUBase : public AttentionBase {
|
|||
const T* past_data = past != nullptr ? past->template Data<T>() : nullptr;
|
||||
T* present_data = present != nullptr ? present->template MutableData<T>() : nullptr;
|
||||
|
||||
const T* extra_add_qk_data = nullptr;
|
||||
if (extra_add_qk != nullptr) {
|
||||
extra_add_qk_data = extra_add_qk->template Data<T>();
|
||||
}
|
||||
|
||||
ComputeAttentionProbs<T>(static_cast<T*>(attention_probs), Q, K,
|
||||
mask_index_data, mask_index_dims, static_cast<T*>(mask_data),
|
||||
batch_size, sequence_length, past_sequence_length, head_size,
|
||||
past_data, present_data, tp);
|
||||
batch_size, sequence_length, past_sequence_length, qk_head_size == 0 ? v_head_size : qk_head_size,
|
||||
past_data, present_data, tp, extra_add_qk_data);
|
||||
|
||||
// Compute the attentionScore * Value. It does: out_tmp(B, N, S, H) = attention_probs(B, N, S, S*) x V(B, N, S*, H)
|
||||
auto out_tmp_data =
|
||||
allocator->Alloc(SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * head_size * sizeof(T));
|
||||
allocator->Alloc(SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * v_head_size * sizeof(T));
|
||||
BufferUniquePtr out_tmp_buffer(out_tmp_data, BufferDeleter(allocator));
|
||||
|
||||
ComputeVxAttentionScore(output->template MutableData<T>(), static_cast<T*>(out_tmp_data), static_cast<T*>(attention_probs), V,
|
||||
batch_size, sequence_length, past_sequence_length, head_size, hidden_size,
|
||||
batch_size, sequence_length, past_sequence_length, v_head_size, v_hidden_size,
|
||||
past_data, present_data, tp);
|
||||
|
||||
return Status::OK();
|
||||
|
|
@ -96,7 +103,8 @@ class AttentionCPUBase : public AttentionBase {
|
|||
int head_size, // head size of self-attention
|
||||
const T* past, // past state
|
||||
T* present, // present state
|
||||
ThreadPool* tp) const {
|
||||
ThreadPool* tp,
|
||||
const T* extra_add_qk_data) const {
|
||||
const int all_sequence_length = past_sequence_length + sequence_length; // S* = S' + S
|
||||
const size_t past_chunk_length = static_cast<size_t>(past_sequence_length) * head_size; // S' x H
|
||||
const size_t input_chunk_length = static_cast<size_t>(sequence_length) * head_size; // S x H
|
||||
|
|
@ -140,6 +148,13 @@ class AttentionCPUBase : public AttentionBase {
|
|||
math::Gemm<T, ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, all_sequence_length, head_size, alpha,
|
||||
Q + input_chunk_length * i, k, 1.0,
|
||||
reinterpret_cast<T*>(attention_probs) + sequence_length * all_sequence_length * i, nullptr);
|
||||
|
||||
if (extra_add_qk_data != nullptr) {
|
||||
int extra_add_qk_offset = static_cast<int>(i) * sequence_length * all_sequence_length;
|
||||
for (int j = 0; j < sequence_length * all_sequence_length ; j++) {
|
||||
attention_probs[extra_add_qk_offset+j] += extra_add_qk_data[extra_add_qk_offset + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -158,7 +158,8 @@ Status QAttention<T>::Compute(OpKernelContext* context) const {
|
|||
weights_shape,
|
||||
bias->Shape(),
|
||||
mask_index,
|
||||
past_tensor));
|
||||
past_tensor,
|
||||
nullptr));
|
||||
|
||||
ORT_RETURN_IF_NOT(IsScalarOr1ElementVector(input_scale_tensor),
|
||||
"input scale must be a scalar or 1D tensor of size 1");
|
||||
|
|
@ -286,7 +287,7 @@ Status QAttention<T>::Compute(OpKernelContext* context) const {
|
|||
// Compute the attention score and apply the score to V
|
||||
return ApplyAttention(Q, K, V, mask_index, past_tensor, output,
|
||||
batch_size, sequence_length,
|
||||
head_size, hidden_size, context);
|
||||
head_size, head_size, hidden_size, nullptr, context);
|
||||
}
|
||||
|
||||
} // namespace contrib
|
||||
|
|
|
|||
|
|
@ -397,17 +397,32 @@ void AttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int p
|
|||
|
||||
auto& bias_shape = getInputShape(ctx, 2);
|
||||
auto& bias_dims = bias_shape.dim();
|
||||
if (bias_dims.size() != 1 || bias_shape.dim(0).dim_value() % 3 != 0) {
|
||||
if (bias_dims.size() != 1) {
|
||||
fail_shape_inference("Invalid bias shape");
|
||||
}
|
||||
|
||||
std::vector<int64_t> qkv_hidden_sizes;
|
||||
getRepeatedAttribute(ctx, "qkv_hidden_sizes", qkv_hidden_sizes);
|
||||
|
||||
int64_t output_hidden_size;
|
||||
if (qkv_hidden_sizes.size() != 0) {
|
||||
if (qkv_hidden_sizes.size() != 3) {
|
||||
fail_shape_inference("qkv_hidden_sizes should have 3 elements")
|
||||
}
|
||||
output_hidden_size = qkv_hidden_sizes[2];
|
||||
} else {
|
||||
output_hidden_size = bias_shape.dim(0).dim_value() / 3;
|
||||
}
|
||||
|
||||
ONNX_NAMESPACE::TensorShapeProto output_shape;
|
||||
for (auto& dim : input_dims) {
|
||||
*output_shape.add_dim() = dim;
|
||||
}
|
||||
output_shape.mutable_dim(2)->set_dim_value(bias_shape.dim(0).dim_value() / 3);
|
||||
|
||||
output_shape.mutable_dim(2)->set_dim_value(output_hidden_size);
|
||||
updateOutputShape(ctx, 0, output_shape);
|
||||
|
||||
// TODO does the extra output need any changes?
|
||||
if (ctx.getNumOutputs() > 1) {
|
||||
if (hasInputShape(ctx, past_input_index)) {
|
||||
auto& past_shape = getInputShape(ctx, past_input_index);
|
||||
|
|
@ -453,12 +468,17 @@ and present state are optional. Present state could appear in output even when p
|
|||
"Whether every token can only attend to previous tokens. Default value is 0.",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(0))
|
||||
.Attr("qkv_hidden_sizes",
|
||||
"Hidden layer sizes of Q, K, V paths in Attention",
|
||||
AttributeProto::INTS,
|
||||
OPTIONAL_VALUE)
|
||||
.Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, input_hidden_size)", "T")
|
||||
.Input(1, "weight", "2D input tensor with shape (input_hidden_size, 3 * hidden_size), where hidden_size = num_heads * head_size", "T")
|
||||
.Input(2, "bias", "1D input tensor with shape (3 * hidden_size)", "T")
|
||||
.Input(3, "mask_index", "Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, past_sequence_length + sequence_length)"
|
||||
"or (batch_size, sequence_length, past_sequence_length + sequence_length), or index with shape (batch_size) or (2 * batch_size).", "M", OpSchema::Optional)
|
||||
.Input(4, "past", "past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).", "T", OpSchema::Optional)
|
||||
.Input(5, "extra_add", "additional add to QxK' with shape (batch_size, num_heads, sequence_length, sequence_length).", "T", OpSchema::Optional)
|
||||
.Output(0, "output", "3D output tensor with shape (batch_size, append_length, hidden_size)", "T")
|
||||
.Output(1, "present", "present state for key and value with shape (2, batch_size, num_heads, past_sequence_length + sequence_length, head_size)", "T", OpSchema::Optional)
|
||||
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.")
|
||||
|
|
@ -545,6 +565,7 @@ and present state are optional. Present state could appear in output even when p
|
|||
.TypeConstraint("T4", {"tensor(int32)"}, "Constrain mask index to integer types")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
constexpr int past_input_index = 8;
|
||||
|
||||
AttentionTypeAndShapeInference(ctx, past_input_index);
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -1479,7 +1479,12 @@ class SymbolicShapeInference:
|
|||
shape = self._get_shape(node, 0)
|
||||
shape_bias = self._get_shape(node, 2)
|
||||
assert len(shape) == 3 and len(shape_bias) == 1
|
||||
shape[2] = int(shape_bias[0] / 3)
|
||||
qkv_hidden_sizes_attr = get_attribute(node, 'qkv_hidden_sizes')
|
||||
if qkv_hidden_sizes_attr is not None:
|
||||
assert len(qkv_hidden_sizes_attr) == 3
|
||||
shape[2] = int(qkv_hidden_sizes_attr[2])
|
||||
else:
|
||||
shape[2] = int(shape_bias[0] / 3)
|
||||
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
||||
vi = self.known_vi_[node.output[0]]
|
||||
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape))
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from onnx import helper, numpy_helper, TensorProto, NodeProto
|
|||
from onnx_model import OnnxModel
|
||||
from fusion_base import Fusion
|
||||
from fusion_utils import FusionUtils, NumpyHelper
|
||||
from shape_infer_helper import SymbolicShapeInferenceHelper, get_shape_from_type_proto
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
|
@ -126,9 +127,29 @@ class FusionAttention(Fusion):
|
|||
|
||||
return num_heads, hidden_size
|
||||
|
||||
def get_add_qk_str(self, add_qk: NodeProto):
|
||||
# Note: Does not work for dynamic models, reshape node shape inference would fail with more than 2 dims being -1
|
||||
# inputs_ids has to be the name of input, this may be changes if needed
|
||||
inputs_ids = self.model.find_graph_input("input_ids")
|
||||
if inputs_ids == None:
|
||||
logger.debug("no input with name \"input_ids\"")
|
||||
return None
|
||||
batch_size, seq_len = get_shape_from_type_proto(inputs_ids.type)
|
||||
|
||||
if batch_size < 0 or seq_len < 0:
|
||||
logger.debug(f"batch_size: {batch_size} and seq_len {seq_len} cannot be -ve")
|
||||
return None
|
||||
shape_infer_helper = SymbolicShapeInferenceHelper(self.model.model)
|
||||
shape_infer_helper.infer({"batch_size": batch_size, "seq_len": seq_len})
|
||||
if shape_infer_helper.get_edge_shape(add_qk.input[0]) != shape_infer_helper.get_edge_shape(add_qk.input[1]):
|
||||
logger.debug(f"the shape of two inputs of {add_qk} is not same")
|
||||
return None
|
||||
|
||||
return add_qk.input[1]
|
||||
|
||||
def create_attention_node(self, mask_index: str, q_matmul: NodeProto, k_matmul: NodeProto, v_matmul: NodeProto,
|
||||
q_add: NodeProto, k_add: NodeProto, v_add: NodeProto, num_heads: int, hidden_size: int,
|
||||
input: str, output: str) -> Union[NodeProto, None]:
|
||||
input: str, output: str, add_qk_str: str) -> Union[NodeProto, None]:
|
||||
""" Create an Attention node.
|
||||
|
||||
Args:
|
||||
|
|
@ -165,6 +186,7 @@ class FusionAttention(Fusion):
|
|||
return None
|
||||
if not (k_weight and v_weight and q_bias and k_bias):
|
||||
return None
|
||||
|
||||
qw = NumpyHelper.to_array(q_weight)
|
||||
kw = NumpyHelper.to_array(k_weight)
|
||||
vw = NumpyHelper.to_array(v_weight)
|
||||
|
|
@ -246,6 +268,10 @@ class FusionAttention(Fusion):
|
|||
if mask_index is not None:
|
||||
attention_inputs.append(mask_index)
|
||||
|
||||
if add_qk_str is not None:
|
||||
attention_inputs.append("")
|
||||
attention_inputs.append(add_qk_str)
|
||||
|
||||
attention_node = helper.make_node('Attention',
|
||||
inputs=attention_inputs,
|
||||
outputs=[output],
|
||||
|
|
@ -334,7 +360,7 @@ class FusionAttention(Fusion):
|
|||
logger.debug("fuse_attention: failed to match v path")
|
||||
return
|
||||
(_, _, add_v, matmul_v) = v_nodes
|
||||
|
||||
|
||||
is_distill = False
|
||||
is_distill_add = False
|
||||
qk_paths = {
|
||||
|
|
@ -365,7 +391,7 @@ class FusionAttention(Fusion):
|
|||
if is_distill:
|
||||
(_, where_qk, matmul_qk, _) = qk_nodes
|
||||
elif is_distill_add:
|
||||
(_, _, where_qk, matmul_qk) = qk_nodes
|
||||
(_, add_qk, where_qk, matmul_qk) = qk_nodes
|
||||
else:
|
||||
(_, add_qk, _, matmul_qk) = qk_nodes
|
||||
|
||||
|
|
@ -392,6 +418,7 @@ class FusionAttention(Fusion):
|
|||
|
||||
# Note that Cast might be removed by OnnxRuntime so we match two patterns here.
|
||||
mask_nodes = None
|
||||
add_qk_str = None
|
||||
if is_distill:
|
||||
_, mask_nodes, _ = self.model.match_parent_paths(where_qk,
|
||||
[(['Expand', 'Reshape', 'Equal'], [0, 0, 0]),
|
||||
|
|
@ -401,6 +428,11 @@ class FusionAttention(Fusion):
|
|||
_, mask_nodes, _ = self.model.match_parent_paths(
|
||||
where_qk, [(['Cast', 'Equal', 'Unsqueeze', 'Unsqueeze'], [0, 0, 0, 0]),
|
||||
(['Equal', 'Unsqueeze', 'Unsqueeze'], [0, 0, 0])], output_name_to_node)
|
||||
if add_qk is not None:
|
||||
add_qk_str = self.get_add_qk_str(add_qk)
|
||||
if add_qk_str is None:
|
||||
logger.debug(f"fuse_attention: failed to verify shape inference of {add_qk}")
|
||||
return
|
||||
else:
|
||||
_, mask_nodes, _ = self.model.match_parent_paths(
|
||||
add_qk, [(['Mul', 'Sub', 'Cast', 'Unsqueeze', 'Unsqueeze'], [None, 0, 1, 0, 0]),
|
||||
|
|
@ -419,7 +451,7 @@ class FusionAttention(Fusion):
|
|||
# the input_hidden_size represents the input hidden size, this is used as needed but hidden sizes for Q, K are extracted appropriately
|
||||
new_node = self.create_attention_node(mask_index, matmul_q, matmul_k, matmul_v, add_q, add_k, add_v,
|
||||
q_num_heads, self.hidden_size, root_input,
|
||||
attention_last_node.output[0])
|
||||
attention_last_node.output[0], add_qk_str)
|
||||
if new_node is None:
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -140,7 +140,7 @@ class BertOnnxModelKeras(BertOnnxModelTF):
|
|||
attention_node = self.attention_fusion.create_attention_node(mask_index, matmul_q, matmul_k, matmul_v,
|
||||
add_q, add_k, add_v, self.num_heads,
|
||||
self.hidden_size, parent.output[0],
|
||||
reshape_qkv.output[0])
|
||||
reshape_qkv.output[0], None)
|
||||
if attention_node is None:
|
||||
continue
|
||||
|
||||
|
|
|
|||
|
|
@ -384,7 +384,7 @@ class BertOnnxModelTF(BertOnnxModel):
|
|||
attention_node = self.attention_fusion.create_attention_node(mask_index, matmul_k, matmul_q, matmul_v,
|
||||
add_k, add_q, add_v, self.num_heads,
|
||||
self.hidden_size, parent.output[0],
|
||||
qkv_nodes[2].output[0])
|
||||
qkv_nodes[2].output[0], None)
|
||||
if attention_node is None:
|
||||
continue
|
||||
|
||||
|
|
|
|||
|
|
@ -38,11 +38,14 @@ static void RunAttentionTest(
|
|||
MaskIndexType mask_index_type = kMaskIndexEnd,
|
||||
int input_hidden_size = 0,
|
||||
int max_sequence_length = 0,
|
||||
bool only_enable_cuda = false) {
|
||||
bool only_enable_cuda = false,
|
||||
bool only_enable_cpu = false,
|
||||
std::vector<int32_t> qkv_sizes = {},
|
||||
const std::vector<float>& extra_add_data = {}) {
|
||||
input_hidden_size = (input_hidden_size == 0 ? hidden_size : input_hidden_size); // By default, no pruning.
|
||||
|
||||
int min_cuda_architecture = use_float16 ? 530 : 0;
|
||||
bool enable_cuda = HasCudaEnvironment(min_cuda_architecture) && !is_weights_constant;
|
||||
bool enable_cuda = HasCudaEnvironment(min_cuda_architecture) && !is_weights_constant && !only_enable_cpu;
|
||||
bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()) && !use_float16 && !only_enable_cuda;
|
||||
|
||||
int head_size = hidden_size / number_of_heads;
|
||||
|
|
@ -51,9 +54,21 @@ static void RunAttentionTest(
|
|||
tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(number_of_heads));
|
||||
tester.AddAttribute<int64_t>("unidirectional", static_cast<int64_t>(is_unidirectional ? 1 : 0));
|
||||
|
||||
int32_t matrix_size;
|
||||
int32_t output_hidden_size;
|
||||
if (qkv_sizes.size() != 0) {
|
||||
matrix_size = qkv_sizes[0] + qkv_sizes[1] + qkv_sizes[2];
|
||||
std::vector<int64_t> sizes_attribute{qkv_sizes[0], qkv_sizes[1], qkv_sizes[2]};
|
||||
tester.AddAttribute<std::vector<int64_t>>("qkv_hidden_sizes", sizes_attribute);
|
||||
output_hidden_size = qkv_sizes[2];
|
||||
} else {
|
||||
matrix_size = 3 * hidden_size;
|
||||
output_hidden_size = hidden_size;
|
||||
}
|
||||
|
||||
std::vector<int64_t> input_dims = {batch_size, sequence_length, input_hidden_size};
|
||||
std::vector<int64_t> weights_dims = {input_hidden_size, 3 * hidden_size};
|
||||
std::vector<int64_t> bias_dims = {3 * hidden_size};
|
||||
std::vector<int64_t> weights_dims = {input_hidden_size, matrix_size};
|
||||
std::vector<int64_t> bias_dims = {matrix_size};
|
||||
|
||||
std::vector<int64_t> mask_index_dims_1 = {batch_size};
|
||||
std::vector<int64_t> mask_index_dims_2 = {2 * batch_size};
|
||||
|
|
@ -88,7 +103,7 @@ static void RunAttentionTest(
|
|||
|
||||
std::vector<int64_t> past_dims = {2, batch_size, number_of_heads, past_sequence_length, head_size};
|
||||
std::vector<int64_t> present_dims = {2, batch_size, number_of_heads, past_sequence_length + sequence_length, head_size};
|
||||
std::vector<int64_t> output_dims = {batch_size, sequence_length, hidden_size};
|
||||
std::vector<int64_t> output_dims = {batch_size, sequence_length, output_hidden_size};
|
||||
|
||||
if (use_float16) {
|
||||
tester.AddInput<MLFloat16>("input", input_dims, ToFloat16(input_data));
|
||||
|
|
@ -122,6 +137,14 @@ static void RunAttentionTest(
|
|||
}
|
||||
}
|
||||
|
||||
if (extra_add_data.size() > 0) {
|
||||
if (!use_past_state) {
|
||||
tester.AddOptionalInputEdge<float>();
|
||||
}
|
||||
std::vector<int64_t> extra_add_data_dims = {batch_size, number_of_heads, sequence_length, sequence_length};
|
||||
tester.AddInput<float>("extra_add_qk", extra_add_data_dims, extra_add_data);
|
||||
}
|
||||
|
||||
if (enable_cuda) {
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
|
|
@ -155,17 +178,20 @@ static void RunAttentionTest(
|
|||
MaskIndexType mask_index_type = kMaskIndexEnd,
|
||||
int input_hidden_size = 0,
|
||||
int max_sequence_length = 0,
|
||||
bool only_enable_cuda = false) {
|
||||
RunAttentionTest(input_data, weights_data, false, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
use_float16, is_unidirectional, use_past_state, past_sequence_length,
|
||||
past_data, present_data, mask_index_type, input_hidden_size, max_sequence_length,
|
||||
only_enable_cuda);
|
||||
RunAttentionTest(input_data, weights_data, true, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
use_float16, is_unidirectional, use_past_state, past_sequence_length,
|
||||
past_data, present_data, mask_index_type, input_hidden_size, max_sequence_length,
|
||||
only_enable_cuda);
|
||||
bool only_enable_cuda = false,
|
||||
bool only_enable_cpu = false,
|
||||
const std::vector<int32_t> qkv_sizes = {},
|
||||
const std::vector<float>& extra_add_data = {}) {
|
||||
RunAttentionTest(input_data, weights_data, false, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
use_float16, is_unidirectional, use_past_state, past_sequence_length,
|
||||
past_data, present_data, mask_index_type, input_hidden_size, max_sequence_length,
|
||||
only_enable_cuda, only_enable_cpu, qkv_sizes, extra_add_data);
|
||||
RunAttentionTest(input_data, weights_data, true, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
use_float16, is_unidirectional, use_past_state, past_sequence_length,
|
||||
past_data, present_data, mask_index_type, input_hidden_size, max_sequence_length,
|
||||
only_enable_cuda, only_enable_cpu, qkv_sizes, extra_add_data);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionBatch1) {
|
||||
|
|
@ -197,6 +223,163 @@ TEST(AttentionTest, AttentionBatch1) {
|
|||
batch_size, sequence_length, hidden_size, number_of_heads);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionBatch1WithQKVAttr1) {
|
||||
int batch_size = 1;
|
||||
int sequence_length = 2;
|
||||
int hidden_size = 4;
|
||||
int number_of_heads = 2;
|
||||
|
||||
std::vector<float> input_data = {
|
||||
0.8f, -0.5f, 0.0f, 1.f,
|
||||
0.5f, 0.2f, 0.3f, -0.6f};
|
||||
|
||||
std::vector<int32_t> qkv_sizes = {
|
||||
6, 6, 4};
|
||||
|
||||
std::vector<float> weight_data = {
|
||||
0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
|
||||
0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
|
||||
|
||||
0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
|
||||
0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f,
|
||||
|
||||
0.3f, 0.2f, 4.0f, 2.2f, 2.4f, 3.3f, 2.1f, 4.2f, 0.5f, 0.1f, 0.4f, 1.6f,
|
||||
0.4f, 0.8f, 0.9f, 0.1f
|
||||
};
|
||||
|
||||
std::vector<float> bias_data = {
|
||||
-0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f,
|
||||
0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f,
|
||||
0.5f, 0.7f, 0.2f, 1.2f};
|
||||
|
||||
std::vector<int32_t> mask_index_data = {2L};
|
||||
|
||||
std::vector<float> output_data = {
|
||||
3.1967618465423584f, 0.51903456449508667f, 0.63051539659500122f, 2.9394614696502686f,
|
||||
0.65332180261611938f, 1.000949501991272f, 0.74175024032592773f, 2.8231701850891113f};
|
||||
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
false, false, false, 0, nullptr, nullptr, kMaskIndexEnd, 0,
|
||||
0, false, true, qkv_sizes);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionBatch1WithQKVAttr2) {
|
||||
int batch_size = 1;
|
||||
int sequence_length = 2;
|
||||
int hidden_size = 4;
|
||||
int number_of_heads = 2;
|
||||
|
||||
std::vector<float> input_data = {
|
||||
-0.031707365f, 0.053643607f, 0.057394292f, -0.019800574f, 0.075466447f, -0.0034214978f, 0.012995008f, -0.019587509f};
|
||||
|
||||
std::vector<int32_t> qkv_sizes = {
|
||||
6, 6, 2};
|
||||
|
||||
std::vector<float> weight_data = {
|
||||
0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
|
||||
0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
|
||||
|
||||
0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
|
||||
0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f,
|
||||
|
||||
0.3f, 0.2f, 4.0f, 2.2f, 2.4f, 3.3f, 2.1f, 4.2f};
|
||||
|
||||
std::vector<float> bias_data = {
|
||||
-0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f,
|
||||
0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f,
|
||||
0.5f, 0.7f};
|
||||
|
||||
std::vector<int32_t> mask_index_data = {2L};
|
||||
|
||||
std::vector<float> output_data = {
|
||||
0.64932525157928467f, 0.79390722513198853f, 0.64932847023010254f, 0.79375863075256348f};
|
||||
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
false, false, false, 0, nullptr, nullptr, kMaskIndexEnd, 0,
|
||||
0, false, true, qkv_sizes);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionBatch1ExtraAdd) {
|
||||
int batch_size = 1;
|
||||
int sequence_length = 2;
|
||||
int hidden_size = 4;
|
||||
int number_of_heads = 2;
|
||||
|
||||
std::vector<float> input_data = {
|
||||
0.8f, -0.5f, 0.0f, 1.f,
|
||||
0.5f, 0.2f, 0.3f, -0.6f};
|
||||
|
||||
std::vector<int32_t> qkv_sizes = {};
|
||||
|
||||
std::vector<float> weight_data = {
|
||||
0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
|
||||
0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
|
||||
0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
|
||||
0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f};
|
||||
|
||||
std::vector<float> bias_data = {
|
||||
-0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f,
|
||||
0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f};
|
||||
|
||||
std::vector<int32_t> mask_index_data = {2L};
|
||||
|
||||
std::vector<float> extra_add_qk = {
|
||||
0.2f, -0.1f, 0.4f, 2.5f, 1.6f, -1.1f, 0.4f, -2.5f};
|
||||
|
||||
std::vector<float> output_data = {
|
||||
4.066014289855957f, 0.068997815251350403f, 4.25f, 5.6499996185302734f,
|
||||
-1.8799558877944946f, 0.32488855719566345f, 4.25f, 5.6499996185302734f};
|
||||
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
false, false, false, 0, nullptr, nullptr, kMaskIndexEnd, 0,
|
||||
0, false, true, qkv_sizes, extra_add_qk);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionBatch2ExtraAdd) {
|
||||
int batch_size = 2;
|
||||
int sequence_length = 2;
|
||||
int hidden_size = 4;
|
||||
int number_of_heads = 2;
|
||||
|
||||
std::vector<float> input_data = {
|
||||
0.8f, -0.5f, 0.0f, 1.f,
|
||||
0.5f, 0.2f, 0.3f, -0.6f,
|
||||
0.8f, -0.5f, 0.0f, 1.f,
|
||||
0.5f, 0.2f, 0.3f, -0.6f};
|
||||
|
||||
std::vector<int32_t> qkv_sizes = {};
|
||||
|
||||
std::vector<float> weight_data = {
|
||||
0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
|
||||
0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
|
||||
0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
|
||||
0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f};
|
||||
|
||||
std::vector<float> bias_data = {
|
||||
-0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f,
|
||||
0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f};
|
||||
|
||||
std::vector<int32_t> mask_index_data = {2L, 2L};
|
||||
|
||||
std::vector<float> extra_add_qk = {
|
||||
0.2f, -0.1f, 0.4f, 2.5f, 1.6f, -1.1f, 0.4f, -2.5f,
|
||||
0.2f, -0.1f, 0.4f, 2.5f, 1.6f, -1.1f, 0.4f, -2.5f};
|
||||
|
||||
std::vector<float> output_data = {
|
||||
4.066014289855957f, 0.068997815251350403f, 4.25f, 5.6499996185302734f,
|
||||
-1.8799558877944946f, 0.32488855719566345f, 4.25f, 5.6499996185302734f,
|
||||
4.066014289855957f, 0.068997815251350403f, 4.25f, 5.6499996185302734f,
|
||||
-1.8799558877944946f, 0.32488855719566345f, 4.25f, 5.6499996185302734f};
|
||||
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
false, false, false, 0, nullptr, nullptr, kMaskIndexEnd, 0,
|
||||
0, false, true, qkv_sizes, extra_add_qk);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionBatch1_Float16) {
|
||||
int batch_size = 1;
|
||||
int sequence_length = 2;
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import onnx
|
||||
import math
|
||||
import numpy as np
|
||||
from typing import List
|
||||
from packaging import version
|
||||
from onnx import helper, TensorProto
|
||||
|
|
@ -17,7 +18,7 @@ def float_tensor(name: str, shape: List[int], random=False):
|
|||
total_elements = 1
|
||||
for x in shape:
|
||||
total_elements *= x
|
||||
weights = [random.uniform(low, high) for _ in range(total_elements)] if random else [1.0] * total_elements
|
||||
weights = [np.random.uniform(low, high) for _ in range(total_elements)] if random else [1.0] * total_elements
|
||||
return helper.make_tensor(name, TensorProto.FLOAT, shape, weights)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue