From afce0e2543f04f024e761946dd9abbde93bc6d89 Mon Sep 17 00:00:00 2001 From: Viswanath Boga <44417868+viboga@users.noreply.github.com> Date: Mon, 19 Jul 2021 12:21:33 -0700 Subject: [PATCH] Attention kernel update to handle different Q,K,V hidden sizes (#8039) * changes working to convert akv nodes * changes to replace nodes * changes to accomodate qkv hidden sizes as attributes * kernel to accept qkv_hidden_size attributes * Working till compute for varied dimension, todo applyattention() * changes to make all regression tests work * inference running successfully without prepack * success inference with pre-pack weights * add test for diff sizes * bias shape need not be a mul of 3 * get the output_hidden_size from input * infer output shape from input * merge with master * cleaning up files that got merged wrong * accurancy at accepted level * added unit test case for different dimensions * all unit tests passing * packed weights working for attention * prepacked weights working * added test case for newly added extra qk input * updated unit test to test only extra add qk * fixing build error * removing few debugs * reverting test changes * all python test passing * cleaning up * new unit test added, major clean up of code * removed extra code * minor * minor fix to tests * prepack weights code cleaned up * compacted compute() in attention.cc * reformat compute() * making a parameter T * adding 3 q,k,v buffers in all cases * fixing build * running tests only on cpu * Updating docs * trigger ci builds * Addressing comments in PR * addressing some more comments * get add_qk_str from add_qk node directly * updating docs, added extra check to verify attn inputs * Optimized the extra add by parallelizing * added attention_shape to symbolic_shape_infer.py * minor refactoring to address comments --- docs/ContribOperators.md | 6 +- docs/OperatorKernels.md | 4 +- onnxruntime/contrib_ops/cpu/bert/attention.cc | 353 +++++++++++++----- .../contrib_ops/cpu/bert/attention_base.h | 8 +- .../contrib_ops/cpu/bert/attention_cpu_base.h | 31 +- .../cpu/quantization/attention_quant.cc | 5 +- .../core/graph/contrib_ops/contrib_defs.cc | 25 +- .../python/tools/symbolic_shape_infer.py | 7 +- .../tools/transformers/fusion_attention.py | 40 +- .../transformers/onnx_model_bert_keras.py | 2 +- .../tools/transformers/onnx_model_bert_tf.py | 2 +- .../test/contrib_ops/attention_op_test.cc | 215 ++++++++++- .../transformers/bert_model_generator.py | 3 +- 13 files changed, 577 insertions(+), 124 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 03b1b24c49..60fbad8254 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -84,11 +84,13 @@ This version of the operator has been available since version 1 of the 'com.micr
num_heads : int (required)
Number of attention heads
+
qkv_hidden_sizes : list of ints
+
Hidden layer sizes of Q, K, V paths in Attention
unidirectional : int
Whether every token can only attend to previous tokens. Default value is 0.
-#### Inputs (3 - 5) +#### Inputs (3 - 6)
input : T
@@ -101,6 +103,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, past_sequence_length + sequence_length)or (batch_size, sequence_length, past_sequence_length + sequence_length), or index with shape (batch_size) or (2 * batch_size).
past (optional) : T
past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).
+
extra_add (optional) : T
+
additional add to QxK' with shape (batch_size, num_heads, sequence_length, sequence_length).
#### Outputs (1 - 2) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 3665bae302..0286bf3cd7 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -366,7 +366,7 @@ Do not modify directly.* | | | | |**Operator Domain:** *com.microsoft*|||| -|Attention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float)| +|Attention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* extra_add:**T**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float)| |AttnLSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* QW:**T**
*in* MW:**T**
*in* V:**T**
*in* M:**T**
*in* memory_seq_lens:**T1**
*in* AW:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)
**T1** = tensor(int32)| |BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float)| |CDist|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(double), tensor(float)| @@ -705,7 +705,7 @@ Do not modify directly.* | | | | |**Operator Domain:** *com.microsoft*|||| -|Attention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)| +|Attention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* extra_add:**T**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)| |BiasDropout|*in* data:**T**
*in* bias:**T**
*in* residual:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| |BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |BiasSoftmax|*in* data:**T**
*in* bias:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index a1d6889c6c..8375ff7dfb 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -20,6 +20,8 @@ class Attention : public OpKernel, public AttentionCPUBase { public: explicit Attention(const OpKernelInfo& info); + bool IsPackWeightsSuccessful(int qkv_index, AllocatorPtr alloc, size_t head_size, size_t input_hidden_size, const T* weights_data, size_t weight_matrix_col_size, PrePackedWeights* prepacked_weights); + Status Compute(OpKernelContext* context) const override; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, @@ -31,8 +33,13 @@ class Attention : public OpKernel, public AttentionCPUBase { /*out*/ bool& used_shared_buffers) override; private: - BufferUniquePtr packed_weights_; - size_t packed_weights_size_ = 0; + BufferUniquePtr q_packed_weights_; + BufferUniquePtr k_packed_weights_; + BufferUniquePtr v_packed_weights_; + + size_t q_packed_weights_size_ = 0; + size_t k_packed_weights_size_ = 0; + size_t v_packed_weights_size_ = 0; TensorShape weight_shape_; }; @@ -51,7 +58,8 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, const TensorShape& weights_shape, const TensorShape& bias_shape, const Tensor*& mask_index, - const Tensor* past) const { + const Tensor* past, + const Tensor* extra_add_qk) const { // Input shapes: // input : (batch_size, sequence_length, input_hidden_size) // weights : (input_hidden_size, 3 * hidden_size) @@ -61,10 +69,16 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, // or (batch_size, past_sequence_length + sequence_length) // or (batch_size, sequence_length, past_sequence_length + sequence_length) // past : (2, batch_size, num_heads, past_sequence_length, head_size) + // extra_add_qk: (batch_size, num_heads, sequence_length, sequence_length) // // Where hidden_size = num_heads * head_size. // When a model is pruned (like some attention heads are removed), hidden_size < input_hidden_size. + if (past != nullptr && extra_add_qk != nullptr) { + // past is used on GPT-2 model with past state, we don't have a case for extra add qk yet + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Attention cannot have past sequence and extra add qk"); + } + const auto& dims = input_shape.GetDims(); if (dims.size() != 3) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'input' is expected to have 3 dimensions, got ", @@ -83,21 +97,53 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, "Input 1 dimension 0 should have same length as dimension 2 of input 0"); } - int hidden_size = static_cast(weights_dims[1]) / 3; - if (3 * hidden_size != static_cast(weights_dims[1])) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 1 dimension 1 should be 3 times of hidden dimension"); - } - - if (hidden_size % num_heads_ != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "hidden_size should be divisiable by num_heads."); - } - const auto& bias_dims = bias_shape.GetDims(); if (bias_dims.size() != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' is expected to have 1 dimension, got ", bias_dims.size()); } + + int hidden_size = 0; + + if (qkv_hidden_sizes_.size() == 0) { + hidden_size = static_cast(weights_dims[1]) / 3; + if (3 * hidden_size != static_cast(weights_dims[1])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 1 dimension 1 should be 3 times of hidden dimension"); + } + + if (hidden_size % num_heads_ != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "hidden_size should be divisiable by num_heads."); + } + } else { + int qkv_sizes = 0; + + if (qkv_hidden_sizes_.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "qkv_hidden_sizes attribute should have 3 elements"); + } + + if (qkv_hidden_sizes_[0] != qkv_hidden_sizes_[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "qkv_hidden_sizes first element should be same as the second"); + } + + for (size_t i = 0; i < qkv_hidden_sizes_.size(); i++) { + if (qkv_hidden_sizes_[i] % num_heads_ != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "hidden_size should be divisiable by num_heads:", qkv_hidden_sizes_[i]); + } + + qkv_sizes += static_cast(qkv_hidden_sizes_[i]); + } + + int qkv_hidden_sizes_sum = static_cast(weights_dims[1]); + if (qkv_hidden_sizes_sum != qkv_sizes) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "qkv_sizes doesn't match the wights dimension"); + } + + hidden_size = static_cast(qkv_hidden_sizes_[2]); + } + if (bias_dims[0] != weights_dims[1]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' dimension 0 should have same length as dimension 1 of input 'weights'"); @@ -157,6 +203,33 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, mask_dims.size()); } } + + if (extra_add_qk != nullptr) { + const auto& extra_add_qk_dims = extra_add_qk->Shape().GetDims(); + + if (extra_add_qk_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'extra_add_qk' is expected to have 4 dimensions, got ", + extra_add_qk_dims.size()); + } + + if (extra_add_qk_dims[0] != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'extra_add_qk' dimension 0 should be same as batch_size, got ", + extra_add_qk_dims[0]); + } + if (extra_add_qk_dims[1] != num_heads_) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'extra_add_qk' dimension 1 should be same as number of heads, got ", + extra_add_qk_dims[1]); + } + if (extra_add_qk_dims[2] != sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'extra_add_qk' dimension 2 should be same as sequence_length, got ", + extra_add_qk_dims[2]); + } + if (extra_add_qk_dims[3] != sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'extra_add_qk' dimension 3 should be same as sequence_length, got ", + extra_add_qk_dims[3]); + } + } + return Status::OK(); } @@ -170,7 +243,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); } - return CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past); + return CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past, nullptr); } Tensor* AttentionBase::GetPresent(OpKernelContext* context, @@ -203,6 +276,71 @@ template Attention::Attention(const OpKernelInfo& info) : OpKernel(info), AttentionCPUBase(info) { } +template +bool Attention::IsPackWeightsSuccessful(int qkv_index, + AllocatorPtr alloc, + size_t head_size, + size_t input_hidden_size, + const T* weights_data, + size_t weight_matrix_col_size, + /*out*/ PrePackedWeights* prepacked_weights) { + size_t packb_size = MlasGemmPackBSize(head_size, input_hidden_size); + if (packb_size == 0) { + return false; + } + + size_t loop_len = static_cast(num_heads_); + size_t packed_weights_data_size = packb_size * loop_len; // The same size would be computed by AllocArray() below + auto* packed_weights_data = static_cast(alloc->AllocArray(packb_size, loop_len)); + + // Initialize memory to 0 as there could be some padding associated with pre-packed + // buffer memory and we don not want it uninitialized and generate different hashes + // if and when we try to cache this pre-packed buffer for sharing between sessions. + memset(packed_weights_data, 0, packed_weights_data_size); + switch (qkv_index) { + case 0: + q_packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc)); + q_packed_weights_size_ = packb_size; + break; + case 1: + k_packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc)); + k_packed_weights_size_ = packb_size; + break; + case 2: + v_packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc)); + v_packed_weights_size_ = packb_size; + break; + default: + return false; + } + + for (size_t i = 0; i < loop_len; i++) { + MlasGemmPackB(CblasNoTrans, head_size, input_hidden_size, weights_data, weight_matrix_col_size, packed_weights_data); + packed_weights_data += packb_size; + weights_data += head_size; + } + + bool share_prepacked_weights = (prepacked_weights != nullptr); + if (share_prepacked_weights) { + switch (qkv_index) { + case 0: + prepacked_weights->buffers_.push_back(std::move(q_packed_weights_)); + break; + case 1: + prepacked_weights->buffers_.push_back(std::move(k_packed_weights_)); + break; + case 2: + prepacked_weights->buffers_.push_back(std::move(v_packed_weights_)); + break; + default: + break; + } + + prepacked_weights->buffer_sizes_.push_back(packed_weights_data_size); + } + return true; +} + template Status Attention::PrePack(const Tensor& weights, int input_idx, AllocatorPtr alloc, /*out*/ bool& is_packed, @@ -219,46 +357,47 @@ Status Attention::PrePack(const Tensor& weights, int input_idx, AllocatorPtr return Status::OK(); } - const size_t input_hidden_size = static_cast(weights_dims[0]); - const size_t hidden_size_x3 = static_cast(weights_dims[1]); - const size_t hidden_size = hidden_size_x3 / 3; - const size_t head_size = hidden_size / num_heads_; - - // Bail out if the weights shape has an expected shape. - if ((hidden_size == 0) || ((hidden_size % num_heads_) != 0) || (hidden_size_x3 != 3 * hidden_size)) { - return Status::OK(); - } - const auto* weights_data = weights.Data(); + const size_t input_hidden_size = static_cast(weights_dims[0]); + size_t q_hidden_size, k_hidden_size, v_hidden_size; - packed_weights_size_ = MlasGemmPackBSize(head_size, input_hidden_size); - if (packed_weights_size_ == 0) { + if (qkv_hidden_sizes_.size() != 0) { + q_hidden_size = qkv_hidden_sizes_[0]; + k_hidden_size = qkv_hidden_sizes_[1]; + v_hidden_size = qkv_hidden_sizes_[2]; + + if (q_hidden_size == 0 || k_hidden_size == 0 || v_hidden_size == 0) { + return Status::OK(); + } + + if (q_hidden_size % num_heads_ != 0 || k_hidden_size % num_heads_ != 0 || v_hidden_size % num_heads_ != 0) { + return Status::OK(); + } + } else { + const size_t hidden_size_x3 = static_cast(weights_dims[1]); + const size_t hidden_size = hidden_size_x3 / 3; + + if (hidden_size % num_heads_ != 0) { + return Status::OK(); + } + + q_hidden_size = hidden_size; + k_hidden_size = hidden_size; + v_hidden_size = hidden_size; + } + + const size_t q_head_size = q_hidden_size / num_heads_; + const size_t k_head_size = k_hidden_size / num_heads_; + const size_t v_head_size = v_hidden_size / num_heads_; + const size_t weight_matrix_col_size = q_hidden_size + k_hidden_size + v_hidden_size; + + if (!IsPackWeightsSuccessful(0, alloc, q_head_size, input_hidden_size, weights_data, weight_matrix_col_size, prepacked_weights) || + !IsPackWeightsSuccessful(1, alloc, k_head_size, input_hidden_size, weights_data + (num_heads_ * q_head_size), weight_matrix_col_size, prepacked_weights) || + !IsPackWeightsSuccessful(2, alloc, v_head_size, input_hidden_size, weights_data + (num_heads_ * (q_head_size + k_head_size)), weight_matrix_col_size, prepacked_weights)) { + // we are not cleaning up anything, assuming caller takes care of this return Status::OK(); } - const size_t loop_len = static_cast(3) * num_heads_; - size_t packed_weights_data_size = packed_weights_size_ * loop_len; // The same size would be computed by AllocArray() below - auto* packed_weights_data = static_cast(alloc->AllocArray(packed_weights_size_, loop_len)); - - // Initialize memory to 0 as there could be some padding associated with pre-packed - // buffer memory and we don not want it uninitialized and generate different hashes - // if and when we try to cache this pre-packed buffer for sharing between sessions. - memset(packed_weights_data, 0, packed_weights_data_size); - - packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc)); - - for (size_t i = 0; i < loop_len; i++) { - MlasGemmPackB(CblasNoTrans, head_size, input_hidden_size, weights_data, hidden_size_x3, packed_weights_data); - packed_weights_data += packed_weights_size_; - weights_data += head_size; - } - - bool share_prepacked_weights = (prepacked_weights != nullptr); - if (share_prepacked_weights) { - prepacked_weights->buffers_.push_back(std::move(packed_weights_)); - prepacked_weights->buffer_sizes_.push_back(packed_weights_data_size); - } - is_packed = true; return Status::OK(); } @@ -272,7 +411,9 @@ Status Attention::UseSharedPrePackedBuffers(std::vector& pre } used_shared_buffers = true; - packed_weights_ = std::move(prepacked_buffers[0]); + q_packed_weights_ = std::move(prepacked_buffers[0]); + k_packed_weights_ = std::move(prepacked_buffers[1]); + v_packed_weights_ = std::move(prepacked_buffers[2]); return Status::OK(); } @@ -280,25 +421,35 @@ Status Attention::UseSharedPrePackedBuffers(std::vector& pre template Status Attention::Compute(OpKernelContext* context) const { const Tensor* input = context->Input(0); - const Tensor* weights = packed_weights_ ? nullptr : context->Input(1); + const Tensor* weights = q_packed_weights_ ? nullptr : context->Input(1); const Tensor* bias = context->Input(2); + const Tensor* mask_index = context->Input(3); const Tensor* past = context->Input(4); + const Tensor* extra_add_qk = context->Input(5); const TensorShape& weights_shape = (weights ? weights->Shape() : weight_shape_); ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), weights_shape, bias->Shape(), mask_index, - past)); + past, + extra_add_qk)); const auto& shape = input->Shape().GetDims(); const int batch_size = static_cast(shape[0]); const int sequence_length = static_cast(shape[1]); const int input_hidden_size = static_cast(shape[2]); + + int hidden_size; + + if (qkv_hidden_sizes_.size() == 0) { + const auto& weights_dims = weights_shape.GetDims(); + hidden_size = static_cast(weights_dims[1]) / 3; + } else { + hidden_size = static_cast(qkv_hidden_sizes_[2]); + } - const auto& weights_dims = weights_shape.GetDims(); - const int hidden_size = static_cast(weights_dims[1]) / 3; const int head_size = hidden_size / num_heads_; std::vector output_shape(3); @@ -309,19 +460,35 @@ Status Attention::Compute(OpKernelContext* context) const { constexpr size_t element_size = sizeof(T); + int q_hidden_size = 0; + int k_hidden_size = 0; + int v_hidden_size = 0; + if (qkv_hidden_sizes_.size() == 0) { + q_hidden_size = hidden_size; + k_hidden_size = hidden_size; + v_hidden_size = hidden_size; + } else { + q_hidden_size = static_cast(qkv_hidden_sizes_[0]); + k_hidden_size = static_cast(qkv_hidden_sizes_[1]); + v_hidden_size = static_cast(qkv_hidden_sizes_[2]); + } + const int qkv_head_size[3] = {q_hidden_size / num_heads_, k_hidden_size / num_heads_, v_hidden_size / num_heads_}; + AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); auto* tp = context->GetOperatorThreadPool(); // Compute Q, K, V - // gemm_data(BS, 3NH) = input(BS, D) x weights(D, 3NH) + bias(3NH) - // D (input_hidden_size) is hidden dimension of input, where D could be larger than hidden_size (NH) when model is pruned. - auto gemm_data = allocator->Alloc(SafeInt(batch_size) * sequence_length * 3 * hidden_size * element_size); + // gemm_data(BS, NT) = input(BS, D) x weights(D, NT) + bias(NT) + // D (input_hidden_size) is hidden dimension of input, where D could be larger than any of the hidden_sizes + // (NH) when model is pruned. T = H1 + H2 + H3, where H1, H2, H3 are head sizes of Q, K, V respectively + auto gemm_data = allocator->Alloc(SafeInt(batch_size) * sequence_length * (q_hidden_size + k_hidden_size + v_hidden_size) * element_size); BufferUniquePtr gemm_buffer(gemm_data, BufferDeleter(allocator)); auto Q = reinterpret_cast(gemm_data); - auto K = Q + static_cast(batch_size) * sequence_length * hidden_size; - auto V = K + static_cast(batch_size) * sequence_length * hidden_size; + auto K = Q + static_cast(batch_size) * sequence_length * q_hidden_size; + auto V = K + static_cast(batch_size) * sequence_length * k_hidden_size; + T* QKV[3] = {Q, K, V}; { @@ -339,14 +506,25 @@ Status Attention::Compute(OpKernelContext* context) const { const int qkv_index = static_cast(i % 3); int input_offset = batch_index * sequence_length * input_hidden_size; - int weights_offset = qkv_index * hidden_size + head_index * head_size; + T* qkv_dest = QKV[qkv_index]; + int head_size = qkv_head_size[qkv_index]; + int weights_offset = 0; + int bias_offset = qkv_index * q_hidden_size + head_index * head_size; + + if (q_packed_weights_ == nullptr) { + weights_offset = bias_offset; + } else { + weights_offset = head_index * head_size; + } + int qkv_offset = (batch_index * num_heads_ + head_index) * (sequence_length * head_size); // TODO!! memcpy here makes it not worthwhile to use Gemm batch. Possible to post process? - // broadcast 3NH -> (3.B.N.S.H) - const T* broadcast_data_src = bias_data + weights_offset; + // broadcast NH -> (B.N.S.H) for each of Q, K, V + const T* broadcast_data_src = bias_data + bias_offset; T* broadcast_data_dest = QKV[qkv_index] + qkv_offset; + for (int seq_index = 0; seq_index < sequence_length; seq_index++) { memcpy(broadcast_data_dest, broadcast_data_src, head_size * sizeof(T)); broadcast_data_dest += head_size; @@ -354,15 +532,22 @@ Status Attention::Compute(OpKernelContext* context) const { // original transposed iteration // A: input (BxSxD) (B.)S x D S x D - // B: weights (Dx3xNxH) D x (3.N.)H D x H - // C: QKV[qkv_index] (3xBxNxSxH) (3.B.N.)S x H S x H - if (packed_weights_) { - const auto* packed_weight = - static_cast(packed_weights_.get()) + packed_weights_size_ * (weights_offset / head_size); + // B: weights (DxNxT) D x (N.)T D x H + // C: QKV[qkv_index] (BxNxSxT) (B.N.)S x T S x H + if (q_packed_weights_) { + uint8_t* packed_weight; + if (qkv_index == 0) { + packed_weight = static_cast(q_packed_weights_.get()) + q_packed_weights_size_ * (weights_offset / head_size); + } else if (qkv_index == 1) { + packed_weight = static_cast(k_packed_weights_.get()) + k_packed_weights_size_ * (weights_offset / head_size); + } else { + packed_weight = static_cast(v_packed_weights_.get()) + v_packed_weights_size_ * (weights_offset / head_size); + } + MlasGemm( CblasNoTrans, // TransA = no sequence_length, // M = S - head_size, // N = H + head_size, // N = H input_hidden_size, // K = D 1.0f, // alpha input_data + input_offset, // A @@ -374,20 +559,20 @@ Status Attention::Compute(OpKernelContext* context) const { nullptr); // use single-thread } else { math::GemmEx( - CblasNoTrans, // TransA = no - CblasNoTrans, // TransB = no - sequence_length, // M = S - head_size, // N = H - input_hidden_size, // K = D - 1.0f, // alpha - input_data + input_offset, // A - input_hidden_size, // lda = D - weights_data + weights_offset, // B - 3 * hidden_size, // ldb = 3NH - 1.0f, // beta - qkv_dest + qkv_offset, // C - head_size, // ldc - nullptr // use single-thread + CblasNoTrans, // TransA = no + CblasNoTrans, // TransB = no + sequence_length, // M = S + head_size, // N = H + input_hidden_size, // K = D + 1.0f, // alpha + input_data + input_offset, // A + input_hidden_size, // lda = D + weights_data + weights_offset, // B + q_hidden_size + k_hidden_size + v_hidden_size,// ldb = NH1 + NH2 + NH3 + 1.0f, // beta + qkv_dest + qkv_offset, // C + head_size, // ldc + nullptr // use single-thread ); } } @@ -397,8 +582,8 @@ Status Attention::Compute(OpKernelContext* context) const { // Compute the attention score and apply the score to V return ApplyAttention(Q, K, V, mask_index, past, output, batch_size, sequence_length, - head_size, hidden_size, context); + qkv_head_size[0], qkv_head_size[2], v_hidden_size, + extra_add_qk, context); } - } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_base.h index f14268af3e..ce9033afd9 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.h @@ -33,16 +33,22 @@ class AttentionBase { num_heads_ = static_cast(num_heads); is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; + + if (!info.GetAttrs("qkv_hidden_sizes", qkv_hidden_sizes_).IsOK() || qkv_hidden_sizes_.empty()) { + qkv_hidden_sizes_.resize(0); + } } Status CheckInputs(const TensorShape& input_shape, const TensorShape& weights_shape, const TensorShape& bias_shape, const Tensor*& mask_index, // For dummy mask with shape (1, 1) or (batch_size, 1), it will be updated to nullptr. - const Tensor* past) const; + const Tensor* past, + const Tensor *extra_add_qk) const; int num_heads_; // number of attention heads bool is_unidirectional_; // whether every token can only attend to previous tokens. + std::vector qkv_hidden_sizes_; // Q, K, V path hidden layer sizes }; } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index 063093d16c..46f2b292d5 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -26,8 +26,10 @@ class AttentionCPUBase : public AttentionBase { Tensor* output, // output tensor int batch_size, // batch size int sequence_length, // sequence length - int head_size, // head size - int hidden_size, // hidden size + int qk_head_size, // qk_head_size + int v_head_size, // head_size + int v_hidden_size, // hidden_size + const Tensor* extra_add_qk,// extra add in QK. Its size is BxNxSxS OpKernelContext* context) const { AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); @@ -35,7 +37,7 @@ class AttentionCPUBase : public AttentionBase { auto* tp = context->GetOperatorThreadPool(); int past_sequence_length = 0; - Tensor* present = GetPresent(context, past, batch_size, head_size, sequence_length, past_sequence_length); + Tensor* present = GetPresent(context, past, batch_size, v_head_size, sequence_length, past_sequence_length); // Total sequence length including that of past state: S* = S' + S const int all_sequence_length = past_sequence_length + sequence_length; @@ -61,18 +63,23 @@ class AttentionCPUBase : public AttentionBase { const T* past_data = past != nullptr ? past->template Data() : nullptr; T* present_data = present != nullptr ? present->template MutableData() : nullptr; + const T* extra_add_qk_data = nullptr; + if (extra_add_qk != nullptr) { + extra_add_qk_data = extra_add_qk->template Data(); + } + ComputeAttentionProbs(static_cast(attention_probs), Q, K, mask_index_data, mask_index_dims, static_cast(mask_data), - batch_size, sequence_length, past_sequence_length, head_size, - past_data, present_data, tp); + batch_size, sequence_length, past_sequence_length, qk_head_size == 0 ? v_head_size : qk_head_size, + past_data, present_data, tp, extra_add_qk_data); // Compute the attentionScore * Value. It does: out_tmp(B, N, S, H) = attention_probs(B, N, S, S*) x V(B, N, S*, H) auto out_tmp_data = - allocator->Alloc(SafeInt(batch_size) * num_heads_ * sequence_length * head_size * sizeof(T)); + allocator->Alloc(SafeInt(batch_size) * num_heads_ * sequence_length * v_head_size * sizeof(T)); BufferUniquePtr out_tmp_buffer(out_tmp_data, BufferDeleter(allocator)); ComputeVxAttentionScore(output->template MutableData(), static_cast(out_tmp_data), static_cast(attention_probs), V, - batch_size, sequence_length, past_sequence_length, head_size, hidden_size, + batch_size, sequence_length, past_sequence_length, v_head_size, v_hidden_size, past_data, present_data, tp); return Status::OK(); @@ -96,7 +103,8 @@ class AttentionCPUBase : public AttentionBase { int head_size, // head size of self-attention const T* past, // past state T* present, // present state - ThreadPool* tp) const { + ThreadPool* tp, + const T* extra_add_qk_data) const { const int all_sequence_length = past_sequence_length + sequence_length; // S* = S' + S const size_t past_chunk_length = static_cast(past_sequence_length) * head_size; // S' x H const size_t input_chunk_length = static_cast(sequence_length) * head_size; // S x H @@ -140,6 +148,13 @@ class AttentionCPUBase : public AttentionBase { math::Gemm(CblasNoTrans, CblasTrans, sequence_length, all_sequence_length, head_size, alpha, Q + input_chunk_length * i, k, 1.0, reinterpret_cast(attention_probs) + sequence_length * all_sequence_length * i, nullptr); + + if (extra_add_qk_data != nullptr) { + int extra_add_qk_offset = static_cast(i) * sequence_length * all_sequence_length; + for (int j = 0; j < sequence_length * all_sequence_length ; j++) { + attention_probs[extra_add_qk_offset+j] += extra_add_qk_data[extra_add_qk_offset + j]; + } + } } }); } diff --git a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc index c8f9a40c23..9f7168b81f 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc @@ -158,7 +158,8 @@ Status QAttention::Compute(OpKernelContext* context) const { weights_shape, bias->Shape(), mask_index, - past_tensor)); + past_tensor, + nullptr)); ORT_RETURN_IF_NOT(IsScalarOr1ElementVector(input_scale_tensor), "input scale must be a scalar or 1D tensor of size 1"); @@ -286,7 +287,7 @@ Status QAttention::Compute(OpKernelContext* context) const { // Compute the attention score and apply the score to V return ApplyAttention(Q, K, V, mask_index, past_tensor, output, batch_size, sequence_length, - head_size, hidden_size, context); + head_size, head_size, hidden_size, nullptr, context); } } // namespace contrib diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index abad4abb75..d0050ac069 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -397,17 +397,32 @@ void AttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int p auto& bias_shape = getInputShape(ctx, 2); auto& bias_dims = bias_shape.dim(); - if (bias_dims.size() != 1 || bias_shape.dim(0).dim_value() % 3 != 0) { + if (bias_dims.size() != 1) { fail_shape_inference("Invalid bias shape"); } + std::vector qkv_hidden_sizes; + getRepeatedAttribute(ctx, "qkv_hidden_sizes", qkv_hidden_sizes); + + int64_t output_hidden_size; + if (qkv_hidden_sizes.size() != 0) { + if (qkv_hidden_sizes.size() != 3) { + fail_shape_inference("qkv_hidden_sizes should have 3 elements") + } + output_hidden_size = qkv_hidden_sizes[2]; + } else { + output_hidden_size = bias_shape.dim(0).dim_value() / 3; + } + ONNX_NAMESPACE::TensorShapeProto output_shape; for (auto& dim : input_dims) { *output_shape.add_dim() = dim; } - output_shape.mutable_dim(2)->set_dim_value(bias_shape.dim(0).dim_value() / 3); + + output_shape.mutable_dim(2)->set_dim_value(output_hidden_size); updateOutputShape(ctx, 0, output_shape); + // TODO does the extra output need any changes? if (ctx.getNumOutputs() > 1) { if (hasInputShape(ctx, past_input_index)) { auto& past_shape = getInputShape(ctx, past_input_index); @@ -453,12 +468,17 @@ and present state are optional. Present state could appear in output even when p "Whether every token can only attend to previous tokens. Default value is 0.", AttributeProto::INT, static_cast(0)) + .Attr("qkv_hidden_sizes", + "Hidden layer sizes of Q, K, V paths in Attention", + AttributeProto::INTS, + OPTIONAL_VALUE) .Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, input_hidden_size)", "T") .Input(1, "weight", "2D input tensor with shape (input_hidden_size, 3 * hidden_size), where hidden_size = num_heads * head_size", "T") .Input(2, "bias", "1D input tensor with shape (3 * hidden_size)", "T") .Input(3, "mask_index", "Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, past_sequence_length + sequence_length)" "or (batch_size, sequence_length, past_sequence_length + sequence_length), or index with shape (batch_size) or (2 * batch_size).", "M", OpSchema::Optional) .Input(4, "past", "past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).", "T", OpSchema::Optional) + .Input(5, "extra_add", "additional add to QxK' with shape (batch_size, num_heads, sequence_length, sequence_length).", "T", OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, append_length, hidden_size)", "T") .Output(1, "present", "present state for key and value with shape (2, batch_size, num_heads, past_sequence_length + sequence_length, head_size)", "T", OpSchema::Optional) .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") @@ -545,6 +565,7 @@ and present state are optional. Present state could appear in output even when p .TypeConstraint("T4", {"tensor(int32)"}, "Constrain mask index to integer types") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { constexpr int past_input_index = 8; + AttentionTypeAndShapeInference(ctx, past_input_index); }); diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 2c50b4a6b0..fdccf86fd4 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -1479,7 +1479,12 @@ class SymbolicShapeInference: shape = self._get_shape(node, 0) shape_bias = self._get_shape(node, 2) assert len(shape) == 3 and len(shape_bias) == 1 - shape[2] = int(shape_bias[0] / 3) + qkv_hidden_sizes_attr = get_attribute(node, 'qkv_hidden_sizes') + if qkv_hidden_sizes_attr is not None: + assert len(qkv_hidden_sizes_attr) == 3 + shape[2] = int(qkv_hidden_sizes_attr[2]) + else: + shape[2] = int(shape_bias[0] / 3) output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type vi = self.known_vi_[node.output[0]] vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape)) diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index c18225d6e9..7830971c4e 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -12,6 +12,7 @@ from onnx import helper, numpy_helper, TensorProto, NodeProto from onnx_model import OnnxModel from fusion_base import Fusion from fusion_utils import FusionUtils, NumpyHelper +from shape_infer_helper import SymbolicShapeInferenceHelper, get_shape_from_type_proto logger = getLogger(__name__) @@ -126,9 +127,29 @@ class FusionAttention(Fusion): return num_heads, hidden_size + def get_add_qk_str(self, add_qk: NodeProto): + # Note: Does not work for dynamic models, reshape node shape inference would fail with more than 2 dims being -1 + # inputs_ids has to be the name of input, this may be changes if needed + inputs_ids = self.model.find_graph_input("input_ids") + if inputs_ids == None: + logger.debug("no input with name \"input_ids\"") + return None + batch_size, seq_len = get_shape_from_type_proto(inputs_ids.type) + + if batch_size < 0 or seq_len < 0: + logger.debug(f"batch_size: {batch_size} and seq_len {seq_len} cannot be -ve") + return None + shape_infer_helper = SymbolicShapeInferenceHelper(self.model.model) + shape_infer_helper.infer({"batch_size": batch_size, "seq_len": seq_len}) + if shape_infer_helper.get_edge_shape(add_qk.input[0]) != shape_infer_helper.get_edge_shape(add_qk.input[1]): + logger.debug(f"the shape of two inputs of {add_qk} is not same") + return None + + return add_qk.input[1] + def create_attention_node(self, mask_index: str, q_matmul: NodeProto, k_matmul: NodeProto, v_matmul: NodeProto, q_add: NodeProto, k_add: NodeProto, v_add: NodeProto, num_heads: int, hidden_size: int, - input: str, output: str) -> Union[NodeProto, None]: + input: str, output: str, add_qk_str: str) -> Union[NodeProto, None]: """ Create an Attention node. Args: @@ -165,6 +186,7 @@ class FusionAttention(Fusion): return None if not (k_weight and v_weight and q_bias and k_bias): return None + qw = NumpyHelper.to_array(q_weight) kw = NumpyHelper.to_array(k_weight) vw = NumpyHelper.to_array(v_weight) @@ -246,6 +268,10 @@ class FusionAttention(Fusion): if mask_index is not None: attention_inputs.append(mask_index) + if add_qk_str is not None: + attention_inputs.append("") + attention_inputs.append(add_qk_str) + attention_node = helper.make_node('Attention', inputs=attention_inputs, outputs=[output], @@ -334,7 +360,7 @@ class FusionAttention(Fusion): logger.debug("fuse_attention: failed to match v path") return (_, _, add_v, matmul_v) = v_nodes - + is_distill = False is_distill_add = False qk_paths = { @@ -365,7 +391,7 @@ class FusionAttention(Fusion): if is_distill: (_, where_qk, matmul_qk, _) = qk_nodes elif is_distill_add: - (_, _, where_qk, matmul_qk) = qk_nodes + (_, add_qk, where_qk, matmul_qk) = qk_nodes else: (_, add_qk, _, matmul_qk) = qk_nodes @@ -392,6 +418,7 @@ class FusionAttention(Fusion): # Note that Cast might be removed by OnnxRuntime so we match two patterns here. mask_nodes = None + add_qk_str = None if is_distill: _, mask_nodes, _ = self.model.match_parent_paths(where_qk, [(['Expand', 'Reshape', 'Equal'], [0, 0, 0]), @@ -401,6 +428,11 @@ class FusionAttention(Fusion): _, mask_nodes, _ = self.model.match_parent_paths( where_qk, [(['Cast', 'Equal', 'Unsqueeze', 'Unsqueeze'], [0, 0, 0, 0]), (['Equal', 'Unsqueeze', 'Unsqueeze'], [0, 0, 0])], output_name_to_node) + if add_qk is not None: + add_qk_str = self.get_add_qk_str(add_qk) + if add_qk_str is None: + logger.debug(f"fuse_attention: failed to verify shape inference of {add_qk}") + return else: _, mask_nodes, _ = self.model.match_parent_paths( add_qk, [(['Mul', 'Sub', 'Cast', 'Unsqueeze', 'Unsqueeze'], [None, 0, 1, 0, 0]), @@ -419,7 +451,7 @@ class FusionAttention(Fusion): # the input_hidden_size represents the input hidden size, this is used as needed but hidden sizes for Q, K are extracted appropriately new_node = self.create_attention_node(mask_index, matmul_q, matmul_k, matmul_v, add_q, add_k, add_v, q_num_heads, self.hidden_size, root_input, - attention_last_node.output[0]) + attention_last_node.output[0], add_qk_str) if new_node is None: return diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py b/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py index 52c3f1017d..c0f8f35517 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py @@ -140,7 +140,7 @@ class BertOnnxModelKeras(BertOnnxModelTF): attention_node = self.attention_fusion.create_attention_node(mask_index, matmul_q, matmul_k, matmul_v, add_q, add_k, add_v, self.num_heads, self.hidden_size, parent.output[0], - reshape_qkv.output[0]) + reshape_qkv.output[0], None) if attention_node is None: continue diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert_tf.py b/onnxruntime/python/tools/transformers/onnx_model_bert_tf.py index 28bff0cbf5..e05593cfff 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert_tf.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert_tf.py @@ -384,7 +384,7 @@ class BertOnnxModelTF(BertOnnxModel): attention_node = self.attention_fusion.create_attention_node(mask_index, matmul_k, matmul_q, matmul_v, add_k, add_q, add_v, self.num_heads, self.hidden_size, parent.output[0], - qkv_nodes[2].output[0]) + qkv_nodes[2].output[0], None) if attention_node is None: continue diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc index 846699732c..9cd2863de4 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test.cc @@ -38,11 +38,14 @@ static void RunAttentionTest( MaskIndexType mask_index_type = kMaskIndexEnd, int input_hidden_size = 0, int max_sequence_length = 0, - bool only_enable_cuda = false) { + bool only_enable_cuda = false, + bool only_enable_cpu = false, + std::vector qkv_sizes = {}, + const std::vector& extra_add_data = {}) { input_hidden_size = (input_hidden_size == 0 ? hidden_size : input_hidden_size); // By default, no pruning. int min_cuda_architecture = use_float16 ? 530 : 0; - bool enable_cuda = HasCudaEnvironment(min_cuda_architecture) && !is_weights_constant; + bool enable_cuda = HasCudaEnvironment(min_cuda_architecture) && !is_weights_constant && !only_enable_cpu; bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()) && !use_float16 && !only_enable_cuda; int head_size = hidden_size / number_of_heads; @@ -51,9 +54,21 @@ static void RunAttentionTest( tester.AddAttribute("num_heads", static_cast(number_of_heads)); tester.AddAttribute("unidirectional", static_cast(is_unidirectional ? 1 : 0)); + int32_t matrix_size; + int32_t output_hidden_size; + if (qkv_sizes.size() != 0) { + matrix_size = qkv_sizes[0] + qkv_sizes[1] + qkv_sizes[2]; + std::vector sizes_attribute{qkv_sizes[0], qkv_sizes[1], qkv_sizes[2]}; + tester.AddAttribute>("qkv_hidden_sizes", sizes_attribute); + output_hidden_size = qkv_sizes[2]; + } else { + matrix_size = 3 * hidden_size; + output_hidden_size = hidden_size; + } + std::vector input_dims = {batch_size, sequence_length, input_hidden_size}; - std::vector weights_dims = {input_hidden_size, 3 * hidden_size}; - std::vector bias_dims = {3 * hidden_size}; + std::vector weights_dims = {input_hidden_size, matrix_size}; + std::vector bias_dims = {matrix_size}; std::vector mask_index_dims_1 = {batch_size}; std::vector mask_index_dims_2 = {2 * batch_size}; @@ -88,7 +103,7 @@ static void RunAttentionTest( std::vector past_dims = {2, batch_size, number_of_heads, past_sequence_length, head_size}; std::vector present_dims = {2, batch_size, number_of_heads, past_sequence_length + sequence_length, head_size}; - std::vector output_dims = {batch_size, sequence_length, hidden_size}; + std::vector output_dims = {batch_size, sequence_length, output_hidden_size}; if (use_float16) { tester.AddInput("input", input_dims, ToFloat16(input_data)); @@ -122,6 +137,14 @@ static void RunAttentionTest( } } + if (extra_add_data.size() > 0) { + if (!use_past_state) { + tester.AddOptionalInputEdge(); + } + std::vector extra_add_data_dims = {batch_size, number_of_heads, sequence_length, sequence_length}; + tester.AddInput("extra_add_qk", extra_add_data_dims, extra_add_data); + } + if (enable_cuda) { std::vector> execution_providers; execution_providers.push_back(DefaultCudaExecutionProvider()); @@ -155,17 +178,20 @@ static void RunAttentionTest( MaskIndexType mask_index_type = kMaskIndexEnd, int input_hidden_size = 0, int max_sequence_length = 0, - bool only_enable_cuda = false) { - RunAttentionTest(input_data, weights_data, false, bias_data, mask_index_data, output_data, - batch_size, sequence_length, hidden_size, number_of_heads, - use_float16, is_unidirectional, use_past_state, past_sequence_length, - past_data, present_data, mask_index_type, input_hidden_size, max_sequence_length, - only_enable_cuda); - RunAttentionTest(input_data, weights_data, true, bias_data, mask_index_data, output_data, - batch_size, sequence_length, hidden_size, number_of_heads, - use_float16, is_unidirectional, use_past_state, past_sequence_length, - past_data, present_data, mask_index_type, input_hidden_size, max_sequence_length, - only_enable_cuda); + bool only_enable_cuda = false, + bool only_enable_cpu = false, + const std::vector qkv_sizes = {}, + const std::vector& extra_add_data = {}) { + RunAttentionTest(input_data, weights_data, false, bias_data, mask_index_data, output_data, + batch_size, sequence_length, hidden_size, number_of_heads, + use_float16, is_unidirectional, use_past_state, past_sequence_length, + past_data, present_data, mask_index_type, input_hidden_size, max_sequence_length, + only_enable_cuda, only_enable_cpu, qkv_sizes, extra_add_data); + RunAttentionTest(input_data, weights_data, true, bias_data, mask_index_data, output_data, + batch_size, sequence_length, hidden_size, number_of_heads, + use_float16, is_unidirectional, use_past_state, past_sequence_length, + past_data, present_data, mask_index_type, input_hidden_size, max_sequence_length, + only_enable_cuda, only_enable_cpu, qkv_sizes, extra_add_data); } TEST(AttentionTest, AttentionBatch1) { @@ -197,6 +223,163 @@ TEST(AttentionTest, AttentionBatch1) { batch_size, sequence_length, hidden_size, number_of_heads); } +TEST(AttentionTest, AttentionBatch1WithQKVAttr1) { + int batch_size = 1; + int sequence_length = 2; + int hidden_size = 4; + int number_of_heads = 2; + + std::vector input_data = { + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f}; + + std::vector qkv_sizes = { + 6, 6, 4}; + + std::vector weight_data = { + 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f, + 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f, + + 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f, + 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f, + + 0.3f, 0.2f, 4.0f, 2.2f, 2.4f, 3.3f, 2.1f, 4.2f, 0.5f, 0.1f, 0.4f, 1.6f, + 0.4f, 0.8f, 0.9f, 0.1f + }; + + std::vector bias_data = { + -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, + 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f, + 0.5f, 0.7f, 0.2f, 1.2f}; + + std::vector mask_index_data = {2L}; + + std::vector output_data = { + 3.1967618465423584f, 0.51903456449508667f, 0.63051539659500122f, 2.9394614696502686f, + 0.65332180261611938f, 1.000949501991272f, 0.74175024032592773f, 2.8231701850891113f}; + + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, + batch_size, sequence_length, hidden_size, number_of_heads, + false, false, false, 0, nullptr, nullptr, kMaskIndexEnd, 0, + 0, false, true, qkv_sizes); +} + +TEST(AttentionTest, AttentionBatch1WithQKVAttr2) { + int batch_size = 1; + int sequence_length = 2; + int hidden_size = 4; + int number_of_heads = 2; + + std::vector input_data = { + -0.031707365f, 0.053643607f, 0.057394292f, -0.019800574f, 0.075466447f, -0.0034214978f, 0.012995008f, -0.019587509f}; + + std::vector qkv_sizes = { + 6, 6, 2}; + + std::vector weight_data = { + 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f, + 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f, + + 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f, + 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f, + + 0.3f, 0.2f, 4.0f, 2.2f, 2.4f, 3.3f, 2.1f, 4.2f}; + + std::vector bias_data = { + -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, + 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f, + 0.5f, 0.7f}; + + std::vector mask_index_data = {2L}; + + std::vector output_data = { + 0.64932525157928467f, 0.79390722513198853f, 0.64932847023010254f, 0.79375863075256348f}; + + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, + batch_size, sequence_length, hidden_size, number_of_heads, + false, false, false, 0, nullptr, nullptr, kMaskIndexEnd, 0, + 0, false, true, qkv_sizes); +} + +TEST(AttentionTest, AttentionBatch1ExtraAdd) { + int batch_size = 1; + int sequence_length = 2; + int hidden_size = 4; + int number_of_heads = 2; + + std::vector input_data = { + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f}; + + std::vector qkv_sizes = {}; + + std::vector weight_data = { + 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f, + 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f, + 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f, + 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f}; + + std::vector bias_data = { + -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, + 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f}; + + std::vector mask_index_data = {2L}; + + std::vector extra_add_qk = { + 0.2f, -0.1f, 0.4f, 2.5f, 1.6f, -1.1f, 0.4f, -2.5f}; + + std::vector output_data = { + 4.066014289855957f, 0.068997815251350403f, 4.25f, 5.6499996185302734f, + -1.8799558877944946f, 0.32488855719566345f, 4.25f, 5.6499996185302734f}; + + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, + batch_size, sequence_length, hidden_size, number_of_heads, + false, false, false, 0, nullptr, nullptr, kMaskIndexEnd, 0, + 0, false, true, qkv_sizes, extra_add_qk); +} + +TEST(AttentionTest, AttentionBatch2ExtraAdd) { + int batch_size = 2; + int sequence_length = 2; + int hidden_size = 4; + int number_of_heads = 2; + + std::vector input_data = { + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f, + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f}; + + std::vector qkv_sizes = {}; + + std::vector weight_data = { + 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f, + 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f, + 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f, + 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f}; + + std::vector bias_data = { + -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, + 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f}; + + std::vector mask_index_data = {2L, 2L}; + + std::vector extra_add_qk = { + 0.2f, -0.1f, 0.4f, 2.5f, 1.6f, -1.1f, 0.4f, -2.5f, + 0.2f, -0.1f, 0.4f, 2.5f, 1.6f, -1.1f, 0.4f, -2.5f}; + + std::vector output_data = { + 4.066014289855957f, 0.068997815251350403f, 4.25f, 5.6499996185302734f, + -1.8799558877944946f, 0.32488855719566345f, 4.25f, 5.6499996185302734f, + 4.066014289855957f, 0.068997815251350403f, 4.25f, 5.6499996185302734f, + -1.8799558877944946f, 0.32488855719566345f, 4.25f, 5.6499996185302734f}; + + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, + batch_size, sequence_length, hidden_size, number_of_heads, + false, false, false, 0, nullptr, nullptr, kMaskIndexEnd, 0, + 0, false, true, qkv_sizes, extra_add_qk); +} + TEST(AttentionTest, AttentionBatch1_Float16) { int batch_size = 1; int sequence_length = 2; diff --git a/onnxruntime/test/python/transformers/bert_model_generator.py b/onnxruntime/test/python/transformers/bert_model_generator.py index dc9a504578..4764334ada 100644 --- a/onnxruntime/test/python/transformers/bert_model_generator.py +++ b/onnxruntime/test/python/transformers/bert_model_generator.py @@ -6,6 +6,7 @@ import onnx import math +import numpy as np from typing import List from packaging import version from onnx import helper, TensorProto @@ -17,7 +18,7 @@ def float_tensor(name: str, shape: List[int], random=False): total_elements = 1 for x in shape: total_elements *= x - weights = [random.uniform(low, high) for _ in range(total_elements)] if random else [1.0] * total_elements + weights = [np.random.uniform(low, high) for _ in range(total_elements)] if random else [1.0] * total_elements return helper.make_tensor(name, TensorProto.FLOAT, shape, weights)