mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Update Attention and QAttention to support pruned model (#6819)
* update Attention operator spec to support pruned model * update Attention and QAttention cpu & cuda kernel * Fix invalid embed layer norm fusion test models.
This commit is contained in:
parent
cb8d8464bc
commit
2d6e10ba00
14 changed files with 247 additions and 188 deletions
|
|
@ -54,14 +54,17 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
|
|||
const Tensor*& mask_index,
|
||||
const Tensor* past) const {
|
||||
// Input shapes:
|
||||
// input : (batch_size, sequence_length, hidden_size)
|
||||
// weights : (hidden_size, 3 * hidden_size)
|
||||
// input : (batch_size, sequence_length, input_hidden_size)
|
||||
// weights : (input_hidden_size, 3 * hidden_size)
|
||||
// bias : (3 * hidden_size)
|
||||
// mask_index : nullptr, (batch_size), (2 * batch_size),
|
||||
// or (batch_size, 1), (1, 1)
|
||||
// or (batch_size, past_sequence_length + sequence_length)
|
||||
// or (batch_size, sequence_length, past_sequence_length + sequence_length)
|
||||
// past : (2, batch_size, num_heads, past_sequence_length, head_size)
|
||||
//
|
||||
// Where hidden_size = num_heads * head_size.
|
||||
// When a model is pruned (like some attention heads are removed), hidden_size < input_hidden_size.
|
||||
|
||||
const auto& dims = input_shape.GetDims();
|
||||
if (dims.size() != 3) {
|
||||
|
|
@ -70,11 +73,6 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
|
|||
}
|
||||
int batch_size = static_cast<int>(dims[0]);
|
||||
int sequence_length = static_cast<int>(dims[1]);
|
||||
int hidden_size = static_cast<int>(dims[2]);
|
||||
if (hidden_size % num_heads_ != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 0 dimension 2 should be divisiable by value of the num_heads attribute.");
|
||||
}
|
||||
|
||||
const auto& weights_dims = weights_shape.GetDims();
|
||||
if (weights_dims.size() != 2) {
|
||||
|
|
@ -85,8 +83,15 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
|
|||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 1 dimension 0 should have same length as dimension 2 of input 0");
|
||||
}
|
||||
if (weights_dims[1] != 3 * weights_dims[0]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'weights' dimension 1 should be 3 times of dimension 0");
|
||||
|
||||
int hidden_size = static_cast<int>(weights_dims[1]) / 3;
|
||||
if (3 * hidden_size != static_cast<int>(weights_dims[1])) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 1 dimension 1 should be 3 times of hidden dimension");
|
||||
}
|
||||
|
||||
if (hidden_size % num_heads_ != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "hidden_size should be divisiable by num_heads.");
|
||||
}
|
||||
|
||||
const auto& bias_dims = bias_shape.GetDims();
|
||||
|
|
@ -125,7 +130,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
|
|||
const auto& mask_dims = mask_index->Shape().GetDims();
|
||||
if (mask_dims.size() == 1) {
|
||||
if (static_cast<int>(mask_dims[0]) != batch_size && static_cast<int>(mask_dims[0]) != 2 * batch_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' dimension 0 shall have length of batch_size or 2 * batch_size");
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with 1D data shall have length of batch_size or 2 * batch_size");
|
||||
}
|
||||
} else if (mask_dims.size() == 2) {
|
||||
if (static_cast<int>(mask_dims[0]) != batch_size || static_cast<int>(mask_dims[1]) != past_sequence_length + sequence_length) {
|
||||
|
|
@ -134,12 +139,12 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
|
|||
// Mask will have same value after propogation, which has same effect as no mask.
|
||||
mask_index = nullptr;
|
||||
} else {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with raw attention mask shall have shape batch_size x (past_sequence_length + sequence_length)");
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with 2D data shall have shape batch_size x (past_sequence_length + sequence_length)");
|
||||
}
|
||||
}
|
||||
} else if (mask_dims.size() == 3) {
|
||||
if (static_cast<int>(mask_dims[0]) != batch_size || mask_dims[1] != sequence_length || static_cast<int>(mask_dims[2]) != past_sequence_length + sequence_length) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' of 3d shall have shape batch_size x sequence_length x (past_sequence_length + sequence_length)");
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with 3D data shall have shape batch_size x sequence_length x (past_sequence_length + sequence_length)");
|
||||
}
|
||||
} else {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'mask_index' is expected to have 1, 2 or 3 dimensions, got ",
|
||||
|
|
@ -193,8 +198,9 @@ Status Attention<T>::PrePack(const Tensor& weights, int input_idx, bool& is_pack
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
const size_t hidden_size = static_cast<size_t>(weights_dims[0]);
|
||||
const size_t input_hidden_size = static_cast<size_t>(weights_dims[0]);
|
||||
const size_t hidden_size_x3 = static_cast<size_t>(weights_dims[1]);
|
||||
const size_t hidden_size = hidden_size_x3 / 3;
|
||||
const size_t head_size = hidden_size / num_heads_;
|
||||
|
||||
// Bail out if the weights shape has an expected shape.
|
||||
|
|
@ -204,7 +210,7 @@ Status Attention<T>::PrePack(const Tensor& weights, int input_idx, bool& is_pack
|
|||
|
||||
const auto* weights_data = weights.Data<T>();
|
||||
|
||||
packed_weights_size_ = MlasGemmPackBSize(head_size, hidden_size);
|
||||
packed_weights_size_ = MlasGemmPackBSize(head_size, input_hidden_size);
|
||||
if (packed_weights_size_ == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -215,7 +221,7 @@ Status Attention<T>::PrePack(const Tensor& weights, int input_idx, bool& is_pack
|
|||
packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc));
|
||||
|
||||
for (size_t i = 0; i < loop_len; i++) {
|
||||
MlasGemmPackB(CblasNoTrans, head_size, hidden_size, weights_data, hidden_size_x3, packed_weights_data);
|
||||
MlasGemmPackB(CblasNoTrans, head_size, input_hidden_size, weights_data, hidden_size_x3, packed_weights_data);
|
||||
packed_weights_data += packed_weights_size_;
|
||||
weights_data += head_size;
|
||||
}
|
||||
|
|
@ -232,8 +238,9 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
|
|||
const Tensor* mask_index = context->Input<Tensor>(3);
|
||||
const Tensor* past = context->Input<Tensor>(4);
|
||||
|
||||
const TensorShape& weights_shape = (packed_weights_ ? weight_shape_ : weights->Shape());
|
||||
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(),
|
||||
weights ? weights->Shape() : weight_shape_,
|
||||
weights_shape,
|
||||
bias->Shape(),
|
||||
mask_index,
|
||||
past));
|
||||
|
|
@ -241,10 +248,17 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
|
|||
const auto& shape = input->Shape().GetDims();
|
||||
const int batch_size = static_cast<int>(shape[0]);
|
||||
const int sequence_length = static_cast<int>(shape[1]);
|
||||
const int hidden_size = static_cast<int>(shape[2]);
|
||||
const int input_hidden_size = static_cast<int>(shape[2]);
|
||||
|
||||
const auto& weights_dims = weights_shape.GetDims();
|
||||
const int hidden_size = static_cast<int>(weights_dims[1]) / 3;
|
||||
const int head_size = hidden_size / num_heads_;
|
||||
|
||||
Tensor* output = context->Output(0, shape);
|
||||
std::vector<int64_t> output_shape(3);
|
||||
output_shape[0] = shape[0];
|
||||
output_shape[1] = shape[1];
|
||||
output_shape[2] = static_cast<int64_t>(hidden_size);
|
||||
Tensor* output = context->Output(0, output_shape);
|
||||
|
||||
constexpr size_t element_size = sizeof(T);
|
||||
|
||||
|
|
@ -253,7 +267,8 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
|
|||
|
||||
auto* tp = context->GetOperatorThreadPool();
|
||||
// Compute Q, K, V
|
||||
// gemm_data(BS, 3NH) = input(BS, NH) x weights(NH, 3NH) + bias(3NH)
|
||||
// gemm_data(BS, 3NH) = input(BS, D) x weights(D, 3NH) + bias(3NH)
|
||||
// D (input_hidden_size) is hidden dimension of input, where D could be larger than hidden_size (NH) when model is pruned.
|
||||
auto gemm_data = allocator->Alloc(SafeInt<size_t>(batch_size) * sequence_length * 3 * hidden_size * element_size);
|
||||
BufferUniquePtr gemm_buffer(gemm_data, BufferDeleter(allocator));
|
||||
|
||||
|
|
@ -269,14 +284,14 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
|
|||
const auto* bias_data = bias->template Data<T>();
|
||||
|
||||
const double cost =
|
||||
static_cast<double>(sequence_length) * static_cast<double>(head_size) * static_cast<double>(hidden_size);
|
||||
static_cast<double>(sequence_length) * static_cast<double>(head_size) * static_cast<double>(input_hidden_size);
|
||||
ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
|
||||
for (std::ptrdiff_t i = begin; i != end; ++i) {
|
||||
const int batch_index = static_cast<int>((i / 3) / num_heads_);
|
||||
const int head_index = static_cast<int>((i / 3) % num_heads_);
|
||||
const int qkv_index = static_cast<int>(i % 3);
|
||||
|
||||
int input_offset = batch_index * sequence_length * hidden_size;
|
||||
int input_offset = batch_index * sequence_length * input_hidden_size;
|
||||
int weights_offset = qkv_index * hidden_size + head_index * head_size;
|
||||
T* qkv_dest = QKV[qkv_index];
|
||||
int qkv_offset = (batch_index * num_heads_ + head_index) * (sequence_length * head_size);
|
||||
|
|
@ -290,8 +305,8 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
|
|||
}
|
||||
|
||||
// original transposed iteration
|
||||
// A: input (BxSxNxH) (B.)S x NH S x NH
|
||||
// B: weights (NxHx3xNxH) NH x (3.N.)H NH x H
|
||||
// A: input (BxSxD) (B.)S x D S x D
|
||||
// B: weights (Dx3xNxH) D x (3.N.)H D x H
|
||||
// C: QKV[qkv_index] (3xBxNxSxH) (3.B.N.)S x H S x H
|
||||
if (packed_weights_) {
|
||||
const auto* packed_weight =
|
||||
|
|
@ -300,30 +315,31 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
|
|||
CblasNoTrans, // TransA = no
|
||||
sequence_length, // M = S
|
||||
head_size, // N = H
|
||||
hidden_size, // K = NH
|
||||
input_hidden_size, // K = D
|
||||
1.0f, // alpha
|
||||
input_data + input_offset, // A
|
||||
hidden_size, // lda = NH
|
||||
input_hidden_size, // lda = D
|
||||
packed_weight, // B
|
||||
1.0f, // beta
|
||||
qkv_dest + qkv_offset, // C
|
||||
head_size, // ldc
|
||||
nullptr); // use single-thread
|
||||
} else {
|
||||
math::GemmEx<float, ThreadPool>(CblasNoTrans, // TransA = no
|
||||
CblasNoTrans, // TransB = no
|
||||
sequence_length, // M = S
|
||||
head_size, // N = H
|
||||
hidden_size, // K = NH
|
||||
1.0f, // alpha
|
||||
input_data + input_offset, // A
|
||||
hidden_size, // lda = NH
|
||||
weights_data + weights_offset, // B
|
||||
3 * hidden_size, // ldb = 3NH
|
||||
1.0f, // beta
|
||||
qkv_dest + qkv_offset, // C
|
||||
head_size, // ldc
|
||||
nullptr // use single-thread
|
||||
math::GemmEx<float, ThreadPool>(
|
||||
CblasNoTrans, // TransA = no
|
||||
CblasNoTrans, // TransB = no
|
||||
sequence_length, // M = S
|
||||
head_size, // N = H
|
||||
input_hidden_size, // K = D
|
||||
1.0f, // alpha
|
||||
input_data + input_offset, // A
|
||||
input_hidden_size, // lda = D
|
||||
weights_data + weights_offset, // B
|
||||
3 * hidden_size, // ldb = 3NH
|
||||
1.0f, // beta
|
||||
qkv_dest + qkv_offset, // C
|
||||
head_size, // ldc
|
||||
nullptr // use single-thread
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -66,8 +66,9 @@ Status QAttention<T>::PrePack(const Tensor& weights, int input_idx, bool& is_pac
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
const size_t hidden_size = static_cast<size_t>(weights_dims[0]);
|
||||
const size_t input_hidden_size = static_cast<size_t>(weights_dims[0]);
|
||||
const size_t hidden_size_x3 = static_cast<size_t>(weights_dims[1]);
|
||||
const size_t hidden_size = hidden_size_x3 / 3;
|
||||
const size_t head_size = hidden_size / num_heads_;
|
||||
|
||||
// Bail out if the weights shape has an expected shape.
|
||||
|
|
@ -78,7 +79,7 @@ Status QAttention<T>::PrePack(const Tensor& weights, int input_idx, bool& is_pac
|
|||
const auto* weights_data = static_cast<const uint8_t*>(weights.DataRaw());
|
||||
weights_is_signed_ = weights.IsDataType<int8_t>();
|
||||
|
||||
packed_weights_size_ = MlasGemmPackBSize(head_size, hidden_size, weights_is_signed_);
|
||||
packed_weights_size_ = MlasGemmPackBSize(head_size, input_hidden_size, weights_is_signed_);
|
||||
if (packed_weights_size_ == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -89,7 +90,7 @@ Status QAttention<T>::PrePack(const Tensor& weights, int input_idx, bool& is_pac
|
|||
packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc));
|
||||
|
||||
for (size_t i = 0; i < loop_len; i++) {
|
||||
MlasGemmPackB(head_size, hidden_size, weights_data, hidden_size_x3, weights_is_signed_, packed_weights_data);
|
||||
MlasGemmPackB(head_size, input_hidden_size, weights_data, hidden_size_x3, weights_is_signed_, packed_weights_data);
|
||||
packed_weights_data += packed_weights_size_;
|
||||
weights_data += head_size;
|
||||
}
|
||||
|
|
@ -102,8 +103,8 @@ Status QAttention<T>::PrePack(const Tensor& weights, int input_idx, bool& is_pac
|
|||
template <typename T>
|
||||
Status QAttention<T>::Compute(OpKernelContext* context) const {
|
||||
// Input and output shapes:
|
||||
// Input 0 - input : (batch_size, sequence_length, hidden_size)
|
||||
// Input 1 - weights : (hidden_size, 3 * hidden_size)
|
||||
// Input 0 - input : (batch_size, sequence_length, input_hidden_size)
|
||||
// Input 1 - weights : (input_hidden_size, 3 * hidden_size)
|
||||
// Input 2 - bias : (3 * hidden_size)
|
||||
// Input 3 - input_scale : scalar
|
||||
// Input 4 - weight_scale : scalar
|
||||
|
|
@ -124,8 +125,9 @@ Status QAttention<T>::Compute(OpKernelContext* context) const {
|
|||
const Tensor* w_zp_tensor = context->Input<Tensor>(7);
|
||||
const Tensor* past_tensor = context->Input<Tensor>(8);
|
||||
|
||||
const TensorShape& weights_shape = (packed_weights_ ? weight_shape_ : weights->Shape());
|
||||
ORT_RETURN_IF_ERROR(AttentionBase::CheckInputs(input->Shape(),
|
||||
packed_weights_ ? weight_shape_ : weights->Shape(),
|
||||
weights_shape,
|
||||
bias->Shape(),
|
||||
mask_index,
|
||||
past_tensor));
|
||||
|
|
@ -157,10 +159,17 @@ Status QAttention<T>::Compute(OpKernelContext* context) const {
|
|||
const auto& shape = input->Shape();
|
||||
const int batch_size = static_cast<int>(shape[0]);
|
||||
const int sequence_length = static_cast<int>(shape[1]);
|
||||
const int hidden_size = static_cast<int>(shape[2]);
|
||||
const int input_hidden_size = static_cast<int>(shape[2]);
|
||||
|
||||
const auto hidden_size_x3 = weights_shape.GetDims()[1];
|
||||
const int hidden_size = static_cast<int>(hidden_size_x3) / 3;
|
||||
const int head_size = hidden_size / num_heads_;
|
||||
|
||||
Tensor* output = context->Output(0, shape);
|
||||
std::vector<int64_t> output_shape(3);
|
||||
output_shape[0] = shape[0];
|
||||
output_shape[1] = shape[1];
|
||||
output_shape[2] = static_cast<int64_t>(hidden_size);
|
||||
Tensor* output = context->Output(0, output_shape);
|
||||
|
||||
AllocatorPtr allocator;
|
||||
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
|
||||
|
|
@ -168,7 +177,8 @@ Status QAttention<T>::Compute(OpKernelContext* context) const {
|
|||
constexpr size_t element_size = sizeof(T);
|
||||
|
||||
auto* tp = context->GetOperatorThreadPool();
|
||||
// STEP.1: gemm_data(BS, 3NH) = Scale(input(BS, NH) x weights(NH, 3NH)) + bias(3NH)
|
||||
// STEP.1: gemm_data(BS, 3NH) = Scale(input(BS, D) x weights(D, 3NH)) + bias(3NH)
|
||||
// D is hidden dimension of input, where input_hidden_size (D) could be larger than hidden_size (NH) when model is pruned.
|
||||
auto gemm_data = allocator->Alloc(SafeInt<size_t>(batch_size) * sequence_length * 3 * hidden_size * element_size);
|
||||
BufferUniquePtr gemm_buffer(gemm_data, BufferDeleter(allocator));
|
||||
|
||||
|
|
@ -186,21 +196,21 @@ Status QAttention<T>::Compute(OpKernelContext* context) const {
|
|||
const bool weights_is_signed = packed_weights_ ? weights_is_signed_ : weights->IsDataType<int8_t>();
|
||||
|
||||
const double cost =
|
||||
static_cast<double>(sequence_length) * static_cast<double>(head_size) * static_cast<double>(hidden_size);
|
||||
static_cast<double>(sequence_length) * static_cast<double>(head_size) * static_cast<double>(input_hidden_size);
|
||||
ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
|
||||
for (std::ptrdiff_t i = begin; i != end; ++i) {
|
||||
const int batch_index = static_cast<int>((i / 3) / num_heads_);
|
||||
const int head_index = static_cast<int>((i / 3) % num_heads_);
|
||||
const int qkv_index = static_cast<int>(i % 3);
|
||||
|
||||
int input_offset = batch_index * sequence_length * hidden_size;
|
||||
int input_offset = batch_index * sequence_length * input_hidden_size;
|
||||
int weights_offset = qkv_index * hidden_size + head_index * head_size;
|
||||
float* qkv_dest = QKV[qkv_index];
|
||||
int qkv_offset = (batch_index * num_heads_ + head_index) * (sequence_length * head_size);
|
||||
|
||||
// original transposed iteration
|
||||
// A: input (BxSxNxH) (B.)S x NH S x NH
|
||||
// B: weights (NxHx3xNxH) NH x (3.N.)H NH x H
|
||||
// A: input (BxSxD) (B.)S x D S x D
|
||||
// B: weights (Dx3xNxH) D x (3.N.)H D x H
|
||||
// C: QKV[qkv_index] (3xBxNxSxH) (3.B.N.)S x H S x H
|
||||
|
||||
MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR scale_bias_processor(qkv_dest + qkv_offset,
|
||||
|
|
@ -215,9 +225,9 @@ Status QAttention<T>::Compute(OpKernelContext* context) const {
|
|||
MlasGemm(
|
||||
sequence_length, // M = S
|
||||
head_size, // N = H
|
||||
hidden_size, // K = NH
|
||||
input_hidden_size, // K = D
|
||||
input_data + input_offset, // A
|
||||
hidden_size, // lda = NH
|
||||
input_hidden_size, // lda = D
|
||||
input_zero_point, // input zero point
|
||||
packed_weight, // B
|
||||
weight_zero_point, // weight zero point
|
||||
|
|
@ -233,9 +243,9 @@ Status QAttention<T>::Compute(OpKernelContext* context) const {
|
|||
MlasGemm(
|
||||
sequence_length, // M = S
|
||||
head_size, // N = H
|
||||
hidden_size, // K = NH
|
||||
input_hidden_size, // K = D
|
||||
input_data + input_offset, // A
|
||||
hidden_size, // lda = NH
|
||||
input_hidden_size, // lda = D
|
||||
input_zero_point, // input zero point
|
||||
weights_data + weights_offset, // B
|
||||
3 * hidden_size, // ldb = 3NH
|
||||
|
|
|
|||
|
|
@ -41,16 +41,23 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
const Tensor* past = context->Input<Tensor>(4);
|
||||
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), mask_index, past));
|
||||
|
||||
// Input and output shapes:
|
||||
// Input 0 - input : (batch_size, sequence_length, hidden_size)
|
||||
// Output 0 - output : (batch_size, sequence_length, hidden_size)
|
||||
// input shape (batch_size, sequence_length, input_hidden_size)
|
||||
const auto& shape = input->Shape();
|
||||
int batch_size = static_cast<int>(shape[0]);
|
||||
int sequence_length = static_cast<int>(shape[1]);
|
||||
int hidden_size = static_cast<int>(shape[2]);
|
||||
int head_size = hidden_size / num_heads_;
|
||||
int input_hidden_size = static_cast<int>(shape[2]);
|
||||
|
||||
Tensor* output = context->Output(0, shape);
|
||||
// bias shape (3 * hidden_size)
|
||||
const auto& bias_shape = bias->Shape();
|
||||
int hidden_size = static_cast<int>(bias_shape[0]) / 3;
|
||||
|
||||
int head_size = hidden_size / num_heads_;
|
||||
|
||||
std::vector<int64_t> output_shape(3);
|
||||
output_shape[0] = shape[0];
|
||||
output_shape[1] = shape[1];
|
||||
output_shape[2] = static_cast<int64_t>(hidden_size);
|
||||
Tensor* output = context->Output(0, output_shape);
|
||||
|
||||
int past_sequence_length = 0;
|
||||
Tensor* present = GetPresent(context, past, batch_size, head_size, sequence_length, past_sequence_length);
|
||||
|
|
@ -61,7 +68,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
// Use GEMM for fully connection.
|
||||
int m = batch_size * sequence_length;
|
||||
int n = 3 * hidden_size;
|
||||
int k = hidden_size;
|
||||
int k = input_hidden_size;
|
||||
auto gemm_buffer = GetScratchBuffer<T>(batch_size * sequence_length * 3 * hidden_size * element_size);
|
||||
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
|
|
|
|||
|
|
@ -49,17 +49,6 @@ Status QAttention<T, int8_t>::CheckInputs(const Tensor* input,
|
|||
const Tensor* i_zp_tensor,
|
||||
const Tensor* w_zp_tensor,
|
||||
const Tensor* past_tensor) const {
|
||||
// Input and output shapes:
|
||||
// Input 0 - input : (batch_size, sequence_length, hidden_size)
|
||||
// Input 1 - weights : (hidden_size, 3 * hidden_size)
|
||||
// Input 2 - bias : (3 * hidden_size)
|
||||
// Input 3 - input_scale : scalar
|
||||
// Input 4 - weight_scale : scalar
|
||||
// Input 5 - mask_index : nullptr, (batch_size), (2 * batch_size), (batch_size, 1), (1, 1) or (batch_size, past_sequence_length + sequence_length)
|
||||
// Input 6 - input_zero_point : scalar
|
||||
// Input 7 - weight_zero_point : scalar
|
||||
// Output : (batch_size, sequence_length, hidden_size)
|
||||
|
||||
ORT_RETURN_IF_ERROR(AttentionBase::CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), mask_index, past_tensor));
|
||||
|
||||
ORT_RETURN_IF_NOT(IsScalarOr1ElementVector(input_scale_tensor),
|
||||
|
|
@ -89,12 +78,12 @@ Status QAttention<T, int8_t>::CheckInputs(const Tensor* input,
|
|||
template <typename T>
|
||||
Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
|
||||
// Input and output shapes:
|
||||
// Input 0 - input : (batch_size, sequence_length, hidden_size)
|
||||
// Input 1 - weights : (hidden_size, 3 * hidden_size)
|
||||
// Input 0 - input : (batch_size, sequence_length, input_hidden_size)
|
||||
// Input 1 - weights : (input_hidden_size, 3 * hidden_size)
|
||||
// Input 2 - bias : (3 * hidden_size)
|
||||
// Input 3 - input_scale : scalar
|
||||
// Input 4 - weight_scale : scalar
|
||||
// Input 5 - mask_index : (batch_size)
|
||||
// Input 5 - mask_index : nullptr, (batch_size), (2 * batch_size), (batch_size, 1), (1, 1) or (batch_size, past_sequence_length + sequence_length)
|
||||
// Input 6 - input_zero_point : scalar
|
||||
// Input 7 - weight_zero_point : scalar
|
||||
// Input 8 - past : (2, batch_size, num_heads, past_sequence_length, head_size)
|
||||
|
|
@ -124,10 +113,17 @@ Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
|
|||
const auto& shape = input->Shape();
|
||||
int batch_size = static_cast<int>(shape[0]);
|
||||
int sequence_length = static_cast<int>(shape[1]);
|
||||
int hidden_size = static_cast<int>(shape[2]);
|
||||
int head_size = hidden_size / num_heads_;
|
||||
int input_hidden_size = static_cast<int>(shape[2]);
|
||||
|
||||
Tensor* output = context->Output(0, shape);
|
||||
const auto& bias_shape = bias->Shape();
|
||||
const int hidden_size = static_cast<int>(bias_shape.GetDims()[0]) / 3;
|
||||
const int head_size = hidden_size / num_heads_;
|
||||
|
||||
std::vector<int64_t> output_shape(3);
|
||||
output_shape[0] = shape[0];
|
||||
output_shape[1] = shape[1];
|
||||
output_shape[2] = static_cast<int64_t>(hidden_size);
|
||||
Tensor* output = context->Output(0, output_shape);
|
||||
|
||||
cublasHandle_t cublas = CublasHandle();
|
||||
const size_t element_size = sizeof(T);
|
||||
|
|
@ -135,7 +131,7 @@ Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
|
|||
// Use GEMM for fully connection.
|
||||
int m = batch_size * sequence_length;
|
||||
int n = 3 * hidden_size;
|
||||
int k = hidden_size;
|
||||
int k = input_hidden_size;
|
||||
auto gemm_buffer = GetScratchBuffer<T>(batch_size * sequence_length * 3 * hidden_size * element_size);
|
||||
auto gemm_buffer_quantized = GetScratchBuffer<int32_t>(batch_size * sequence_length * 3 * hidden_size);
|
||||
|
||||
|
|
|
|||
|
|
@ -275,6 +275,59 @@ void FusedMatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) {
|
|||
updateOutputShape(ctx, 0, resultShape);
|
||||
}
|
||||
|
||||
|
||||
void AttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_input_index) {
|
||||
// Type inference
|
||||
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 2, 0);
|
||||
if (ctx.getNumOutputs() > 1) {
|
||||
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 2, 1);
|
||||
}
|
||||
|
||||
// Shape inference
|
||||
if (hasInputShape(ctx, 0) && hasInputShape(ctx, 2)) {
|
||||
auto& input_shape = getInputShape(ctx, 0);
|
||||
auto& input_dims = input_shape.dim();
|
||||
if (input_dims.size() != 3) {
|
||||
fail_shape_inference("Inputs 0 shall be 3 dimensions");
|
||||
}
|
||||
|
||||
auto& bias_shape = getInputShape(ctx, 2);
|
||||
auto& bias_dims = bias_shape.dim();
|
||||
if (bias_dims.size() != 1 || bias_shape.dim(0).dim_value() % 3 != 0) {
|
||||
fail_shape_inference("Invalid bias shape");
|
||||
}
|
||||
|
||||
ONNX_NAMESPACE::TensorShapeProto output_shape;
|
||||
for (auto& dim : input_dims) {
|
||||
*output_shape.add_dim() = dim;
|
||||
}
|
||||
output_shape.mutable_dim(2)->set_dim_value(bias_shape.dim(0).dim_value() / 3);
|
||||
updateOutputShape(ctx, 0, output_shape);
|
||||
|
||||
if (ctx.getNumOutputs() > 1) {
|
||||
if (hasInputShape(ctx, past_input_index)) {
|
||||
auto& past_shape = getInputShape(ctx, past_input_index);
|
||||
auto& past_dims = past_shape.dim();
|
||||
if (past_dims.size() != 5) {
|
||||
fail_shape_inference("Inputs 4 shall be 5 dimensions");
|
||||
}
|
||||
|
||||
if (past_dims[3].has_dim_value() && input_dims[1].has_dim_value()) {
|
||||
auto all_sequence_length = past_shape.dim(3).dim_value() + input_shape.dim(1).dim_value();
|
||||
|
||||
ONNX_NAMESPACE::TensorShapeProto present_shape;
|
||||
for (auto& dim : past_dims) {
|
||||
*present_shape.add_dim() = dim;
|
||||
}
|
||||
present_shape.mutable_dim(3)->set_dim_value(all_sequence_length);
|
||||
|
||||
updateOutputShape(ctx, 1, present_shape);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void RegisterBertSchemas() {
|
||||
static const char* Attention_ver1_doc = R"DOC(
|
||||
Multi-Head Self Attention that can be either unidirectional (like GPT-2) or bidirectional (like BERT).
|
||||
|
|
@ -296,8 +349,8 @@ and present state are optional. Present state could appear in output even when p
|
|||
"Whether every token can only attend to previous tokens. Default value is 0.",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(0))
|
||||
.Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, hidden_size), hidden_size = num_heads * head_size", "T")
|
||||
.Input(1, "weight", "2D input tensor with shape (hidden_size, 3 * hidden_size)", "T")
|
||||
.Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, input_hidden_size)", "T")
|
||||
.Input(1, "weight", "2D input tensor with shape (input_hidden_size, 3 * hidden_size), where hidden_size = num_heads * head_size", "T")
|
||||
.Input(2, "bias", "1D input tensor with shape (3 * hidden_size)", "T")
|
||||
.Input(3, "mask_index", "Attention mask with shape (batch_size, past_sequence_length + sequence_length) or (batch_size, sequence_length, past_sequence_length + sequence_length), or index with shape (batch_size) or (2 * batch_size).", "M", OpSchema::Optional)
|
||||
.Input(4, "past", "past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).", "T", OpSchema::Optional)
|
||||
|
|
@ -306,42 +359,8 @@ and present state are optional. Present state could appear in output even when p
|
|||
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.")
|
||||
.TypeConstraint("M", {"tensor(int32)"}, "Constrain mask index to integer types")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
if (ctx.getNumOutputs() > 1) {
|
||||
propagateElemTypeFromInputToOutput(ctx, 0, 1);
|
||||
}
|
||||
|
||||
if (hasInputShape(ctx, 0)) {
|
||||
propagateShapeFromInputToOutput(ctx, 0, 0);
|
||||
|
||||
if (ctx.getNumOutputs() > 1) {
|
||||
auto& input_shape = getInputShape(ctx, 0);
|
||||
auto& input_dims = input_shape.dim();
|
||||
if (input_dims.size() != 3) {
|
||||
fail_shape_inference("Inputs 0 shall be 3 dimensions");
|
||||
}
|
||||
|
||||
if (hasInputShape(ctx, 4)) {
|
||||
auto& past_shape = getInputShape(ctx, 4);
|
||||
auto& past_dims = past_shape.dim();
|
||||
if (past_dims.size() != 5) {
|
||||
fail_shape_inference("Inputs 4 shall be 5 dimensions");
|
||||
}
|
||||
|
||||
if (past_dims[3].has_dim_value() && input_dims[1].has_dim_value()) {
|
||||
auto all_sequence_length = past_shape.dim(3).dim_value() + input_shape.dim(1).dim_value();
|
||||
|
||||
ONNX_NAMESPACE::TensorShapeProto present_shape;
|
||||
for (auto& dim : past_dims) {
|
||||
*present_shape.add_dim() = dim;
|
||||
}
|
||||
present_shape.mutable_dim(3)->set_dim_value(all_sequence_length);
|
||||
|
||||
updateOutputShape(ctx, 1, present_shape);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
constexpr int past_input_index = 4;
|
||||
AttentionTypeAndShapeInference(ctx, past_input_index);
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(QAttention)
|
||||
|
|
@ -356,12 +375,12 @@ and present state are optional. Present state could appear in output even when p
|
|||
.Input(
|
||||
0,
|
||||
"input",
|
||||
"3D input tensor with shape (batch_size, sequence_length, hidden_size), hidden_size = num_heads * head_size",
|
||||
"3D input tensor with shape (batch_size, sequence_length, input_hidden_size)",
|
||||
"T1")
|
||||
.Input(
|
||||
1,
|
||||
"weight",
|
||||
"2D input tensor with shape (hidden_size, 3 * hidden_size)",
|
||||
"2D input tensor with shape (input_hidden_size, 3 * hidden_size), hidden_size = num_heads * head_size",
|
||||
"T2")
|
||||
.Input(
|
||||
2,
|
||||
|
|
@ -418,18 +437,8 @@ and present state are optional. Present state could appear in output even when p
|
|||
.TypeConstraint("T3", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.")
|
||||
.TypeConstraint("T4", {"tensor(int32)"}, "Constrain mask index to integer types")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
// Type inference
|
||||
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 2, 0);
|
||||
|
||||
// Shape inference
|
||||
// if the input shape doesn't exist, further shape inference is not possible
|
||||
if (!hasNInputShapes(ctx, 1)) {
|
||||
return;
|
||||
}
|
||||
|
||||
ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, 0, 0);
|
||||
|
||||
return;
|
||||
constexpr int past_input_index = 8;
|
||||
AttentionTypeAndShapeInference(ctx, past_input_index);
|
||||
});
|
||||
|
||||
static const char* Longformer_Attention_doc = R"DOC(
|
||||
|
|
@ -456,7 +465,7 @@ Global attention flags have value 1 for the tokens attend globally and 0 otherwi
|
|||
.Input(4, "global_weight", "2D input tensor with shape (hidden_size, 3 * hidden_size)", "T")
|
||||
.Input(5, "global_bias", "1D input tensor with shape (3 * hidden_size)", "T")
|
||||
.Input(6, "global", "Global attention flags with shape (batch_size, sequence_length)", "G")
|
||||
.Output(0, "output", "3D output tensor with shape (batch_size, append_length, hidden_size)", "T")
|
||||
.Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", "T")
|
||||
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.")
|
||||
.TypeConstraint("G", {"tensor(int32)"}, "Constrain to integer types")
|
||||
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);
|
||||
|
|
|
|||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -2,29 +2,32 @@ import onnx
|
|||
from onnx import helper
|
||||
from onnx import TensorProto
|
||||
from enum import Enum
|
||||
from packaging import version
|
||||
|
||||
if version.parse(onnx.__version__) == version.parse('1.8.0'):
|
||||
opset_version = 13
|
||||
elif version.parse(onnx.__version__) == version.parse('1.6.0'):
|
||||
opset_version = 11
|
||||
else:
|
||||
raise RuntimeError("Please pip install onnx==1.8.0 or 1.6.0 before running this script")
|
||||
|
||||
def GenerateNodes(model_name, has_cast, suffix=''):
|
||||
nodes = [ # LayerNorm subgraph
|
||||
helper.make_node("Shape", ["input_ids" + suffix], ["shape1_out" + suffix], "shape1" + suffix),
|
||||
helper.make_node("Gather", ["shape1_out" + suffix, "indices_0"], ["gather0_out" + suffix], "gather0" + suffix),
|
||||
helper.make_node("Unsqueeze", ["gather0_out" + suffix], ["unsqueeze0_out" + suffix],
|
||||
"unsqueeze0" + suffix,
|
||||
axes=[0]),
|
||||
helper.make_node("Unsqueeze", ["gather0_out" + suffix, "axes_0"], ["unsqueeze0_out" + suffix], "unsqueeze0" + suffix) if opset_version == 13 \
|
||||
else helper.make_node("Unsqueeze", ["gather0_out" + suffix], ["unsqueeze0_out" + suffix], "unsqueeze0" + suffix, axes=[0]),
|
||||
helper.make_node("Shape", ["input_ids" + suffix], ["shape2_out" + suffix], "shape2" + suffix),
|
||||
helper.make_node("Gather", ["shape2_out" + suffix, "indices_1"], ["gather1_out" + suffix], "gather1" + suffix),
|
||||
helper.make_node("Unsqueeze", ["gather1_out" + suffix], ["unsqueeze1_out" + suffix],
|
||||
"unsqueeze1" + suffix,
|
||||
axes=[0]),
|
||||
helper.make_node("Unsqueeze", ["gather1_out" + suffix, "axes_0"], ["unsqueeze1_out" + suffix], "unsqueeze1" + suffix) if opset_version == 13 \
|
||||
else helper.make_node("Unsqueeze", ["gather1_out" + suffix], ["unsqueeze1_out" + suffix], "unsqueeze1" + suffix, axes=[0]),
|
||||
helper.make_node("Concat", ["unsqueeze0_out" + suffix, "unsqueeze1_out" + suffix], ["concat_out" + suffix],
|
||||
"concat" + suffix,
|
||||
axis=0),
|
||||
"concat" + suffix, axis=0),
|
||||
helper.make_node("Cast", ["gather1_out" + suffix], ["cast_out" + suffix], "cast" + suffix, to=7),
|
||||
helper.make_node("Range", ["start_0", "cast_out" + suffix if has_cast else "gather1_out" + suffix, "delta_1"],
|
||||
["range_out" + suffix], "range" + suffix),
|
||||
helper.make_node("Unsqueeze", ["range_out" + suffix], ["unsqueeze2_out" + suffix],
|
||||
"unsqueeze2" + suffix,
|
||||
axes=[0]),
|
||||
helper.make_node("Unsqueeze", ["range_out" + suffix, "axes_0"], ["unsqueeze2_out" + suffix], "unsqueeze2" + suffix) if opset_version == 13 \
|
||||
else helper.make_node("Unsqueeze", ["range_out" + suffix], ["unsqueeze2_out" + suffix], "unsqueeze2" + suffix, axes=[0]),
|
||||
helper.make_node("Expand", ["unsqueeze2_out" + suffix, "concat_out" + suffix], ["expand_out" + suffix],
|
||||
"expand" + suffix),
|
||||
helper.make_node("Gather", ["pos_embed", "expand_out" + suffix], ["pos_gather_out" + suffix],
|
||||
|
|
@ -43,10 +46,8 @@ def GenerateNodes(model_name, has_cast, suffix=''):
|
|||
axis=-1,
|
||||
epsion=0.000009999999747378752),
|
||||
helper.make_node("Cast", ["input_mask" + suffix], ["mask_cast_out" + suffix], "mask_cast" + suffix, to=6),
|
||||
helper.make_node("ReduceSum", ["mask_cast_out" + suffix], ["mask_index_out" + suffix],
|
||||
"mask_index" + suffix,
|
||||
axes=[1],
|
||||
keepdims=0),
|
||||
helper.make_node("ReduceSum", ["mask_cast_out" + suffix, "axes_1"], ["mask_index_out" + suffix], "mask_index" + suffix, keepdims=0) if opset_version == 13 \
|
||||
else helper.make_node("ReduceSum", ["mask_cast_out" + suffix], ["mask_index_out" + suffix], "mask_index" + suffix, axes=[1], keepdims=0),
|
||||
helper.make_node("Attention", ["layernorm_out" + suffix, "qkv_weights", "qkv_bias", "mask_index_out" + suffix],
|
||||
["att_out" + suffix],
|
||||
"att" + suffix,
|
||||
|
|
@ -63,7 +64,7 @@ def GenerateNodes(model_name, has_cast, suffix=''):
|
|||
|
||||
|
||||
def GenerateInitializers():
|
||||
# hidden_size=4, num_heads=2, max_seq_length=3
|
||||
# hidden_size=4, num_heads=2
|
||||
initializers = [ # initializers
|
||||
helper.make_tensor('indices_0', TensorProto.INT64, [], [0]),
|
||||
helper.make_tensor('indices_1', TensorProto.INT64, [], [1]),
|
||||
|
|
@ -75,12 +76,14 @@ def GenerateInitializers():
|
|||
helper.make_tensor('seg_embed', TensorProto.FLOAT, [2, 4], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
|
||||
helper.make_tensor('layer_norm_weight', TensorProto.FLOAT, [4], [1.0, 2.0, 3.0, 4.0]),
|
||||
helper.make_tensor('layer_norm_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]),
|
||||
helper.make_tensor('qkv_weights', TensorProto.FLOAT, [4, 4],
|
||||
[1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
|
||||
helper.make_tensor('qkv_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]),
|
||||
helper.make_tensor('qkv_weights', TensorProto.FLOAT, [4, 12], [0.1] * 4 * 12),
|
||||
helper.make_tensor('qkv_bias', TensorProto.FLOAT, [12],
|
||||
[0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4]),
|
||||
helper.make_tensor('matmul_weight', TensorProto.FLOAT, [4, 4],
|
||||
[1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
|
||||
helper.make_tensor('add_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]),
|
||||
helper.make_tensor('axes_0', TensorProto.INT64, [1], [0]),
|
||||
helper.make_tensor('axes_1', TensorProto.INT64, [1], [1]),
|
||||
]
|
||||
|
||||
return initializers
|
||||
|
|
@ -154,7 +157,9 @@ def GenerateModel5(model_name):
|
|||
axis=-1,
|
||||
epsion=0.000009999999747378752),
|
||||
helper.make_node("Cast", ["input_mask"], ["mask_cast_out"], "mask_cast", to=6),
|
||||
helper.make_node("ReduceSum", ["mask_cast_out"], ["mask_index_out"], "mask_index", axes=[1], keepdims=0),
|
||||
helper.make_node("ReduceSum", ["mask_cast_out", "axes_1"], ["mask_index_out"], "mask_index", keepdims=0) if opset_version == 13 \
|
||||
else helper.make_node("ReduceSum", ["mask_cast_out"], ["mask_index_out"], "mask_index", axes=[1], keepdims=0),
|
||||
|
||||
helper.make_node("Attention", ["layernorm_out", "qkv_weights", "qkv_bias", "mask_index_out"], ["att_out"],
|
||||
"att",
|
||||
domain="com.microsoft",
|
||||
|
|
@ -164,11 +169,7 @@ def GenerateModel5(model_name):
|
|||
helper.make_node("Add", ["add_out", "layernorm_out"], ["add2_out"], "add2")
|
||||
]
|
||||
|
||||
qkv_weights = [
|
||||
1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0,
|
||||
3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0,
|
||||
1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0
|
||||
]
|
||||
qkv_weights = [1.0] * hidden_size * (3 * hidden_size)
|
||||
|
||||
initializers = [ # initializers
|
||||
helper.make_tensor('word_embed', TensorProto.FLOAT, [2, hidden_size], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
|
||||
|
|
@ -185,6 +186,7 @@ def GenerateModel5(model_name):
|
|||
helper.make_tensor('matmul_weight', TensorProto.FLOAT, [hidden_size, hidden_size],
|
||||
[1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
|
||||
helper.make_tensor('add_bias', TensorProto.FLOAT, [hidden_size], [0.1, 0.2, 0.3, 0.4]),
|
||||
helper.make_tensor('axes_1', TensorProto.INT64, [1], [1]),
|
||||
]
|
||||
|
||||
graph = helper.make_graph(
|
||||
|
|
@ -208,16 +210,19 @@ def GenerateModel6(model_name):
|
|||
nodes = [ # LayerNorm subgraph
|
||||
helper.make_node("Shape", ["input_ids"], ["shape1_out"], "shape1"),
|
||||
helper.make_node("Gather", ["shape1_out", "indices_0"], ["gather0_out"], "gather0"),
|
||||
helper.make_node("Unsqueeze", ["gather0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]),
|
||||
helper.make_node("Unsqueeze", ["gather0_out", "axes_0"], ["unsqueeze0_out"], "unsqueeze0") if opset_version == 13 \
|
||||
else helper.make_node("Unsqueeze", ["gather0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]),
|
||||
helper.make_node("Shape", ["input_ids"], ["shape2_out"], "shape2"),
|
||||
helper.make_node("Gather", ["shape2_out", "indices_1"], ["gather1_out"], "gather1"),
|
||||
helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]),
|
||||
helper.make_node("Unsqueeze", ["gather1_out", "axes_0"], ["unsqueeze1_out"], "unsqueeze1") if opset_version == 13 \
|
||||
else helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]),
|
||||
helper.make_node("Concat", ["unsqueeze0_out", "unsqueeze1_out"], ["concat_out"], "concat", axis=0),
|
||||
helper.make_node("Reshape", ["concat_out", "reshape_init"], ["reshape_out"], "reshape"),
|
||||
helper.make_node("Equal", ["reshape_out", "equal_init"], ["equal_out"], "equal"),
|
||||
helper.make_node("Where", ["equal_out", "where_init", "reshape_out"], ["where_out"], "where"),
|
||||
helper.make_node("Range", ["start_0", "gather1_out", "delta_1"], ["range_out"], "range"),
|
||||
helper.make_node("Unsqueeze", ["range_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]),
|
||||
helper.make_node("Unsqueeze", ["range_out", "axes_0"], ["unsqueeze2_out"], "unsqueeze2") if opset_version == 13 \
|
||||
else helper.make_node("Unsqueeze", ["range_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]),
|
||||
helper.make_node("Expand", ["unsqueeze2_out", "where_out"], ["expand_out"], "expand"),
|
||||
helper.make_node("Gather", ["pos_embed", "expand_out"], ["pos_gather_out"], "pos_gather"),
|
||||
helper.make_node("Gather", ["word_embed", "input_ids"], ["word_gather_out"], "word_gather"),
|
||||
|
|
@ -229,7 +234,8 @@ def GenerateModel6(model_name):
|
|||
axis=-1,
|
||||
epsion=0.000009999999747378752),
|
||||
helper.make_node("Cast", ["input_mask"], ["mask_cast_out"], "mask_cast", to=6),
|
||||
helper.make_node("ReduceSum", ["mask_cast_out"], ["mask_index_out"], "mask_index", axes=[1], keepdims=0),
|
||||
helper.make_node("ReduceSum", ["mask_cast_out", "axes_1"], ["mask_index_out"], "mask_index", keepdims=0) if opset_version == 13 \
|
||||
else helper.make_node("ReduceSum", ["mask_cast_out"], ["mask_index_out"], "mask_index", axes=[1], keepdims=0),
|
||||
helper.make_node("Attention", ["layernorm_out", "qkv_weights", "qkv_bias", "mask_index_out"], ["att_out"],
|
||||
"att",
|
||||
domain="com.microsoft",
|
||||
|
|
@ -251,15 +257,16 @@ def GenerateModel6(model_name):
|
|||
helper.make_tensor('seg_embed', TensorProto.FLOAT, [2, 4], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
|
||||
helper.make_tensor('layer_norm_weight', TensorProto.FLOAT, [4], [1.0, 2.0, 3.0, 4.0]),
|
||||
helper.make_tensor('layer_norm_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]),
|
||||
helper.make_tensor('qkv_weights', TensorProto.FLOAT, [4, 4],
|
||||
[1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
|
||||
helper.make_tensor('qkv_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]),
|
||||
helper.make_tensor('qkv_weights', TensorProto.FLOAT, [4, 12], [0.1] * 4 * 12),
|
||||
helper.make_tensor('qkv_bias', TensorProto.FLOAT, [12], [0.1] * 12),
|
||||
helper.make_tensor('matmul_weight', TensorProto.FLOAT, [4, 4],
|
||||
[1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
|
||||
helper.make_tensor('add_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]),
|
||||
helper.make_tensor('reshape_init', TensorProto.INT64, [1], [-1]),
|
||||
helper.make_tensor('equal_init', TensorProto.INT64, [2], [-1, -1]),
|
||||
helper.make_tensor('where_init', TensorProto.INT64, [2], [1, 1]),
|
||||
helper.make_tensor('axes_0', TensorProto.INT64, [1], [0]),
|
||||
helper.make_tensor('axes_1', TensorProto.INT64, [1], [1]),
|
||||
]
|
||||
|
||||
graph = helper.make_graph(
|
||||
|
|
@ -278,12 +285,9 @@ def GenerateModel6(model_name):
|
|||
model = helper.make_model(graph)
|
||||
onnx.save(model, model_name)
|
||||
|
||||
|
||||
def GenerateInitializers2(hidden_size):
|
||||
qkv_weights = [
|
||||
1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0,
|
||||
3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0,
|
||||
1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0
|
||||
]
|
||||
qkv_weights = [1.0] * hidden_size * (3 * hidden_size)
|
||||
|
||||
initializers = [ # initializers
|
||||
helper.make_tensor('word_embed', TensorProto.FLOAT, [2, hidden_size], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
|
||||
|
|
@ -300,31 +304,32 @@ def GenerateInitializers2(hidden_size):
|
|||
helper.make_tensor('matmul_weight', TensorProto.FLOAT, [hidden_size, hidden_size],
|
||||
[1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
|
||||
helper.make_tensor('add_bias', TensorProto.FLOAT, [hidden_size], [0.1, 0.2, 0.3, 0.4]),
|
||||
helper.make_tensor('axes_0', TensorProto.INT64, [1], [0]),
|
||||
helper.make_tensor('axes_1', TensorProto.INT64, [1], [1]),
|
||||
]
|
||||
|
||||
return initializers
|
||||
|
||||
|
||||
def GenerateNodes2(attention_heads):
|
||||
nodes = [
|
||||
helper.make_node("Gather", ["word_embed", "input_ids"], ["word_gather_out"], "word_gather", axis=0),
|
||||
|
||||
helper.make_node("Shape", ["input_ids"], ["shape0_out"], "shape0"),
|
||||
helper.make_node("Gather", ["shape0_out", "indices_1"], ["gather0_out"], "gather0"),
|
||||
helper.make_node("Range", ["start", "gather0_out", "delta"], ["range0_out"], "range0"),
|
||||
helper.make_node("Unsqueeze", ["range0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]),
|
||||
helper.make_node("Unsqueeze", ["range0_out", "axes_0"], ["unsqueeze0_out"], "unsqueeze0") if opset_version == 13 \
|
||||
else helper.make_node("Unsqueeze", ["range0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]),
|
||||
helper.make_node("Shape", ["input_ids"], ["shape1_out"], "shape1"),
|
||||
helper.make_node("Expand", ["unsqueeze0_out", "shape1_out"], ["expand_out"], "expand"),
|
||||
helper.make_node("Gather", ["pos_embed", "expand_out"], ["pos_gather_out"], "pos_gather", axis=0),
|
||||
|
||||
helper.make_node("Add", ["word_gather_out", "pos_gather_out"], ["add1_out"], "add1"),
|
||||
helper.make_node("LayerNormalization", ["add1_out", "layer_norm_weight", "layer_norm_bias"], ["layernorm_out"],
|
||||
"layernorm",
|
||||
axis=-1,
|
||||
epsion=0.000009999999747378752),
|
||||
|
||||
helper.make_node("Cast", ["input_mask"], ["mask_cast_out"], "mask_cast", to=6),
|
||||
|
||||
helper.make_node("ReduceSum", ["mask_cast_out"], ["mask_index_out"], "mask_index", axes=[1], keepdims=0),
|
||||
helper.make_node("ReduceSum", ["mask_cast_out", "axes_1"], ["mask_index_out"], "mask_index", keepdims=0) if opset_version == 13 \
|
||||
else helper.make_node("ReduceSum", ["mask_cast_out"], ["mask_index_out"], "mask_index", axes=[1], keepdims=0),
|
||||
helper.make_node("Attention", ["layernorm_out", "qkv_weights", "qkv_bias", "mask_index_out"], ["att_out"],
|
||||
"att",
|
||||
domain="com.microsoft",
|
||||
|
|
@ -336,6 +341,7 @@ def GenerateNodes2(attention_heads):
|
|||
|
||||
return nodes
|
||||
|
||||
|
||||
def GenerateModel7(model_name):
|
||||
batch_size = 2
|
||||
hidden_size = 4
|
||||
|
|
@ -361,6 +367,7 @@ def GenerateModel7(model_name):
|
|||
model = helper.make_model(graph)
|
||||
onnx.save(model, model_name)
|
||||
|
||||
|
||||
def GenerateModel8(model_name):
|
||||
batch_size = -1
|
||||
hidden_size = 4
|
||||
|
|
@ -395,6 +402,7 @@ def GenerateModel8(model_name):
|
|||
model = helper.make_model(graph)
|
||||
onnx.save(model, model_name)
|
||||
|
||||
|
||||
def GenerateModel9(model_name):
|
||||
batch_size = -1
|
||||
hidden_size = 4
|
||||
|
|
@ -412,10 +420,13 @@ def GenerateModel9(model_name):
|
|||
helper.make_node("Expand", ["unsqueeze0_out", "shape_out"], ["expand_out"], "expand"),
|
||||
helper.make_node("Gather", ["shape_out", "indices_0"], ["gather1_out"], "gather1"),
|
||||
helper.make_node("Gather", ["shape_out", "indices_1"], ["gather2_out"], "gather2"),
|
||||
helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]),
|
||||
helper.make_node("Unsqueeze", ["gather2_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]),
|
||||
helper.make_node("Unsqueeze", ["gather1_out", "axes_0"], ["unsqueeze1_out"], "unsqueeze1") if opset_version == 13 \
|
||||
else helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]),
|
||||
helper.make_node("Unsqueeze", ["gather2_out", "axes_0"], ["unsqueeze2_out"], "unsqueeze2") if opset_version == 13 \
|
||||
else helper.make_node("Unsqueeze", ["gather2_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]),
|
||||
helper.make_node("Concat", ["unsqueeze1_out", "unsqueeze2_out"], ["concat_out"], "concat", axis=0),
|
||||
helper.make_node('ConstantOfShape', ['concat_out'], ['constant_of_shape_out'], "constant_of_shape",
|
||||
helper.make_node('ConstantOfShape', ['concat_out'], ['constant_of_shape_out'],
|
||||
"constant_of_shape",
|
||||
value=helper.make_tensor('mask_shape', TensorProto.FLOAT, [1], [1.0])),
|
||||
helper.make_node("Cast", ["constant_of_shape_out"], ["mask_cast_out"], "mask_cast", to=6),
|
||||
]
|
||||
|
|
@ -437,11 +448,21 @@ def GenerateModel9(model_name):
|
|||
model = helper.make_model(graph)
|
||||
onnx.save(model, model_name)
|
||||
|
||||
GenerateModel3('embed_layer_norm_format3.onnx', True)
|
||||
GenerateModel3('embed_layer_norm_format3_no_cast.onnx', False)
|
||||
GenerateModel5('embed_layer_norm_format5.onnx')
|
||||
GenerateModel6('embed_layer_norm_format6.onnx')
|
||||
GenerateModel7('embed_layer_norm_format7.onnx') #distilbert
|
||||
GenerateModel8('embed_layer_norm_format8.onnx') #distilbert & shape nodes integration with input mask
|
||||
GenerateModel9('embed_layer_norm_format9.onnx') #distilbert & shape nodes integration without input mask
|
||||
GenerateMultipleEmbedModel('embed_layer_norm_multiple.onnx')
|
||||
if opset_version == 11:
|
||||
GenerateModel3('embed_layer_norm_format3.onnx', True)
|
||||
GenerateModel3('embed_layer_norm_format3_no_cast.onnx', False)
|
||||
GenerateModel5('embed_layer_norm_format5.onnx')
|
||||
GenerateModel6('embed_layer_norm_format6.onnx')
|
||||
GenerateModel7('embed_layer_norm_format7.onnx') #distilbert
|
||||
GenerateModel8('embed_layer_norm_format8.onnx') #distilbert & shape nodes integration with input mask
|
||||
GenerateModel9('embed_layer_norm_format9.onnx') #distilbert & shape nodes integration without input mask
|
||||
GenerateMultipleEmbedModel('embed_layer_norm_multiple.onnx')
|
||||
else:
|
||||
GenerateModel3('embed_layer_norm_format3_opset13.onnx', True)
|
||||
GenerateModel3('embed_layer_norm_format3_no_cast_opset13.onnx', False)
|
||||
GenerateModel5('embed_layer_norm_format5_opset13.onnx')
|
||||
GenerateModel6('embed_layer_norm_format6_opset13.onnx')
|
||||
GenerateModel7('embed_layer_norm_format7_opset13.onnx')
|
||||
GenerateModel8('embed_layer_norm_format8_opset13.onnx')
|
||||
GenerateModel9('embed_layer_norm_format9_opset13.onnx')
|
||||
GenerateMultipleEmbedModel('embed_layer_norm_multiple_opset13.onnx')
|
||||
Binary file not shown.
Loading…
Reference in a new issue