diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 03b1b24c49..60fbad8254 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -84,11 +84,13 @@ This version of the operator has been available since version 1 of the 'com.micr
- num_heads : int (required)
- Number of attention heads
+- qkv_hidden_sizes : list of ints
+- Hidden layer sizes of Q, K, V paths in Attention
- unidirectional : int
- Whether every token can only attend to previous tokens. Default value is 0.
-#### Inputs (3 - 5)
+#### Inputs (3 - 6)
- input : T
@@ -101,6 +103,8 @@ This version of the operator has been available since version 1 of the 'com.micr
- Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, past_sequence_length + sequence_length)or (batch_size, sequence_length, past_sequence_length + sequence_length), or index with shape (batch_size) or (2 * batch_size).
- past (optional) : T
- past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).
+- extra_add (optional) : T
+- additional add to QxK' with shape (batch_size, num_heads, sequence_length, sequence_length).
#### Outputs (1 - 2)
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 3665bae302..0286bf3cd7 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -366,7 +366,7 @@ Do not modify directly.*
| |
| |
|**Operator Domain:** *com.microsoft*||||
-|Attention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float)|
+|Attention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* extra_add:**T**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float)|
|AttnLSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* QW:**T**
*in* MW:**T**
*in* V:**T**
*in* M:**T**
*in* memory_seq_lens:**T1**
*in* AW:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)
**T1** = tensor(int32)|
|BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float)|
|CDist|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(double), tensor(float)|
@@ -705,7 +705,7 @@ Do not modify directly.*
| |
| |
|**Operator Domain:** *com.microsoft*||||
-|Attention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
+|Attention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* extra_add:**T**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
|BiasDropout|*in* data:**T**
*in* bias:**T**
*in* residual:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)|
|BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|BiasSoftmax|*in* data:**T**
*in* bias:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc
index a1d6889c6c..8375ff7dfb 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc
@@ -20,6 +20,8 @@ class Attention : public OpKernel, public AttentionCPUBase {
public:
explicit Attention(const OpKernelInfo& info);
+ bool IsPackWeightsSuccessful(int qkv_index, AllocatorPtr alloc, size_t head_size, size_t input_hidden_size, const T* weights_data, size_t weight_matrix_col_size, PrePackedWeights* prepacked_weights);
+
Status Compute(OpKernelContext* context) const override;
Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
@@ -31,8 +33,13 @@ class Attention : public OpKernel, public AttentionCPUBase {
/*out*/ bool& used_shared_buffers) override;
private:
- BufferUniquePtr packed_weights_;
- size_t packed_weights_size_ = 0;
+ BufferUniquePtr q_packed_weights_;
+ BufferUniquePtr k_packed_weights_;
+ BufferUniquePtr v_packed_weights_;
+
+ size_t q_packed_weights_size_ = 0;
+ size_t k_packed_weights_size_ = 0;
+ size_t v_packed_weights_size_ = 0;
TensorShape weight_shape_;
};
@@ -51,7 +58,8 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
const TensorShape& weights_shape,
const TensorShape& bias_shape,
const Tensor*& mask_index,
- const Tensor* past) const {
+ const Tensor* past,
+ const Tensor* extra_add_qk) const {
// Input shapes:
// input : (batch_size, sequence_length, input_hidden_size)
// weights : (input_hidden_size, 3 * hidden_size)
@@ -61,10 +69,16 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
// or (batch_size, past_sequence_length + sequence_length)
// or (batch_size, sequence_length, past_sequence_length + sequence_length)
// past : (2, batch_size, num_heads, past_sequence_length, head_size)
+ // extra_add_qk: (batch_size, num_heads, sequence_length, sequence_length)
//
// Where hidden_size = num_heads * head_size.
// When a model is pruned (like some attention heads are removed), hidden_size < input_hidden_size.
+ if (past != nullptr && extra_add_qk != nullptr) {
+ // past is used on GPT-2 model with past state, we don't have a case for extra add qk yet
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Attention cannot have past sequence and extra add qk");
+ }
+
const auto& dims = input_shape.GetDims();
if (dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'input' is expected to have 3 dimensions, got ",
@@ -83,21 +97,53 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
"Input 1 dimension 0 should have same length as dimension 2 of input 0");
}
- int hidden_size = static_cast(weights_dims[1]) / 3;
- if (3 * hidden_size != static_cast(weights_dims[1])) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 1 dimension 1 should be 3 times of hidden dimension");
- }
-
- if (hidden_size % num_heads_ != 0) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "hidden_size should be divisiable by num_heads.");
- }
-
const auto& bias_dims = bias_shape.GetDims();
if (bias_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' is expected to have 1 dimension, got ",
bias_dims.size());
}
+
+ int hidden_size = 0;
+
+ if (qkv_hidden_sizes_.size() == 0) {
+ hidden_size = static_cast(weights_dims[1]) / 3;
+ if (3 * hidden_size != static_cast(weights_dims[1])) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 1 dimension 1 should be 3 times of hidden dimension");
+ }
+
+ if (hidden_size % num_heads_ != 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "hidden_size should be divisiable by num_heads.");
+ }
+ } else {
+ int qkv_sizes = 0;
+
+ if (qkv_hidden_sizes_.size() != 3) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "qkv_hidden_sizes attribute should have 3 elements");
+ }
+
+ if (qkv_hidden_sizes_[0] != qkv_hidden_sizes_[1]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "qkv_hidden_sizes first element should be same as the second");
+ }
+
+ for (size_t i = 0; i < qkv_hidden_sizes_.size(); i++) {
+ if (qkv_hidden_sizes_[i] % num_heads_ != 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "hidden_size should be divisiable by num_heads:", qkv_hidden_sizes_[i]);
+ }
+
+ qkv_sizes += static_cast(qkv_hidden_sizes_[i]);
+ }
+
+ int qkv_hidden_sizes_sum = static_cast(weights_dims[1]);
+ if (qkv_hidden_sizes_sum != qkv_sizes) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "qkv_sizes doesn't match the wights dimension");
+ }
+
+ hidden_size = static_cast(qkv_hidden_sizes_[2]);
+ }
+
if (bias_dims[0] != weights_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'bias' dimension 0 should have same length as dimension 1 of input 'weights'");
@@ -157,6 +203,33 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
mask_dims.size());
}
}
+
+ if (extra_add_qk != nullptr) {
+ const auto& extra_add_qk_dims = extra_add_qk->Shape().GetDims();
+
+ if (extra_add_qk_dims.size() != 4) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'extra_add_qk' is expected to have 4 dimensions, got ",
+ extra_add_qk_dims.size());
+ }
+
+ if (extra_add_qk_dims[0] != batch_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'extra_add_qk' dimension 0 should be same as batch_size, got ",
+ extra_add_qk_dims[0]);
+ }
+ if (extra_add_qk_dims[1] != num_heads_) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'extra_add_qk' dimension 1 should be same as number of heads, got ",
+ extra_add_qk_dims[1]);
+ }
+ if (extra_add_qk_dims[2] != sequence_length) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'extra_add_qk' dimension 2 should be same as sequence_length, got ",
+ extra_add_qk_dims[2]);
+ }
+ if (extra_add_qk_dims[3] != sequence_length) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'extra_add_qk' dimension 3 should be same as sequence_length, got ",
+ extra_add_qk_dims[3]);
+ }
+ }
+
return Status::OK();
}
@@ -170,7 +243,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block);
}
- return CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past);
+ return CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past, nullptr);
}
Tensor* AttentionBase::GetPresent(OpKernelContext* context,
@@ -203,6 +276,71 @@ template
Attention::Attention(const OpKernelInfo& info) : OpKernel(info), AttentionCPUBase(info) {
}
+template
+bool Attention::IsPackWeightsSuccessful(int qkv_index,
+ AllocatorPtr alloc,
+ size_t head_size,
+ size_t input_hidden_size,
+ const T* weights_data,
+ size_t weight_matrix_col_size,
+ /*out*/ PrePackedWeights* prepacked_weights) {
+ size_t packb_size = MlasGemmPackBSize(head_size, input_hidden_size);
+ if (packb_size == 0) {
+ return false;
+ }
+
+ size_t loop_len = static_cast(num_heads_);
+ size_t packed_weights_data_size = packb_size * loop_len; // The same size would be computed by AllocArray() below
+ auto* packed_weights_data = static_cast(alloc->AllocArray(packb_size, loop_len));
+
+ // Initialize memory to 0 as there could be some padding associated with pre-packed
+ // buffer memory and we don not want it uninitialized and generate different hashes
+ // if and when we try to cache this pre-packed buffer for sharing between sessions.
+ memset(packed_weights_data, 0, packed_weights_data_size);
+ switch (qkv_index) {
+ case 0:
+ q_packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc));
+ q_packed_weights_size_ = packb_size;
+ break;
+ case 1:
+ k_packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc));
+ k_packed_weights_size_ = packb_size;
+ break;
+ case 2:
+ v_packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc));
+ v_packed_weights_size_ = packb_size;
+ break;
+ default:
+ return false;
+ }
+
+ for (size_t i = 0; i < loop_len; i++) {
+ MlasGemmPackB(CblasNoTrans, head_size, input_hidden_size, weights_data, weight_matrix_col_size, packed_weights_data);
+ packed_weights_data += packb_size;
+ weights_data += head_size;
+ }
+
+ bool share_prepacked_weights = (prepacked_weights != nullptr);
+ if (share_prepacked_weights) {
+ switch (qkv_index) {
+ case 0:
+ prepacked_weights->buffers_.push_back(std::move(q_packed_weights_));
+ break;
+ case 1:
+ prepacked_weights->buffers_.push_back(std::move(k_packed_weights_));
+ break;
+ case 2:
+ prepacked_weights->buffers_.push_back(std::move(v_packed_weights_));
+ break;
+ default:
+ break;
+ }
+
+ prepacked_weights->buffer_sizes_.push_back(packed_weights_data_size);
+ }
+ return true;
+}
+
template
Status Attention::PrePack(const Tensor& weights, int input_idx, AllocatorPtr alloc,
/*out*/ bool& is_packed,
@@ -219,46 +357,47 @@ Status Attention::PrePack(const Tensor& weights, int input_idx, AllocatorPtr
return Status::OK();
}
- const size_t input_hidden_size = static_cast(weights_dims[0]);
- const size_t hidden_size_x3 = static_cast(weights_dims[1]);
- const size_t hidden_size = hidden_size_x3 / 3;
- const size_t head_size = hidden_size / num_heads_;
-
- // Bail out if the weights shape has an expected shape.
- if ((hidden_size == 0) || ((hidden_size % num_heads_) != 0) || (hidden_size_x3 != 3 * hidden_size)) {
- return Status::OK();
- }
-
const auto* weights_data = weights.Data();
+ const size_t input_hidden_size = static_cast(weights_dims[0]);
+ size_t q_hidden_size, k_hidden_size, v_hidden_size;
- packed_weights_size_ = MlasGemmPackBSize(head_size, input_hidden_size);
- if (packed_weights_size_ == 0) {
+ if (qkv_hidden_sizes_.size() != 0) {
+ q_hidden_size = qkv_hidden_sizes_[0];
+ k_hidden_size = qkv_hidden_sizes_[1];
+ v_hidden_size = qkv_hidden_sizes_[2];
+
+ if (q_hidden_size == 0 || k_hidden_size == 0 || v_hidden_size == 0) {
+ return Status::OK();
+ }
+
+ if (q_hidden_size % num_heads_ != 0 || k_hidden_size % num_heads_ != 0 || v_hidden_size % num_heads_ != 0) {
+ return Status::OK();
+ }
+ } else {
+ const size_t hidden_size_x3 = static_cast(weights_dims[1]);
+ const size_t hidden_size = hidden_size_x3 / 3;
+
+ if (hidden_size % num_heads_ != 0) {
+ return Status::OK();
+ }
+
+ q_hidden_size = hidden_size;
+ k_hidden_size = hidden_size;
+ v_hidden_size = hidden_size;
+ }
+
+ const size_t q_head_size = q_hidden_size / num_heads_;
+ const size_t k_head_size = k_hidden_size / num_heads_;
+ const size_t v_head_size = v_hidden_size / num_heads_;
+ const size_t weight_matrix_col_size = q_hidden_size + k_hidden_size + v_hidden_size;
+
+ if (!IsPackWeightsSuccessful(0, alloc, q_head_size, input_hidden_size, weights_data, weight_matrix_col_size, prepacked_weights) ||
+ !IsPackWeightsSuccessful(1, alloc, k_head_size, input_hidden_size, weights_data + (num_heads_ * q_head_size), weight_matrix_col_size, prepacked_weights) ||
+ !IsPackWeightsSuccessful(2, alloc, v_head_size, input_hidden_size, weights_data + (num_heads_ * (q_head_size + k_head_size)), weight_matrix_col_size, prepacked_weights)) {
+ // we are not cleaning up anything, assuming caller takes care of this
return Status::OK();
}
- const size_t loop_len = static_cast(3) * num_heads_;
- size_t packed_weights_data_size = packed_weights_size_ * loop_len; // The same size would be computed by AllocArray() below
- auto* packed_weights_data = static_cast(alloc->AllocArray(packed_weights_size_, loop_len));
-
- // Initialize memory to 0 as there could be some padding associated with pre-packed
- // buffer memory and we don not want it uninitialized and generate different hashes
- // if and when we try to cache this pre-packed buffer for sharing between sessions.
- memset(packed_weights_data, 0, packed_weights_data_size);
-
- packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc));
-
- for (size_t i = 0; i < loop_len; i++) {
- MlasGemmPackB(CblasNoTrans, head_size, input_hidden_size, weights_data, hidden_size_x3, packed_weights_data);
- packed_weights_data += packed_weights_size_;
- weights_data += head_size;
- }
-
- bool share_prepacked_weights = (prepacked_weights != nullptr);
- if (share_prepacked_weights) {
- prepacked_weights->buffers_.push_back(std::move(packed_weights_));
- prepacked_weights->buffer_sizes_.push_back(packed_weights_data_size);
- }
-
is_packed = true;
return Status::OK();
}
@@ -272,7 +411,9 @@ Status Attention::UseSharedPrePackedBuffers(std::vector& pre
}
used_shared_buffers = true;
- packed_weights_ = std::move(prepacked_buffers[0]);
+ q_packed_weights_ = std::move(prepacked_buffers[0]);
+ k_packed_weights_ = std::move(prepacked_buffers[1]);
+ v_packed_weights_ = std::move(prepacked_buffers[2]);
return Status::OK();
}
@@ -280,25 +421,35 @@ Status Attention::UseSharedPrePackedBuffers(std::vector& pre
template
Status Attention::Compute(OpKernelContext* context) const {
const Tensor* input = context->Input(0);
- const Tensor* weights = packed_weights_ ? nullptr : context->Input(1);
+ const Tensor* weights = q_packed_weights_ ? nullptr : context->Input(1);
const Tensor* bias = context->Input(2);
+
const Tensor* mask_index = context->Input(3);
const Tensor* past = context->Input(4);
+ const Tensor* extra_add_qk = context->Input(5);
const TensorShape& weights_shape = (weights ? weights->Shape() : weight_shape_);
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(),
weights_shape,
bias->Shape(),
mask_index,
- past));
+ past,
+ extra_add_qk));
const auto& shape = input->Shape().GetDims();
const int batch_size = static_cast(shape[0]);
const int sequence_length = static_cast(shape[1]);
const int input_hidden_size = static_cast(shape[2]);
+
+ int hidden_size;
+
+ if (qkv_hidden_sizes_.size() == 0) {
+ const auto& weights_dims = weights_shape.GetDims();
+ hidden_size = static_cast(weights_dims[1]) / 3;
+ } else {
+ hidden_size = static_cast(qkv_hidden_sizes_[2]);
+ }
- const auto& weights_dims = weights_shape.GetDims();
- const int hidden_size = static_cast(weights_dims[1]) / 3;
const int head_size = hidden_size / num_heads_;
std::vector output_shape(3);
@@ -309,19 +460,35 @@ Status Attention::Compute(OpKernelContext* context) const {
constexpr size_t element_size = sizeof(T);
+ int q_hidden_size = 0;
+ int k_hidden_size = 0;
+ int v_hidden_size = 0;
+ if (qkv_hidden_sizes_.size() == 0) {
+ q_hidden_size = hidden_size;
+ k_hidden_size = hidden_size;
+ v_hidden_size = hidden_size;
+ } else {
+ q_hidden_size = static_cast(qkv_hidden_sizes_[0]);
+ k_hidden_size = static_cast(qkv_hidden_sizes_[1]);
+ v_hidden_size = static_cast(qkv_hidden_sizes_[2]);
+ }
+ const int qkv_head_size[3] = {q_hidden_size / num_heads_, k_hidden_size / num_heads_, v_hidden_size / num_heads_};
+
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
auto* tp = context->GetOperatorThreadPool();
// Compute Q, K, V
- // gemm_data(BS, 3NH) = input(BS, D) x weights(D, 3NH) + bias(3NH)
- // D (input_hidden_size) is hidden dimension of input, where D could be larger than hidden_size (NH) when model is pruned.
- auto gemm_data = allocator->Alloc(SafeInt(batch_size) * sequence_length * 3 * hidden_size * element_size);
+ // gemm_data(BS, NT) = input(BS, D) x weights(D, NT) + bias(NT)
+ // D (input_hidden_size) is hidden dimension of input, where D could be larger than any of the hidden_sizes
+ // (NH) when model is pruned. T = H1 + H2 + H3, where H1, H2, H3 are head sizes of Q, K, V respectively
+ auto gemm_data = allocator->Alloc(SafeInt(batch_size) * sequence_length * (q_hidden_size + k_hidden_size + v_hidden_size) * element_size);
BufferUniquePtr gemm_buffer(gemm_data, BufferDeleter(allocator));
auto Q = reinterpret_cast(gemm_data);
- auto K = Q + static_cast(batch_size) * sequence_length * hidden_size;
- auto V = K + static_cast(batch_size) * sequence_length * hidden_size;
+ auto K = Q + static_cast(batch_size) * sequence_length * q_hidden_size;
+ auto V = K + static_cast(batch_size) * sequence_length * k_hidden_size;
+
T* QKV[3] = {Q, K, V};
{
@@ -339,14 +506,25 @@ Status Attention::Compute(OpKernelContext* context) const {
const int qkv_index = static_cast(i % 3);
int input_offset = batch_index * sequence_length * input_hidden_size;
- int weights_offset = qkv_index * hidden_size + head_index * head_size;
+
T* qkv_dest = QKV[qkv_index];
+ int head_size = qkv_head_size[qkv_index];
+ int weights_offset = 0;
+ int bias_offset = qkv_index * q_hidden_size + head_index * head_size;
+
+ if (q_packed_weights_ == nullptr) {
+ weights_offset = bias_offset;
+ } else {
+ weights_offset = head_index * head_size;
+ }
+
int qkv_offset = (batch_index * num_heads_ + head_index) * (sequence_length * head_size);
// TODO!! memcpy here makes it not worthwhile to use Gemm batch. Possible to post process?
- // broadcast 3NH -> (3.B.N.S.H)
- const T* broadcast_data_src = bias_data + weights_offset;
+ // broadcast NH -> (B.N.S.H) for each of Q, K, V
+ const T* broadcast_data_src = bias_data + bias_offset;
T* broadcast_data_dest = QKV[qkv_index] + qkv_offset;
+
for (int seq_index = 0; seq_index < sequence_length; seq_index++) {
memcpy(broadcast_data_dest, broadcast_data_src, head_size * sizeof(T));
broadcast_data_dest += head_size;
@@ -354,15 +532,22 @@ Status Attention::Compute(OpKernelContext* context) const {
// original transposed iteration
// A: input (BxSxD) (B.)S x D S x D
- // B: weights (Dx3xNxH) D x (3.N.)H D x H
- // C: QKV[qkv_index] (3xBxNxSxH) (3.B.N.)S x H S x H
- if (packed_weights_) {
- const auto* packed_weight =
- static_cast(packed_weights_.get()) + packed_weights_size_ * (weights_offset / head_size);
+ // B: weights (DxNxT) D x (N.)T D x H
+ // C: QKV[qkv_index] (BxNxSxT) (B.N.)S x T S x H
+ if (q_packed_weights_) {
+ uint8_t* packed_weight;
+ if (qkv_index == 0) {
+ packed_weight = static_cast(q_packed_weights_.get()) + q_packed_weights_size_ * (weights_offset / head_size);
+ } else if (qkv_index == 1) {
+ packed_weight = static_cast(k_packed_weights_.get()) + k_packed_weights_size_ * (weights_offset / head_size);
+ } else {
+ packed_weight = static_cast(v_packed_weights_.get()) + v_packed_weights_size_ * (weights_offset / head_size);
+ }
+
MlasGemm(
CblasNoTrans, // TransA = no
sequence_length, // M = S
- head_size, // N = H
+ head_size, // N = H
input_hidden_size, // K = D
1.0f, // alpha
input_data + input_offset, // A
@@ -374,20 +559,20 @@ Status Attention::Compute(OpKernelContext* context) const {
nullptr); // use single-thread
} else {
math::GemmEx(
- CblasNoTrans, // TransA = no
- CblasNoTrans, // TransB = no
- sequence_length, // M = S
- head_size, // N = H
- input_hidden_size, // K = D
- 1.0f, // alpha
- input_data + input_offset, // A
- input_hidden_size, // lda = D
- weights_data + weights_offset, // B
- 3 * hidden_size, // ldb = 3NH
- 1.0f, // beta
- qkv_dest + qkv_offset, // C
- head_size, // ldc
- nullptr // use single-thread
+ CblasNoTrans, // TransA = no
+ CblasNoTrans, // TransB = no
+ sequence_length, // M = S
+ head_size, // N = H
+ input_hidden_size, // K = D
+ 1.0f, // alpha
+ input_data + input_offset, // A
+ input_hidden_size, // lda = D
+ weights_data + weights_offset, // B
+ q_hidden_size + k_hidden_size + v_hidden_size,// ldb = NH1 + NH2 + NH3
+ 1.0f, // beta
+ qkv_dest + qkv_offset, // C
+ head_size, // ldc
+ nullptr // use single-thread
);
}
}
@@ -397,8 +582,8 @@ Status Attention::Compute(OpKernelContext* context) const {
// Compute the attention score and apply the score to V
return ApplyAttention(Q, K, V, mask_index, past, output,
batch_size, sequence_length,
- head_size, hidden_size, context);
+ qkv_head_size[0], qkv_head_size[2], v_hidden_size,
+ extra_add_qk, context);
}
-
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_base.h
index f14268af3e..ce9033afd9 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_base.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.h
@@ -33,16 +33,22 @@ class AttentionBase {
num_heads_ = static_cast(num_heads);
is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1;
+
+ if (!info.GetAttrs("qkv_hidden_sizes", qkv_hidden_sizes_).IsOK() || qkv_hidden_sizes_.empty()) {
+ qkv_hidden_sizes_.resize(0);
+ }
}
Status CheckInputs(const TensorShape& input_shape,
const TensorShape& weights_shape,
const TensorShape& bias_shape,
const Tensor*& mask_index, // For dummy mask with shape (1, 1) or (batch_size, 1), it will be updated to nullptr.
- const Tensor* past) const;
+ const Tensor* past,
+ const Tensor *extra_add_qk) const;
int num_heads_; // number of attention heads
bool is_unidirectional_; // whether every token can only attend to previous tokens.
+ std::vector qkv_hidden_sizes_; // Q, K, V path hidden layer sizes
};
} // namespace contrib
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
index 063093d16c..46f2b292d5 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
@@ -26,8 +26,10 @@ class AttentionCPUBase : public AttentionBase {
Tensor* output, // output tensor
int batch_size, // batch size
int sequence_length, // sequence length
- int head_size, // head size
- int hidden_size, // hidden size
+ int qk_head_size, // qk_head_size
+ int v_head_size, // head_size
+ int v_hidden_size, // hidden_size
+ const Tensor* extra_add_qk,// extra add in QK. Its size is BxNxSxS
OpKernelContext* context) const {
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
@@ -35,7 +37,7 @@ class AttentionCPUBase : public AttentionBase {
auto* tp = context->GetOperatorThreadPool();
int past_sequence_length = 0;
- Tensor* present = GetPresent(context, past, batch_size, head_size, sequence_length, past_sequence_length);
+ Tensor* present = GetPresent(context, past, batch_size, v_head_size, sequence_length, past_sequence_length);
// Total sequence length including that of past state: S* = S' + S
const int all_sequence_length = past_sequence_length + sequence_length;
@@ -61,18 +63,23 @@ class AttentionCPUBase : public AttentionBase {
const T* past_data = past != nullptr ? past->template Data() : nullptr;
T* present_data = present != nullptr ? present->template MutableData() : nullptr;
+ const T* extra_add_qk_data = nullptr;
+ if (extra_add_qk != nullptr) {
+ extra_add_qk_data = extra_add_qk->template Data();
+ }
+
ComputeAttentionProbs(static_cast(attention_probs), Q, K,
mask_index_data, mask_index_dims, static_cast(mask_data),
- batch_size, sequence_length, past_sequence_length, head_size,
- past_data, present_data, tp);
+ batch_size, sequence_length, past_sequence_length, qk_head_size == 0 ? v_head_size : qk_head_size,
+ past_data, present_data, tp, extra_add_qk_data);
// Compute the attentionScore * Value. It does: out_tmp(B, N, S, H) = attention_probs(B, N, S, S*) x V(B, N, S*, H)
auto out_tmp_data =
- allocator->Alloc(SafeInt(batch_size) * num_heads_ * sequence_length * head_size * sizeof(T));
+ allocator->Alloc(SafeInt(batch_size) * num_heads_ * sequence_length * v_head_size * sizeof(T));
BufferUniquePtr out_tmp_buffer(out_tmp_data, BufferDeleter(allocator));
ComputeVxAttentionScore(output->template MutableData(), static_cast(out_tmp_data), static_cast(attention_probs), V,
- batch_size, sequence_length, past_sequence_length, head_size, hidden_size,
+ batch_size, sequence_length, past_sequence_length, v_head_size, v_hidden_size,
past_data, present_data, tp);
return Status::OK();
@@ -96,7 +103,8 @@ class AttentionCPUBase : public AttentionBase {
int head_size, // head size of self-attention
const T* past, // past state
T* present, // present state
- ThreadPool* tp) const {
+ ThreadPool* tp,
+ const T* extra_add_qk_data) const {
const int all_sequence_length = past_sequence_length + sequence_length; // S* = S' + S
const size_t past_chunk_length = static_cast(past_sequence_length) * head_size; // S' x H
const size_t input_chunk_length = static_cast(sequence_length) * head_size; // S x H
@@ -140,6 +148,13 @@ class AttentionCPUBase : public AttentionBase {
math::Gemm(CblasNoTrans, CblasTrans, sequence_length, all_sequence_length, head_size, alpha,
Q + input_chunk_length * i, k, 1.0,
reinterpret_cast(attention_probs) + sequence_length * all_sequence_length * i, nullptr);
+
+ if (extra_add_qk_data != nullptr) {
+ int extra_add_qk_offset = static_cast(i) * sequence_length * all_sequence_length;
+ for (int j = 0; j < sequence_length * all_sequence_length ; j++) {
+ attention_probs[extra_add_qk_offset+j] += extra_add_qk_data[extra_add_qk_offset + j];
+ }
+ }
}
});
}
diff --git a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc
index c8f9a40c23..9f7168b81f 100644
--- a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc
+++ b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc
@@ -158,7 +158,8 @@ Status QAttention::Compute(OpKernelContext* context) const {
weights_shape,
bias->Shape(),
mask_index,
- past_tensor));
+ past_tensor,
+ nullptr));
ORT_RETURN_IF_NOT(IsScalarOr1ElementVector(input_scale_tensor),
"input scale must be a scalar or 1D tensor of size 1");
@@ -286,7 +287,7 @@ Status QAttention::Compute(OpKernelContext* context) const {
// Compute the attention score and apply the score to V
return ApplyAttention(Q, K, V, mask_index, past_tensor, output,
batch_size, sequence_length,
- head_size, hidden_size, context);
+ head_size, head_size, hidden_size, nullptr, context);
}
} // namespace contrib
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index abad4abb75..d0050ac069 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -397,17 +397,32 @@ void AttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int p
auto& bias_shape = getInputShape(ctx, 2);
auto& bias_dims = bias_shape.dim();
- if (bias_dims.size() != 1 || bias_shape.dim(0).dim_value() % 3 != 0) {
+ if (bias_dims.size() != 1) {
fail_shape_inference("Invalid bias shape");
}
+ std::vector qkv_hidden_sizes;
+ getRepeatedAttribute(ctx, "qkv_hidden_sizes", qkv_hidden_sizes);
+
+ int64_t output_hidden_size;
+ if (qkv_hidden_sizes.size() != 0) {
+ if (qkv_hidden_sizes.size() != 3) {
+ fail_shape_inference("qkv_hidden_sizes should have 3 elements")
+ }
+ output_hidden_size = qkv_hidden_sizes[2];
+ } else {
+ output_hidden_size = bias_shape.dim(0).dim_value() / 3;
+ }
+
ONNX_NAMESPACE::TensorShapeProto output_shape;
for (auto& dim : input_dims) {
*output_shape.add_dim() = dim;
}
- output_shape.mutable_dim(2)->set_dim_value(bias_shape.dim(0).dim_value() / 3);
+
+ output_shape.mutable_dim(2)->set_dim_value(output_hidden_size);
updateOutputShape(ctx, 0, output_shape);
+ // TODO does the extra output need any changes?
if (ctx.getNumOutputs() > 1) {
if (hasInputShape(ctx, past_input_index)) {
auto& past_shape = getInputShape(ctx, past_input_index);
@@ -453,12 +468,17 @@ and present state are optional. Present state could appear in output even when p
"Whether every token can only attend to previous tokens. Default value is 0.",
AttributeProto::INT,
static_cast(0))
+ .Attr("qkv_hidden_sizes",
+ "Hidden layer sizes of Q, K, V paths in Attention",
+ AttributeProto::INTS,
+ OPTIONAL_VALUE)
.Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, input_hidden_size)", "T")
.Input(1, "weight", "2D input tensor with shape (input_hidden_size, 3 * hidden_size), where hidden_size = num_heads * head_size", "T")
.Input(2, "bias", "1D input tensor with shape (3 * hidden_size)", "T")
.Input(3, "mask_index", "Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, past_sequence_length + sequence_length)"
"or (batch_size, sequence_length, past_sequence_length + sequence_length), or index with shape (batch_size) or (2 * batch_size).", "M", OpSchema::Optional)
.Input(4, "past", "past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).", "T", OpSchema::Optional)
+ .Input(5, "extra_add", "additional add to QxK' with shape (batch_size, num_heads, sequence_length, sequence_length).", "T", OpSchema::Optional)
.Output(0, "output", "3D output tensor with shape (batch_size, append_length, hidden_size)", "T")
.Output(1, "present", "present state for key and value with shape (2, batch_size, num_heads, past_sequence_length + sequence_length, head_size)", "T", OpSchema::Optional)
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.")
@@ -545,6 +565,7 @@ and present state are optional. Present state could appear in output even when p
.TypeConstraint("T4", {"tensor(int32)"}, "Constrain mask index to integer types")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
constexpr int past_input_index = 8;
+
AttentionTypeAndShapeInference(ctx, past_input_index);
});
diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py
index 2c50b4a6b0..fdccf86fd4 100755
--- a/onnxruntime/python/tools/symbolic_shape_infer.py
+++ b/onnxruntime/python/tools/symbolic_shape_infer.py
@@ -1479,7 +1479,12 @@ class SymbolicShapeInference:
shape = self._get_shape(node, 0)
shape_bias = self._get_shape(node, 2)
assert len(shape) == 3 and len(shape_bias) == 1
- shape[2] = int(shape_bias[0] / 3)
+ qkv_hidden_sizes_attr = get_attribute(node, 'qkv_hidden_sizes')
+ if qkv_hidden_sizes_attr is not None:
+ assert len(qkv_hidden_sizes_attr) == 3
+ shape[2] = int(qkv_hidden_sizes_attr[2])
+ else:
+ shape[2] = int(shape_bias[0] / 3)
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape))
diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py
index c18225d6e9..7830971c4e 100644
--- a/onnxruntime/python/tools/transformers/fusion_attention.py
+++ b/onnxruntime/python/tools/transformers/fusion_attention.py
@@ -12,6 +12,7 @@ from onnx import helper, numpy_helper, TensorProto, NodeProto
from onnx_model import OnnxModel
from fusion_base import Fusion
from fusion_utils import FusionUtils, NumpyHelper
+from shape_infer_helper import SymbolicShapeInferenceHelper, get_shape_from_type_proto
logger = getLogger(__name__)
@@ -126,9 +127,29 @@ class FusionAttention(Fusion):
return num_heads, hidden_size
+ def get_add_qk_str(self, add_qk: NodeProto):
+ # Note: Does not work for dynamic models, reshape node shape inference would fail with more than 2 dims being -1
+ # inputs_ids has to be the name of input, this may be changes if needed
+ inputs_ids = self.model.find_graph_input("input_ids")
+ if inputs_ids == None:
+ logger.debug("no input with name \"input_ids\"")
+ return None
+ batch_size, seq_len = get_shape_from_type_proto(inputs_ids.type)
+
+ if batch_size < 0 or seq_len < 0:
+ logger.debug(f"batch_size: {batch_size} and seq_len {seq_len} cannot be -ve")
+ return None
+ shape_infer_helper = SymbolicShapeInferenceHelper(self.model.model)
+ shape_infer_helper.infer({"batch_size": batch_size, "seq_len": seq_len})
+ if shape_infer_helper.get_edge_shape(add_qk.input[0]) != shape_infer_helper.get_edge_shape(add_qk.input[1]):
+ logger.debug(f"the shape of two inputs of {add_qk} is not same")
+ return None
+
+ return add_qk.input[1]
+
def create_attention_node(self, mask_index: str, q_matmul: NodeProto, k_matmul: NodeProto, v_matmul: NodeProto,
q_add: NodeProto, k_add: NodeProto, v_add: NodeProto, num_heads: int, hidden_size: int,
- input: str, output: str) -> Union[NodeProto, None]:
+ input: str, output: str, add_qk_str: str) -> Union[NodeProto, None]:
""" Create an Attention node.
Args:
@@ -165,6 +186,7 @@ class FusionAttention(Fusion):
return None
if not (k_weight and v_weight and q_bias and k_bias):
return None
+
qw = NumpyHelper.to_array(q_weight)
kw = NumpyHelper.to_array(k_weight)
vw = NumpyHelper.to_array(v_weight)
@@ -246,6 +268,10 @@ class FusionAttention(Fusion):
if mask_index is not None:
attention_inputs.append(mask_index)
+ if add_qk_str is not None:
+ attention_inputs.append("")
+ attention_inputs.append(add_qk_str)
+
attention_node = helper.make_node('Attention',
inputs=attention_inputs,
outputs=[output],
@@ -334,7 +360,7 @@ class FusionAttention(Fusion):
logger.debug("fuse_attention: failed to match v path")
return
(_, _, add_v, matmul_v) = v_nodes
-
+
is_distill = False
is_distill_add = False
qk_paths = {
@@ -365,7 +391,7 @@ class FusionAttention(Fusion):
if is_distill:
(_, where_qk, matmul_qk, _) = qk_nodes
elif is_distill_add:
- (_, _, where_qk, matmul_qk) = qk_nodes
+ (_, add_qk, where_qk, matmul_qk) = qk_nodes
else:
(_, add_qk, _, matmul_qk) = qk_nodes
@@ -392,6 +418,7 @@ class FusionAttention(Fusion):
# Note that Cast might be removed by OnnxRuntime so we match two patterns here.
mask_nodes = None
+ add_qk_str = None
if is_distill:
_, mask_nodes, _ = self.model.match_parent_paths(where_qk,
[(['Expand', 'Reshape', 'Equal'], [0, 0, 0]),
@@ -401,6 +428,11 @@ class FusionAttention(Fusion):
_, mask_nodes, _ = self.model.match_parent_paths(
where_qk, [(['Cast', 'Equal', 'Unsqueeze', 'Unsqueeze'], [0, 0, 0, 0]),
(['Equal', 'Unsqueeze', 'Unsqueeze'], [0, 0, 0])], output_name_to_node)
+ if add_qk is not None:
+ add_qk_str = self.get_add_qk_str(add_qk)
+ if add_qk_str is None:
+ logger.debug(f"fuse_attention: failed to verify shape inference of {add_qk}")
+ return
else:
_, mask_nodes, _ = self.model.match_parent_paths(
add_qk, [(['Mul', 'Sub', 'Cast', 'Unsqueeze', 'Unsqueeze'], [None, 0, 1, 0, 0]),
@@ -419,7 +451,7 @@ class FusionAttention(Fusion):
# the input_hidden_size represents the input hidden size, this is used as needed but hidden sizes for Q, K are extracted appropriately
new_node = self.create_attention_node(mask_index, matmul_q, matmul_k, matmul_v, add_q, add_k, add_v,
q_num_heads, self.hidden_size, root_input,
- attention_last_node.output[0])
+ attention_last_node.output[0], add_qk_str)
if new_node is None:
return
diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py b/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py
index 52c3f1017d..c0f8f35517 100644
--- a/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py
+++ b/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py
@@ -140,7 +140,7 @@ class BertOnnxModelKeras(BertOnnxModelTF):
attention_node = self.attention_fusion.create_attention_node(mask_index, matmul_q, matmul_k, matmul_v,
add_q, add_k, add_v, self.num_heads,
self.hidden_size, parent.output[0],
- reshape_qkv.output[0])
+ reshape_qkv.output[0], None)
if attention_node is None:
continue
diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert_tf.py b/onnxruntime/python/tools/transformers/onnx_model_bert_tf.py
index 28bff0cbf5..e05593cfff 100644
--- a/onnxruntime/python/tools/transformers/onnx_model_bert_tf.py
+++ b/onnxruntime/python/tools/transformers/onnx_model_bert_tf.py
@@ -384,7 +384,7 @@ class BertOnnxModelTF(BertOnnxModel):
attention_node = self.attention_fusion.create_attention_node(mask_index, matmul_k, matmul_q, matmul_v,
add_k, add_q, add_v, self.num_heads,
self.hidden_size, parent.output[0],
- qkv_nodes[2].output[0])
+ qkv_nodes[2].output[0], None)
if attention_node is None:
continue
diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc
index 846699732c..9cd2863de4 100644
--- a/onnxruntime/test/contrib_ops/attention_op_test.cc
+++ b/onnxruntime/test/contrib_ops/attention_op_test.cc
@@ -38,11 +38,14 @@ static void RunAttentionTest(
MaskIndexType mask_index_type = kMaskIndexEnd,
int input_hidden_size = 0,
int max_sequence_length = 0,
- bool only_enable_cuda = false) {
+ bool only_enable_cuda = false,
+ bool only_enable_cpu = false,
+ std::vector qkv_sizes = {},
+ const std::vector& extra_add_data = {}) {
input_hidden_size = (input_hidden_size == 0 ? hidden_size : input_hidden_size); // By default, no pruning.
int min_cuda_architecture = use_float16 ? 530 : 0;
- bool enable_cuda = HasCudaEnvironment(min_cuda_architecture) && !is_weights_constant;
+ bool enable_cuda = HasCudaEnvironment(min_cuda_architecture) && !is_weights_constant && !only_enable_cpu;
bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()) && !use_float16 && !only_enable_cuda;
int head_size = hidden_size / number_of_heads;
@@ -51,9 +54,21 @@ static void RunAttentionTest(
tester.AddAttribute("num_heads", static_cast(number_of_heads));
tester.AddAttribute("unidirectional", static_cast(is_unidirectional ? 1 : 0));
+ int32_t matrix_size;
+ int32_t output_hidden_size;
+ if (qkv_sizes.size() != 0) {
+ matrix_size = qkv_sizes[0] + qkv_sizes[1] + qkv_sizes[2];
+ std::vector sizes_attribute{qkv_sizes[0], qkv_sizes[1], qkv_sizes[2]};
+ tester.AddAttribute>("qkv_hidden_sizes", sizes_attribute);
+ output_hidden_size = qkv_sizes[2];
+ } else {
+ matrix_size = 3 * hidden_size;
+ output_hidden_size = hidden_size;
+ }
+
std::vector input_dims = {batch_size, sequence_length, input_hidden_size};
- std::vector weights_dims = {input_hidden_size, 3 * hidden_size};
- std::vector bias_dims = {3 * hidden_size};
+ std::vector weights_dims = {input_hidden_size, matrix_size};
+ std::vector bias_dims = {matrix_size};
std::vector mask_index_dims_1 = {batch_size};
std::vector mask_index_dims_2 = {2 * batch_size};
@@ -88,7 +103,7 @@ static void RunAttentionTest(
std::vector past_dims = {2, batch_size, number_of_heads, past_sequence_length, head_size};
std::vector present_dims = {2, batch_size, number_of_heads, past_sequence_length + sequence_length, head_size};
- std::vector output_dims = {batch_size, sequence_length, hidden_size};
+ std::vector output_dims = {batch_size, sequence_length, output_hidden_size};
if (use_float16) {
tester.AddInput("input", input_dims, ToFloat16(input_data));
@@ -122,6 +137,14 @@ static void RunAttentionTest(
}
}
+ if (extra_add_data.size() > 0) {
+ if (!use_past_state) {
+ tester.AddOptionalInputEdge();
+ }
+ std::vector extra_add_data_dims = {batch_size, number_of_heads, sequence_length, sequence_length};
+ tester.AddInput("extra_add_qk", extra_add_data_dims, extra_add_data);
+ }
+
if (enable_cuda) {
std::vector> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
@@ -155,17 +178,20 @@ static void RunAttentionTest(
MaskIndexType mask_index_type = kMaskIndexEnd,
int input_hidden_size = 0,
int max_sequence_length = 0,
- bool only_enable_cuda = false) {
- RunAttentionTest(input_data, weights_data, false, bias_data, mask_index_data, output_data,
- batch_size, sequence_length, hidden_size, number_of_heads,
- use_float16, is_unidirectional, use_past_state, past_sequence_length,
- past_data, present_data, mask_index_type, input_hidden_size, max_sequence_length,
- only_enable_cuda);
- RunAttentionTest(input_data, weights_data, true, bias_data, mask_index_data, output_data,
- batch_size, sequence_length, hidden_size, number_of_heads,
- use_float16, is_unidirectional, use_past_state, past_sequence_length,
- past_data, present_data, mask_index_type, input_hidden_size, max_sequence_length,
- only_enable_cuda);
+ bool only_enable_cuda = false,
+ bool only_enable_cpu = false,
+ const std::vector qkv_sizes = {},
+ const std::vector& extra_add_data = {}) {
+ RunAttentionTest(input_data, weights_data, false, bias_data, mask_index_data, output_data,
+ batch_size, sequence_length, hidden_size, number_of_heads,
+ use_float16, is_unidirectional, use_past_state, past_sequence_length,
+ past_data, present_data, mask_index_type, input_hidden_size, max_sequence_length,
+ only_enable_cuda, only_enable_cpu, qkv_sizes, extra_add_data);
+ RunAttentionTest(input_data, weights_data, true, bias_data, mask_index_data, output_data,
+ batch_size, sequence_length, hidden_size, number_of_heads,
+ use_float16, is_unidirectional, use_past_state, past_sequence_length,
+ past_data, present_data, mask_index_type, input_hidden_size, max_sequence_length,
+ only_enable_cuda, only_enable_cpu, qkv_sizes, extra_add_data);
}
TEST(AttentionTest, AttentionBatch1) {
@@ -197,6 +223,163 @@ TEST(AttentionTest, AttentionBatch1) {
batch_size, sequence_length, hidden_size, number_of_heads);
}
+TEST(AttentionTest, AttentionBatch1WithQKVAttr1) {
+ int batch_size = 1;
+ int sequence_length = 2;
+ int hidden_size = 4;
+ int number_of_heads = 2;
+
+ std::vector input_data = {
+ 0.8f, -0.5f, 0.0f, 1.f,
+ 0.5f, 0.2f, 0.3f, -0.6f};
+
+ std::vector qkv_sizes = {
+ 6, 6, 4};
+
+ std::vector weight_data = {
+ 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
+ 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
+
+ 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
+ 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f,
+
+ 0.3f, 0.2f, 4.0f, 2.2f, 2.4f, 3.3f, 2.1f, 4.2f, 0.5f, 0.1f, 0.4f, 1.6f,
+ 0.4f, 0.8f, 0.9f, 0.1f
+ };
+
+ std::vector bias_data = {
+ -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f,
+ 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f,
+ 0.5f, 0.7f, 0.2f, 1.2f};
+
+ std::vector mask_index_data = {2L};
+
+ std::vector output_data = {
+ 3.1967618465423584f, 0.51903456449508667f, 0.63051539659500122f, 2.9394614696502686f,
+ 0.65332180261611938f, 1.000949501991272f, 0.74175024032592773f, 2.8231701850891113f};
+
+ RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
+ batch_size, sequence_length, hidden_size, number_of_heads,
+ false, false, false, 0, nullptr, nullptr, kMaskIndexEnd, 0,
+ 0, false, true, qkv_sizes);
+}
+
+TEST(AttentionTest, AttentionBatch1WithQKVAttr2) {
+ int batch_size = 1;
+ int sequence_length = 2;
+ int hidden_size = 4;
+ int number_of_heads = 2;
+
+ std::vector input_data = {
+ -0.031707365f, 0.053643607f, 0.057394292f, -0.019800574f, 0.075466447f, -0.0034214978f, 0.012995008f, -0.019587509f};
+
+ std::vector qkv_sizes = {
+ 6, 6, 2};
+
+ std::vector weight_data = {
+ 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
+ 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
+
+ 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
+ 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f,
+
+ 0.3f, 0.2f, 4.0f, 2.2f, 2.4f, 3.3f, 2.1f, 4.2f};
+
+ std::vector bias_data = {
+ -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f,
+ 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f,
+ 0.5f, 0.7f};
+
+ std::vector mask_index_data = {2L};
+
+ std::vector output_data = {
+ 0.64932525157928467f, 0.79390722513198853f, 0.64932847023010254f, 0.79375863075256348f};
+
+ RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
+ batch_size, sequence_length, hidden_size, number_of_heads,
+ false, false, false, 0, nullptr, nullptr, kMaskIndexEnd, 0,
+ 0, false, true, qkv_sizes);
+}
+
+TEST(AttentionTest, AttentionBatch1ExtraAdd) {
+ int batch_size = 1;
+ int sequence_length = 2;
+ int hidden_size = 4;
+ int number_of_heads = 2;
+
+ std::vector input_data = {
+ 0.8f, -0.5f, 0.0f, 1.f,
+ 0.5f, 0.2f, 0.3f, -0.6f};
+
+ std::vector qkv_sizes = {};
+
+ std::vector weight_data = {
+ 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
+ 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
+ 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
+ 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f};
+
+ std::vector bias_data = {
+ -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f,
+ 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f};
+
+ std::vector mask_index_data = {2L};
+
+ std::vector extra_add_qk = {
+ 0.2f, -0.1f, 0.4f, 2.5f, 1.6f, -1.1f, 0.4f, -2.5f};
+
+ std::vector output_data = {
+ 4.066014289855957f, 0.068997815251350403f, 4.25f, 5.6499996185302734f,
+ -1.8799558877944946f, 0.32488855719566345f, 4.25f, 5.6499996185302734f};
+
+ RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
+ batch_size, sequence_length, hidden_size, number_of_heads,
+ false, false, false, 0, nullptr, nullptr, kMaskIndexEnd, 0,
+ 0, false, true, qkv_sizes, extra_add_qk);
+}
+
+TEST(AttentionTest, AttentionBatch2ExtraAdd) {
+ int batch_size = 2;
+ int sequence_length = 2;
+ int hidden_size = 4;
+ int number_of_heads = 2;
+
+ std::vector input_data = {
+ 0.8f, -0.5f, 0.0f, 1.f,
+ 0.5f, 0.2f, 0.3f, -0.6f,
+ 0.8f, -0.5f, 0.0f, 1.f,
+ 0.5f, 0.2f, 0.3f, -0.6f};
+
+ std::vector qkv_sizes = {};
+
+ std::vector weight_data = {
+ 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
+ 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
+ 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
+ 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f};
+
+ std::vector bias_data = {
+ -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f,
+ 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f};
+
+ std::vector mask_index_data = {2L, 2L};
+
+ std::vector extra_add_qk = {
+ 0.2f, -0.1f, 0.4f, 2.5f, 1.6f, -1.1f, 0.4f, -2.5f,
+ 0.2f, -0.1f, 0.4f, 2.5f, 1.6f, -1.1f, 0.4f, -2.5f};
+
+ std::vector output_data = {
+ 4.066014289855957f, 0.068997815251350403f, 4.25f, 5.6499996185302734f,
+ -1.8799558877944946f, 0.32488855719566345f, 4.25f, 5.6499996185302734f,
+ 4.066014289855957f, 0.068997815251350403f, 4.25f, 5.6499996185302734f,
+ -1.8799558877944946f, 0.32488855719566345f, 4.25f, 5.6499996185302734f};
+
+ RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
+ batch_size, sequence_length, hidden_size, number_of_heads,
+ false, false, false, 0, nullptr, nullptr, kMaskIndexEnd, 0,
+ 0, false, true, qkv_sizes, extra_add_qk);
+}
+
TEST(AttentionTest, AttentionBatch1_Float16) {
int batch_size = 1;
int sequence_length = 2;
diff --git a/onnxruntime/test/python/transformers/bert_model_generator.py b/onnxruntime/test/python/transformers/bert_model_generator.py
index dc9a504578..4764334ada 100644
--- a/onnxruntime/test/python/transformers/bert_model_generator.py
+++ b/onnxruntime/test/python/transformers/bert_model_generator.py
@@ -6,6 +6,7 @@
import onnx
import math
+import numpy as np
from typing import List
from packaging import version
from onnx import helper, TensorProto
@@ -17,7 +18,7 @@ def float_tensor(name: str, shape: List[int], random=False):
total_elements = 1
for x in shape:
total_elements *= x
- weights = [random.uniform(low, high) for _ in range(total_elements)] if random else [1.0] * total_elements
+ weights = [np.random.uniform(low, high) for _ in range(total_elements)] if random else [1.0] * total_elements
return helper.make_tensor(name, TensorProto.FLOAT, shape, weights)