Update Attention and QAttention to support pruned model (#6819)

* update Attention operator spec to support pruned model
* update Attention and QAttention cpu & cuda kernel
* Fix invalid embed layer norm fusion test models.
This commit is contained in:
Tianlei Wu 2021-02-27 09:50:16 -08:00 committed by GitHub
parent cb8d8464bc
commit 2d6e10ba00
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 247 additions and 188 deletions

View file

@ -54,14 +54,17 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
const Tensor*& mask_index,
const Tensor* past) const {
// Input shapes:
// input : (batch_size, sequence_length, hidden_size)
// weights : (hidden_size, 3 * hidden_size)
// input : (batch_size, sequence_length, input_hidden_size)
// weights : (input_hidden_size, 3 * hidden_size)
// bias : (3 * hidden_size)
// mask_index : nullptr, (batch_size), (2 * batch_size),
// or (batch_size, 1), (1, 1)
// or (batch_size, past_sequence_length + sequence_length)
// or (batch_size, sequence_length, past_sequence_length + sequence_length)
// past : (2, batch_size, num_heads, past_sequence_length, head_size)
//
// Where hidden_size = num_heads * head_size.
// When a model is pruned (like some attention heads are removed), hidden_size < input_hidden_size.
const auto& dims = input_shape.GetDims();
if (dims.size() != 3) {
@ -70,11 +73,6 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
}
int batch_size = static_cast<int>(dims[0]);
int sequence_length = static_cast<int>(dims[1]);
int hidden_size = static_cast<int>(dims[2]);
if (hidden_size % num_heads_ != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 0 dimension 2 should be divisiable by value of the num_heads attribute.");
}
const auto& weights_dims = weights_shape.GetDims();
if (weights_dims.size() != 2) {
@ -85,8 +83,15 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 1 dimension 0 should have same length as dimension 2 of input 0");
}
if (weights_dims[1] != 3 * weights_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'weights' dimension 1 should be 3 times of dimension 0");
int hidden_size = static_cast<int>(weights_dims[1]) / 3;
if (3 * hidden_size != static_cast<int>(weights_dims[1])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 1 dimension 1 should be 3 times of hidden dimension");
}
if (hidden_size % num_heads_ != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "hidden_size should be divisiable by num_heads.");
}
const auto& bias_dims = bias_shape.GetDims();
@ -125,7 +130,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
const auto& mask_dims = mask_index->Shape().GetDims();
if (mask_dims.size() == 1) {
if (static_cast<int>(mask_dims[0]) != batch_size && static_cast<int>(mask_dims[0]) != 2 * batch_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' dimension 0 shall have length of batch_size or 2 * batch_size");
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with 1D data shall have length of batch_size or 2 * batch_size");
}
} else if (mask_dims.size() == 2) {
if (static_cast<int>(mask_dims[0]) != batch_size || static_cast<int>(mask_dims[1]) != past_sequence_length + sequence_length) {
@ -134,12 +139,12 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
// Mask will have same value after propogation, which has same effect as no mask.
mask_index = nullptr;
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with raw attention mask shall have shape batch_size x (past_sequence_length + sequence_length)");
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with 2D data shall have shape batch_size x (past_sequence_length + sequence_length)");
}
}
} else if (mask_dims.size() == 3) {
if (static_cast<int>(mask_dims[0]) != batch_size || mask_dims[1] != sequence_length || static_cast<int>(mask_dims[2]) != past_sequence_length + sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' of 3d shall have shape batch_size x sequence_length x (past_sequence_length + sequence_length)");
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with 3D data shall have shape batch_size x sequence_length x (past_sequence_length + sequence_length)");
}
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'mask_index' is expected to have 1, 2 or 3 dimensions, got ",
@ -193,8 +198,9 @@ Status Attention<T>::PrePack(const Tensor& weights, int input_idx, bool& is_pack
return Status::OK();
}
const size_t hidden_size = static_cast<size_t>(weights_dims[0]);
const size_t input_hidden_size = static_cast<size_t>(weights_dims[0]);
const size_t hidden_size_x3 = static_cast<size_t>(weights_dims[1]);
const size_t hidden_size = hidden_size_x3 / 3;
const size_t head_size = hidden_size / num_heads_;
// Bail out if the weights shape has an expected shape.
@ -204,7 +210,7 @@ Status Attention<T>::PrePack(const Tensor& weights, int input_idx, bool& is_pack
const auto* weights_data = weights.Data<T>();
packed_weights_size_ = MlasGemmPackBSize(head_size, hidden_size);
packed_weights_size_ = MlasGemmPackBSize(head_size, input_hidden_size);
if (packed_weights_size_ == 0) {
return Status::OK();
}
@ -215,7 +221,7 @@ Status Attention<T>::PrePack(const Tensor& weights, int input_idx, bool& is_pack
packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc));
for (size_t i = 0; i < loop_len; i++) {
MlasGemmPackB(CblasNoTrans, head_size, hidden_size, weights_data, hidden_size_x3, packed_weights_data);
MlasGemmPackB(CblasNoTrans, head_size, input_hidden_size, weights_data, hidden_size_x3, packed_weights_data);
packed_weights_data += packed_weights_size_;
weights_data += head_size;
}
@ -232,8 +238,9 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
const Tensor* mask_index = context->Input<Tensor>(3);
const Tensor* past = context->Input<Tensor>(4);
const TensorShape& weights_shape = (packed_weights_ ? weight_shape_ : weights->Shape());
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(),
weights ? weights->Shape() : weight_shape_,
weights_shape,
bias->Shape(),
mask_index,
past));
@ -241,10 +248,17 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
const auto& shape = input->Shape().GetDims();
const int batch_size = static_cast<int>(shape[0]);
const int sequence_length = static_cast<int>(shape[1]);
const int hidden_size = static_cast<int>(shape[2]);
const int input_hidden_size = static_cast<int>(shape[2]);
const auto& weights_dims = weights_shape.GetDims();
const int hidden_size = static_cast<int>(weights_dims[1]) / 3;
const int head_size = hidden_size / num_heads_;
Tensor* output = context->Output(0, shape);
std::vector<int64_t> output_shape(3);
output_shape[0] = shape[0];
output_shape[1] = shape[1];
output_shape[2] = static_cast<int64_t>(hidden_size);
Tensor* output = context->Output(0, output_shape);
constexpr size_t element_size = sizeof(T);
@ -253,7 +267,8 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
auto* tp = context->GetOperatorThreadPool();
// Compute Q, K, V
// gemm_data(BS, 3NH) = input(BS, NH) x weights(NH, 3NH) + bias(3NH)
// gemm_data(BS, 3NH) = input(BS, D) x weights(D, 3NH) + bias(3NH)
// D (input_hidden_size) is hidden dimension of input, where D could be larger than hidden_size (NH) when model is pruned.
auto gemm_data = allocator->Alloc(SafeInt<size_t>(batch_size) * sequence_length * 3 * hidden_size * element_size);
BufferUniquePtr gemm_buffer(gemm_data, BufferDeleter(allocator));
@ -269,14 +284,14 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
const auto* bias_data = bias->template Data<T>();
const double cost =
static_cast<double>(sequence_length) * static_cast<double>(head_size) * static_cast<double>(hidden_size);
static_cast<double>(sequence_length) * static_cast<double>(head_size) * static_cast<double>(input_hidden_size);
ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t i = begin; i != end; ++i) {
const int batch_index = static_cast<int>((i / 3) / num_heads_);
const int head_index = static_cast<int>((i / 3) % num_heads_);
const int qkv_index = static_cast<int>(i % 3);
int input_offset = batch_index * sequence_length * hidden_size;
int input_offset = batch_index * sequence_length * input_hidden_size;
int weights_offset = qkv_index * hidden_size + head_index * head_size;
T* qkv_dest = QKV[qkv_index];
int qkv_offset = (batch_index * num_heads_ + head_index) * (sequence_length * head_size);
@ -290,8 +305,8 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
}
// original transposed iteration
// A: input (BxSxNxH) (B.)S x NH S x NH
// B: weights (NxHx3xNxH) NH x (3.N.)H NH x H
// A: input (BxSxD) (B.)S x D S x D
// B: weights (Dx3xNxH) D x (3.N.)H D x H
// C: QKV[qkv_index] (3xBxNxSxH) (3.B.N.)S x H S x H
if (packed_weights_) {
const auto* packed_weight =
@ -300,30 +315,31 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
CblasNoTrans, // TransA = no
sequence_length, // M = S
head_size, // N = H
hidden_size, // K = NH
input_hidden_size, // K = D
1.0f, // alpha
input_data + input_offset, // A
hidden_size, // lda = NH
input_hidden_size, // lda = D
packed_weight, // B
1.0f, // beta
qkv_dest + qkv_offset, // C
head_size, // ldc
nullptr); // use single-thread
} else {
math::GemmEx<float, ThreadPool>(CblasNoTrans, // TransA = no
CblasNoTrans, // TransB = no
sequence_length, // M = S
head_size, // N = H
hidden_size, // K = NH
1.0f, // alpha
input_data + input_offset, // A
hidden_size, // lda = NH
weights_data + weights_offset, // B
3 * hidden_size, // ldb = 3NH
1.0f, // beta
qkv_dest + qkv_offset, // C
head_size, // ldc
nullptr // use single-thread
math::GemmEx<float, ThreadPool>(
CblasNoTrans, // TransA = no
CblasNoTrans, // TransB = no
sequence_length, // M = S
head_size, // N = H
input_hidden_size, // K = D
1.0f, // alpha
input_data + input_offset, // A
input_hidden_size, // lda = D
weights_data + weights_offset, // B
3 * hidden_size, // ldb = 3NH
1.0f, // beta
qkv_dest + qkv_offset, // C
head_size, // ldc
nullptr // use single-thread
);
}
}

View file

@ -66,8 +66,9 @@ Status QAttention<T>::PrePack(const Tensor& weights, int input_idx, bool& is_pac
return Status::OK();
}
const size_t hidden_size = static_cast<size_t>(weights_dims[0]);
const size_t input_hidden_size = static_cast<size_t>(weights_dims[0]);
const size_t hidden_size_x3 = static_cast<size_t>(weights_dims[1]);
const size_t hidden_size = hidden_size_x3 / 3;
const size_t head_size = hidden_size / num_heads_;
// Bail out if the weights shape has an expected shape.
@ -78,7 +79,7 @@ Status QAttention<T>::PrePack(const Tensor& weights, int input_idx, bool& is_pac
const auto* weights_data = static_cast<const uint8_t*>(weights.DataRaw());
weights_is_signed_ = weights.IsDataType<int8_t>();
packed_weights_size_ = MlasGemmPackBSize(head_size, hidden_size, weights_is_signed_);
packed_weights_size_ = MlasGemmPackBSize(head_size, input_hidden_size, weights_is_signed_);
if (packed_weights_size_ == 0) {
return Status::OK();
}
@ -89,7 +90,7 @@ Status QAttention<T>::PrePack(const Tensor& weights, int input_idx, bool& is_pac
packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc));
for (size_t i = 0; i < loop_len; i++) {
MlasGemmPackB(head_size, hidden_size, weights_data, hidden_size_x3, weights_is_signed_, packed_weights_data);
MlasGemmPackB(head_size, input_hidden_size, weights_data, hidden_size_x3, weights_is_signed_, packed_weights_data);
packed_weights_data += packed_weights_size_;
weights_data += head_size;
}
@ -102,8 +103,8 @@ Status QAttention<T>::PrePack(const Tensor& weights, int input_idx, bool& is_pac
template <typename T>
Status QAttention<T>::Compute(OpKernelContext* context) const {
// Input and output shapes:
// Input 0 - input : (batch_size, sequence_length, hidden_size)
// Input 1 - weights : (hidden_size, 3 * hidden_size)
// Input 0 - input : (batch_size, sequence_length, input_hidden_size)
// Input 1 - weights : (input_hidden_size, 3 * hidden_size)
// Input 2 - bias : (3 * hidden_size)
// Input 3 - input_scale : scalar
// Input 4 - weight_scale : scalar
@ -124,8 +125,9 @@ Status QAttention<T>::Compute(OpKernelContext* context) const {
const Tensor* w_zp_tensor = context->Input<Tensor>(7);
const Tensor* past_tensor = context->Input<Tensor>(8);
const TensorShape& weights_shape = (packed_weights_ ? weight_shape_ : weights->Shape());
ORT_RETURN_IF_ERROR(AttentionBase::CheckInputs(input->Shape(),
packed_weights_ ? weight_shape_ : weights->Shape(),
weights_shape,
bias->Shape(),
mask_index,
past_tensor));
@ -157,10 +159,17 @@ Status QAttention<T>::Compute(OpKernelContext* context) const {
const auto& shape = input->Shape();
const int batch_size = static_cast<int>(shape[0]);
const int sequence_length = static_cast<int>(shape[1]);
const int hidden_size = static_cast<int>(shape[2]);
const int input_hidden_size = static_cast<int>(shape[2]);
const auto hidden_size_x3 = weights_shape.GetDims()[1];
const int hidden_size = static_cast<int>(hidden_size_x3) / 3;
const int head_size = hidden_size / num_heads_;
Tensor* output = context->Output(0, shape);
std::vector<int64_t> output_shape(3);
output_shape[0] = shape[0];
output_shape[1] = shape[1];
output_shape[2] = static_cast<int64_t>(hidden_size);
Tensor* output = context->Output(0, output_shape);
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
@ -168,7 +177,8 @@ Status QAttention<T>::Compute(OpKernelContext* context) const {
constexpr size_t element_size = sizeof(T);
auto* tp = context->GetOperatorThreadPool();
// STEP.1: gemm_data(BS, 3NH) = Scale(input(BS, NH) x weights(NH, 3NH)) + bias(3NH)
// STEP.1: gemm_data(BS, 3NH) = Scale(input(BS, D) x weights(D, 3NH)) + bias(3NH)
// D is hidden dimension of input, where input_hidden_size (D) could be larger than hidden_size (NH) when model is pruned.
auto gemm_data = allocator->Alloc(SafeInt<size_t>(batch_size) * sequence_length * 3 * hidden_size * element_size);
BufferUniquePtr gemm_buffer(gemm_data, BufferDeleter(allocator));
@ -186,21 +196,21 @@ Status QAttention<T>::Compute(OpKernelContext* context) const {
const bool weights_is_signed = packed_weights_ ? weights_is_signed_ : weights->IsDataType<int8_t>();
const double cost =
static_cast<double>(sequence_length) * static_cast<double>(head_size) * static_cast<double>(hidden_size);
static_cast<double>(sequence_length) * static_cast<double>(head_size) * static_cast<double>(input_hidden_size);
ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t i = begin; i != end; ++i) {
const int batch_index = static_cast<int>((i / 3) / num_heads_);
const int head_index = static_cast<int>((i / 3) % num_heads_);
const int qkv_index = static_cast<int>(i % 3);
int input_offset = batch_index * sequence_length * hidden_size;
int input_offset = batch_index * sequence_length * input_hidden_size;
int weights_offset = qkv_index * hidden_size + head_index * head_size;
float* qkv_dest = QKV[qkv_index];
int qkv_offset = (batch_index * num_heads_ + head_index) * (sequence_length * head_size);
// original transposed iteration
// A: input (BxSxNxH) (B.)S x NH S x NH
// B: weights (NxHx3xNxH) NH x (3.N.)H NH x H
// A: input (BxSxD) (B.)S x D S x D
// B: weights (Dx3xNxH) D x (3.N.)H D x H
// C: QKV[qkv_index] (3xBxNxSxH) (3.B.N.)S x H S x H
MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR scale_bias_processor(qkv_dest + qkv_offset,
@ -215,9 +225,9 @@ Status QAttention<T>::Compute(OpKernelContext* context) const {
MlasGemm(
sequence_length, // M = S
head_size, // N = H
hidden_size, // K = NH
input_hidden_size, // K = D
input_data + input_offset, // A
hidden_size, // lda = NH
input_hidden_size, // lda = D
input_zero_point, // input zero point
packed_weight, // B
weight_zero_point, // weight zero point
@ -233,9 +243,9 @@ Status QAttention<T>::Compute(OpKernelContext* context) const {
MlasGemm(
sequence_length, // M = S
head_size, // N = H
hidden_size, // K = NH
input_hidden_size, // K = D
input_data + input_offset, // A
hidden_size, // lda = NH
input_hidden_size, // lda = D
input_zero_point, // input zero point
weights_data + weights_offset, // B
3 * hidden_size, // ldb = 3NH

View file

@ -41,16 +41,23 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* past = context->Input<Tensor>(4);
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), mask_index, past));
// Input and output shapes:
// Input 0 - input : (batch_size, sequence_length, hidden_size)
// Output 0 - output : (batch_size, sequence_length, hidden_size)
// input shape (batch_size, sequence_length, input_hidden_size)
const auto& shape = input->Shape();
int batch_size = static_cast<int>(shape[0]);
int sequence_length = static_cast<int>(shape[1]);
int hidden_size = static_cast<int>(shape[2]);
int head_size = hidden_size / num_heads_;
int input_hidden_size = static_cast<int>(shape[2]);
Tensor* output = context->Output(0, shape);
// bias shape (3 * hidden_size)
const auto& bias_shape = bias->Shape();
int hidden_size = static_cast<int>(bias_shape[0]) / 3;
int head_size = hidden_size / num_heads_;
std::vector<int64_t> output_shape(3);
output_shape[0] = shape[0];
output_shape[1] = shape[1];
output_shape[2] = static_cast<int64_t>(hidden_size);
Tensor* output = context->Output(0, output_shape);
int past_sequence_length = 0;
Tensor* present = GetPresent(context, past, batch_size, head_size, sequence_length, past_sequence_length);
@ -61,7 +68,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
// Use GEMM for fully connection.
int m = batch_size * sequence_length;
int n = 3 * hidden_size;
int k = hidden_size;
int k = input_hidden_size;
auto gemm_buffer = GetScratchBuffer<T>(batch_size * sequence_length * 3 * hidden_size * element_size);
typedef typename ToCudaType<T>::MappedType CudaT;

View file

@ -49,17 +49,6 @@ Status QAttention<T, int8_t>::CheckInputs(const Tensor* input,
const Tensor* i_zp_tensor,
const Tensor* w_zp_tensor,
const Tensor* past_tensor) const {
// Input and output shapes:
// Input 0 - input : (batch_size, sequence_length, hidden_size)
// Input 1 - weights : (hidden_size, 3 * hidden_size)
// Input 2 - bias : (3 * hidden_size)
// Input 3 - input_scale : scalar
// Input 4 - weight_scale : scalar
// Input 5 - mask_index : nullptr, (batch_size), (2 * batch_size), (batch_size, 1), (1, 1) or (batch_size, past_sequence_length + sequence_length)
// Input 6 - input_zero_point : scalar
// Input 7 - weight_zero_point : scalar
// Output : (batch_size, sequence_length, hidden_size)
ORT_RETURN_IF_ERROR(AttentionBase::CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), mask_index, past_tensor));
ORT_RETURN_IF_NOT(IsScalarOr1ElementVector(input_scale_tensor),
@ -89,12 +78,12 @@ Status QAttention<T, int8_t>::CheckInputs(const Tensor* input,
template <typename T>
Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
// Input and output shapes:
// Input 0 - input : (batch_size, sequence_length, hidden_size)
// Input 1 - weights : (hidden_size, 3 * hidden_size)
// Input 0 - input : (batch_size, sequence_length, input_hidden_size)
// Input 1 - weights : (input_hidden_size, 3 * hidden_size)
// Input 2 - bias : (3 * hidden_size)
// Input 3 - input_scale : scalar
// Input 4 - weight_scale : scalar
// Input 5 - mask_index : (batch_size)
// Input 5 - mask_index : nullptr, (batch_size), (2 * batch_size), (batch_size, 1), (1, 1) or (batch_size, past_sequence_length + sequence_length)
// Input 6 - input_zero_point : scalar
// Input 7 - weight_zero_point : scalar
// Input 8 - past : (2, batch_size, num_heads, past_sequence_length, head_size)
@ -124,10 +113,17 @@ Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
const auto& shape = input->Shape();
int batch_size = static_cast<int>(shape[0]);
int sequence_length = static_cast<int>(shape[1]);
int hidden_size = static_cast<int>(shape[2]);
int head_size = hidden_size / num_heads_;
int input_hidden_size = static_cast<int>(shape[2]);
Tensor* output = context->Output(0, shape);
const auto& bias_shape = bias->Shape();
const int hidden_size = static_cast<int>(bias_shape.GetDims()[0]) / 3;
const int head_size = hidden_size / num_heads_;
std::vector<int64_t> output_shape(3);
output_shape[0] = shape[0];
output_shape[1] = shape[1];
output_shape[2] = static_cast<int64_t>(hidden_size);
Tensor* output = context->Output(0, output_shape);
cublasHandle_t cublas = CublasHandle();
const size_t element_size = sizeof(T);
@ -135,7 +131,7 @@ Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
// Use GEMM for fully connection.
int m = batch_size * sequence_length;
int n = 3 * hidden_size;
int k = hidden_size;
int k = input_hidden_size;
auto gemm_buffer = GetScratchBuffer<T>(batch_size * sequence_length * 3 * hidden_size * element_size);
auto gemm_buffer_quantized = GetScratchBuffer<int32_t>(batch_size * sequence_length * 3 * hidden_size);

View file

@ -275,6 +275,59 @@ void FusedMatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) {
updateOutputShape(ctx, 0, resultShape);
}
void AttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_input_index) {
// Type inference
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 2, 0);
if (ctx.getNumOutputs() > 1) {
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 2, 1);
}
// Shape inference
if (hasInputShape(ctx, 0) && hasInputShape(ctx, 2)) {
auto& input_shape = getInputShape(ctx, 0);
auto& input_dims = input_shape.dim();
if (input_dims.size() != 3) {
fail_shape_inference("Inputs 0 shall be 3 dimensions");
}
auto& bias_shape = getInputShape(ctx, 2);
auto& bias_dims = bias_shape.dim();
if (bias_dims.size() != 1 || bias_shape.dim(0).dim_value() % 3 != 0) {
fail_shape_inference("Invalid bias shape");
}
ONNX_NAMESPACE::TensorShapeProto output_shape;
for (auto& dim : input_dims) {
*output_shape.add_dim() = dim;
}
output_shape.mutable_dim(2)->set_dim_value(bias_shape.dim(0).dim_value() / 3);
updateOutputShape(ctx, 0, output_shape);
if (ctx.getNumOutputs() > 1) {
if (hasInputShape(ctx, past_input_index)) {
auto& past_shape = getInputShape(ctx, past_input_index);
auto& past_dims = past_shape.dim();
if (past_dims.size() != 5) {
fail_shape_inference("Inputs 4 shall be 5 dimensions");
}
if (past_dims[3].has_dim_value() && input_dims[1].has_dim_value()) {
auto all_sequence_length = past_shape.dim(3).dim_value() + input_shape.dim(1).dim_value();
ONNX_NAMESPACE::TensorShapeProto present_shape;
for (auto& dim : past_dims) {
*present_shape.add_dim() = dim;
}
present_shape.mutable_dim(3)->set_dim_value(all_sequence_length);
updateOutputShape(ctx, 1, present_shape);
}
}
}
}
}
void RegisterBertSchemas() {
static const char* Attention_ver1_doc = R"DOC(
Multi-Head Self Attention that can be either unidirectional (like GPT-2) or bidirectional (like BERT).
@ -296,8 +349,8 @@ and present state are optional. Present state could appear in output even when p
"Whether every token can only attend to previous tokens. Default value is 0.",
AttributeProto::INT,
static_cast<int64_t>(0))
.Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, hidden_size), hidden_size = num_heads * head_size", "T")
.Input(1, "weight", "2D input tensor with shape (hidden_size, 3 * hidden_size)", "T")
.Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, input_hidden_size)", "T")
.Input(1, "weight", "2D input tensor with shape (input_hidden_size, 3 * hidden_size), where hidden_size = num_heads * head_size", "T")
.Input(2, "bias", "1D input tensor with shape (3 * hidden_size)", "T")
.Input(3, "mask_index", "Attention mask with shape (batch_size, past_sequence_length + sequence_length) or (batch_size, sequence_length, past_sequence_length + sequence_length), or index with shape (batch_size) or (2 * batch_size).", "M", OpSchema::Optional)
.Input(4, "past", "past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).", "T", OpSchema::Optional)
@ -306,42 +359,8 @@ and present state are optional. Present state could appear in output even when p
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.")
.TypeConstraint("M", {"tensor(int32)"}, "Constrain mask index to integer types")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (ctx.getNumOutputs() > 1) {
propagateElemTypeFromInputToOutput(ctx, 0, 1);
}
if (hasInputShape(ctx, 0)) {
propagateShapeFromInputToOutput(ctx, 0, 0);
if (ctx.getNumOutputs() > 1) {
auto& input_shape = getInputShape(ctx, 0);
auto& input_dims = input_shape.dim();
if (input_dims.size() != 3) {
fail_shape_inference("Inputs 0 shall be 3 dimensions");
}
if (hasInputShape(ctx, 4)) {
auto& past_shape = getInputShape(ctx, 4);
auto& past_dims = past_shape.dim();
if (past_dims.size() != 5) {
fail_shape_inference("Inputs 4 shall be 5 dimensions");
}
if (past_dims[3].has_dim_value() && input_dims[1].has_dim_value()) {
auto all_sequence_length = past_shape.dim(3).dim_value() + input_shape.dim(1).dim_value();
ONNX_NAMESPACE::TensorShapeProto present_shape;
for (auto& dim : past_dims) {
*present_shape.add_dim() = dim;
}
present_shape.mutable_dim(3)->set_dim_value(all_sequence_length);
updateOutputShape(ctx, 1, present_shape);
}
}
}
}
constexpr int past_input_index = 4;
AttentionTypeAndShapeInference(ctx, past_input_index);
});
ONNX_CONTRIB_OPERATOR_SCHEMA(QAttention)
@ -356,12 +375,12 @@ and present state are optional. Present state could appear in output even when p
.Input(
0,
"input",
"3D input tensor with shape (batch_size, sequence_length, hidden_size), hidden_size = num_heads * head_size",
"3D input tensor with shape (batch_size, sequence_length, input_hidden_size)",
"T1")
.Input(
1,
"weight",
"2D input tensor with shape (hidden_size, 3 * hidden_size)",
"2D input tensor with shape (input_hidden_size, 3 * hidden_size), hidden_size = num_heads * head_size",
"T2")
.Input(
2,
@ -418,18 +437,8 @@ and present state are optional. Present state could appear in output even when p
.TypeConstraint("T3", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.")
.TypeConstraint("T4", {"tensor(int32)"}, "Constrain mask index to integer types")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
// Type inference
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 2, 0);
// Shape inference
// if the input shape doesn't exist, further shape inference is not possible
if (!hasNInputShapes(ctx, 1)) {
return;
}
ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, 0, 0);
return;
constexpr int past_input_index = 8;
AttentionTypeAndShapeInference(ctx, past_input_index);
});
static const char* Longformer_Attention_doc = R"DOC(
@ -456,7 +465,7 @@ Global attention flags have value 1 for the tokens attend globally and 0 otherwi
.Input(4, "global_weight", "2D input tensor with shape (hidden_size, 3 * hidden_size)", "T")
.Input(5, "global_bias", "1D input tensor with shape (3 * hidden_size)", "T")
.Input(6, "global", "Global attention flags with shape (batch_size, sequence_length)", "G")
.Output(0, "output", "3D output tensor with shape (batch_size, append_length, hidden_size)", "T")
.Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", "T")
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.")
.TypeConstraint("G", {"tensor(int32)"}, "Constrain to integer types")
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);

View file

@ -2,29 +2,32 @@ import onnx
from onnx import helper
from onnx import TensorProto
from enum import Enum
from packaging import version
if version.parse(onnx.__version__) == version.parse('1.8.0'):
opset_version = 13
elif version.parse(onnx.__version__) == version.parse('1.6.0'):
opset_version = 11
else:
raise RuntimeError("Please pip install onnx==1.8.0 or 1.6.0 before running this script")
def GenerateNodes(model_name, has_cast, suffix=''):
nodes = [ # LayerNorm subgraph
helper.make_node("Shape", ["input_ids" + suffix], ["shape1_out" + suffix], "shape1" + suffix),
helper.make_node("Gather", ["shape1_out" + suffix, "indices_0"], ["gather0_out" + suffix], "gather0" + suffix),
helper.make_node("Unsqueeze", ["gather0_out" + suffix], ["unsqueeze0_out" + suffix],
"unsqueeze0" + suffix,
axes=[0]),
helper.make_node("Unsqueeze", ["gather0_out" + suffix, "axes_0"], ["unsqueeze0_out" + suffix], "unsqueeze0" + suffix) if opset_version == 13 \
else helper.make_node("Unsqueeze", ["gather0_out" + suffix], ["unsqueeze0_out" + suffix], "unsqueeze0" + suffix, axes=[0]),
helper.make_node("Shape", ["input_ids" + suffix], ["shape2_out" + suffix], "shape2" + suffix),
helper.make_node("Gather", ["shape2_out" + suffix, "indices_1"], ["gather1_out" + suffix], "gather1" + suffix),
helper.make_node("Unsqueeze", ["gather1_out" + suffix], ["unsqueeze1_out" + suffix],
"unsqueeze1" + suffix,
axes=[0]),
helper.make_node("Unsqueeze", ["gather1_out" + suffix, "axes_0"], ["unsqueeze1_out" + suffix], "unsqueeze1" + suffix) if opset_version == 13 \
else helper.make_node("Unsqueeze", ["gather1_out" + suffix], ["unsqueeze1_out" + suffix], "unsqueeze1" + suffix, axes=[0]),
helper.make_node("Concat", ["unsqueeze0_out" + suffix, "unsqueeze1_out" + suffix], ["concat_out" + suffix],
"concat" + suffix,
axis=0),
"concat" + suffix, axis=0),
helper.make_node("Cast", ["gather1_out" + suffix], ["cast_out" + suffix], "cast" + suffix, to=7),
helper.make_node("Range", ["start_0", "cast_out" + suffix if has_cast else "gather1_out" + suffix, "delta_1"],
["range_out" + suffix], "range" + suffix),
helper.make_node("Unsqueeze", ["range_out" + suffix], ["unsqueeze2_out" + suffix],
"unsqueeze2" + suffix,
axes=[0]),
helper.make_node("Unsqueeze", ["range_out" + suffix, "axes_0"], ["unsqueeze2_out" + suffix], "unsqueeze2" + suffix) if opset_version == 13 \
else helper.make_node("Unsqueeze", ["range_out" + suffix], ["unsqueeze2_out" + suffix], "unsqueeze2" + suffix, axes=[0]),
helper.make_node("Expand", ["unsqueeze2_out" + suffix, "concat_out" + suffix], ["expand_out" + suffix],
"expand" + suffix),
helper.make_node("Gather", ["pos_embed", "expand_out" + suffix], ["pos_gather_out" + suffix],
@ -43,10 +46,8 @@ def GenerateNodes(model_name, has_cast, suffix=''):
axis=-1,
epsion=0.000009999999747378752),
helper.make_node("Cast", ["input_mask" + suffix], ["mask_cast_out" + suffix], "mask_cast" + suffix, to=6),
helper.make_node("ReduceSum", ["mask_cast_out" + suffix], ["mask_index_out" + suffix],
"mask_index" + suffix,
axes=[1],
keepdims=0),
helper.make_node("ReduceSum", ["mask_cast_out" + suffix, "axes_1"], ["mask_index_out" + suffix], "mask_index" + suffix, keepdims=0) if opset_version == 13 \
else helper.make_node("ReduceSum", ["mask_cast_out" + suffix], ["mask_index_out" + suffix], "mask_index" + suffix, axes=[1], keepdims=0),
helper.make_node("Attention", ["layernorm_out" + suffix, "qkv_weights", "qkv_bias", "mask_index_out" + suffix],
["att_out" + suffix],
"att" + suffix,
@ -63,7 +64,7 @@ def GenerateNodes(model_name, has_cast, suffix=''):
def GenerateInitializers():
# hidden_size=4, num_heads=2, max_seq_length=3
# hidden_size=4, num_heads=2
initializers = [ # initializers
helper.make_tensor('indices_0', TensorProto.INT64, [], [0]),
helper.make_tensor('indices_1', TensorProto.INT64, [], [1]),
@ -75,12 +76,14 @@ def GenerateInitializers():
helper.make_tensor('seg_embed', TensorProto.FLOAT, [2, 4], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
helper.make_tensor('layer_norm_weight', TensorProto.FLOAT, [4], [1.0, 2.0, 3.0, 4.0]),
helper.make_tensor('layer_norm_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]),
helper.make_tensor('qkv_weights', TensorProto.FLOAT, [4, 4],
[1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
helper.make_tensor('qkv_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]),
helper.make_tensor('qkv_weights', TensorProto.FLOAT, [4, 12], [0.1] * 4 * 12),
helper.make_tensor('qkv_bias', TensorProto.FLOAT, [12],
[0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4]),
helper.make_tensor('matmul_weight', TensorProto.FLOAT, [4, 4],
[1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
helper.make_tensor('add_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]),
helper.make_tensor('axes_0', TensorProto.INT64, [1], [0]),
helper.make_tensor('axes_1', TensorProto.INT64, [1], [1]),
]
return initializers
@ -154,7 +157,9 @@ def GenerateModel5(model_name):
axis=-1,
epsion=0.000009999999747378752),
helper.make_node("Cast", ["input_mask"], ["mask_cast_out"], "mask_cast", to=6),
helper.make_node("ReduceSum", ["mask_cast_out"], ["mask_index_out"], "mask_index", axes=[1], keepdims=0),
helper.make_node("ReduceSum", ["mask_cast_out", "axes_1"], ["mask_index_out"], "mask_index", keepdims=0) if opset_version == 13 \
else helper.make_node("ReduceSum", ["mask_cast_out"], ["mask_index_out"], "mask_index", axes=[1], keepdims=0),
helper.make_node("Attention", ["layernorm_out", "qkv_weights", "qkv_bias", "mask_index_out"], ["att_out"],
"att",
domain="com.microsoft",
@ -164,11 +169,7 @@ def GenerateModel5(model_name):
helper.make_node("Add", ["add_out", "layernorm_out"], ["add2_out"], "add2")
]
qkv_weights = [
1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0,
3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0,
1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0
]
qkv_weights = [1.0] * hidden_size * (3 * hidden_size)
initializers = [ # initializers
helper.make_tensor('word_embed', TensorProto.FLOAT, [2, hidden_size], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
@ -185,6 +186,7 @@ def GenerateModel5(model_name):
helper.make_tensor('matmul_weight', TensorProto.FLOAT, [hidden_size, hidden_size],
[1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
helper.make_tensor('add_bias', TensorProto.FLOAT, [hidden_size], [0.1, 0.2, 0.3, 0.4]),
helper.make_tensor('axes_1', TensorProto.INT64, [1], [1]),
]
graph = helper.make_graph(
@ -208,16 +210,19 @@ def GenerateModel6(model_name):
nodes = [ # LayerNorm subgraph
helper.make_node("Shape", ["input_ids"], ["shape1_out"], "shape1"),
helper.make_node("Gather", ["shape1_out", "indices_0"], ["gather0_out"], "gather0"),
helper.make_node("Unsqueeze", ["gather0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]),
helper.make_node("Unsqueeze", ["gather0_out", "axes_0"], ["unsqueeze0_out"], "unsqueeze0") if opset_version == 13 \
else helper.make_node("Unsqueeze", ["gather0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]),
helper.make_node("Shape", ["input_ids"], ["shape2_out"], "shape2"),
helper.make_node("Gather", ["shape2_out", "indices_1"], ["gather1_out"], "gather1"),
helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]),
helper.make_node("Unsqueeze", ["gather1_out", "axes_0"], ["unsqueeze1_out"], "unsqueeze1") if opset_version == 13 \
else helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]),
helper.make_node("Concat", ["unsqueeze0_out", "unsqueeze1_out"], ["concat_out"], "concat", axis=0),
helper.make_node("Reshape", ["concat_out", "reshape_init"], ["reshape_out"], "reshape"),
helper.make_node("Equal", ["reshape_out", "equal_init"], ["equal_out"], "equal"),
helper.make_node("Where", ["equal_out", "where_init", "reshape_out"], ["where_out"], "where"),
helper.make_node("Range", ["start_0", "gather1_out", "delta_1"], ["range_out"], "range"),
helper.make_node("Unsqueeze", ["range_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]),
helper.make_node("Unsqueeze", ["range_out", "axes_0"], ["unsqueeze2_out"], "unsqueeze2") if opset_version == 13 \
else helper.make_node("Unsqueeze", ["range_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]),
helper.make_node("Expand", ["unsqueeze2_out", "where_out"], ["expand_out"], "expand"),
helper.make_node("Gather", ["pos_embed", "expand_out"], ["pos_gather_out"], "pos_gather"),
helper.make_node("Gather", ["word_embed", "input_ids"], ["word_gather_out"], "word_gather"),
@ -229,7 +234,8 @@ def GenerateModel6(model_name):
axis=-1,
epsion=0.000009999999747378752),
helper.make_node("Cast", ["input_mask"], ["mask_cast_out"], "mask_cast", to=6),
helper.make_node("ReduceSum", ["mask_cast_out"], ["mask_index_out"], "mask_index", axes=[1], keepdims=0),
helper.make_node("ReduceSum", ["mask_cast_out", "axes_1"], ["mask_index_out"], "mask_index", keepdims=0) if opset_version == 13 \
else helper.make_node("ReduceSum", ["mask_cast_out"], ["mask_index_out"], "mask_index", axes=[1], keepdims=0),
helper.make_node("Attention", ["layernorm_out", "qkv_weights", "qkv_bias", "mask_index_out"], ["att_out"],
"att",
domain="com.microsoft",
@ -251,15 +257,16 @@ def GenerateModel6(model_name):
helper.make_tensor('seg_embed', TensorProto.FLOAT, [2, 4], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
helper.make_tensor('layer_norm_weight', TensorProto.FLOAT, [4], [1.0, 2.0, 3.0, 4.0]),
helper.make_tensor('layer_norm_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]),
helper.make_tensor('qkv_weights', TensorProto.FLOAT, [4, 4],
[1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
helper.make_tensor('qkv_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]),
helper.make_tensor('qkv_weights', TensorProto.FLOAT, [4, 12], [0.1] * 4 * 12),
helper.make_tensor('qkv_bias', TensorProto.FLOAT, [12], [0.1] * 12),
helper.make_tensor('matmul_weight', TensorProto.FLOAT, [4, 4],
[1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
helper.make_tensor('add_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]),
helper.make_tensor('reshape_init', TensorProto.INT64, [1], [-1]),
helper.make_tensor('equal_init', TensorProto.INT64, [2], [-1, -1]),
helper.make_tensor('where_init', TensorProto.INT64, [2], [1, 1]),
helper.make_tensor('axes_0', TensorProto.INT64, [1], [0]),
helper.make_tensor('axes_1', TensorProto.INT64, [1], [1]),
]
graph = helper.make_graph(
@ -278,12 +285,9 @@ def GenerateModel6(model_name):
model = helper.make_model(graph)
onnx.save(model, model_name)
def GenerateInitializers2(hidden_size):
qkv_weights = [
1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0,
3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0,
1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0
]
qkv_weights = [1.0] * hidden_size * (3 * hidden_size)
initializers = [ # initializers
helper.make_tensor('word_embed', TensorProto.FLOAT, [2, hidden_size], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
@ -300,31 +304,32 @@ def GenerateInitializers2(hidden_size):
helper.make_tensor('matmul_weight', TensorProto.FLOAT, [hidden_size, hidden_size],
[1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]),
helper.make_tensor('add_bias', TensorProto.FLOAT, [hidden_size], [0.1, 0.2, 0.3, 0.4]),
helper.make_tensor('axes_0', TensorProto.INT64, [1], [0]),
helper.make_tensor('axes_1', TensorProto.INT64, [1], [1]),
]
return initializers
def GenerateNodes2(attention_heads):
nodes = [
helper.make_node("Gather", ["word_embed", "input_ids"], ["word_gather_out"], "word_gather", axis=0),
helper.make_node("Shape", ["input_ids"], ["shape0_out"], "shape0"),
helper.make_node("Gather", ["shape0_out", "indices_1"], ["gather0_out"], "gather0"),
helper.make_node("Range", ["start", "gather0_out", "delta"], ["range0_out"], "range0"),
helper.make_node("Unsqueeze", ["range0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]),
helper.make_node("Unsqueeze", ["range0_out", "axes_0"], ["unsqueeze0_out"], "unsqueeze0") if opset_version == 13 \
else helper.make_node("Unsqueeze", ["range0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]),
helper.make_node("Shape", ["input_ids"], ["shape1_out"], "shape1"),
helper.make_node("Expand", ["unsqueeze0_out", "shape1_out"], ["expand_out"], "expand"),
helper.make_node("Gather", ["pos_embed", "expand_out"], ["pos_gather_out"], "pos_gather", axis=0),
helper.make_node("Add", ["word_gather_out", "pos_gather_out"], ["add1_out"], "add1"),
helper.make_node("LayerNormalization", ["add1_out", "layer_norm_weight", "layer_norm_bias"], ["layernorm_out"],
"layernorm",
axis=-1,
epsion=0.000009999999747378752),
helper.make_node("Cast", ["input_mask"], ["mask_cast_out"], "mask_cast", to=6),
helper.make_node("ReduceSum", ["mask_cast_out"], ["mask_index_out"], "mask_index", axes=[1], keepdims=0),
helper.make_node("ReduceSum", ["mask_cast_out", "axes_1"], ["mask_index_out"], "mask_index", keepdims=0) if opset_version == 13 \
else helper.make_node("ReduceSum", ["mask_cast_out"], ["mask_index_out"], "mask_index", axes=[1], keepdims=0),
helper.make_node("Attention", ["layernorm_out", "qkv_weights", "qkv_bias", "mask_index_out"], ["att_out"],
"att",
domain="com.microsoft",
@ -336,6 +341,7 @@ def GenerateNodes2(attention_heads):
return nodes
def GenerateModel7(model_name):
batch_size = 2
hidden_size = 4
@ -361,6 +367,7 @@ def GenerateModel7(model_name):
model = helper.make_model(graph)
onnx.save(model, model_name)
def GenerateModel8(model_name):
batch_size = -1
hidden_size = 4
@ -395,6 +402,7 @@ def GenerateModel8(model_name):
model = helper.make_model(graph)
onnx.save(model, model_name)
def GenerateModel9(model_name):
batch_size = -1
hidden_size = 4
@ -412,10 +420,13 @@ def GenerateModel9(model_name):
helper.make_node("Expand", ["unsqueeze0_out", "shape_out"], ["expand_out"], "expand"),
helper.make_node("Gather", ["shape_out", "indices_0"], ["gather1_out"], "gather1"),
helper.make_node("Gather", ["shape_out", "indices_1"], ["gather2_out"], "gather2"),
helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]),
helper.make_node("Unsqueeze", ["gather2_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]),
helper.make_node("Unsqueeze", ["gather1_out", "axes_0"], ["unsqueeze1_out"], "unsqueeze1") if opset_version == 13 \
else helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]),
helper.make_node("Unsqueeze", ["gather2_out", "axes_0"], ["unsqueeze2_out"], "unsqueeze2") if opset_version == 13 \
else helper.make_node("Unsqueeze", ["gather2_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]),
helper.make_node("Concat", ["unsqueeze1_out", "unsqueeze2_out"], ["concat_out"], "concat", axis=0),
helper.make_node('ConstantOfShape', ['concat_out'], ['constant_of_shape_out'], "constant_of_shape",
helper.make_node('ConstantOfShape', ['concat_out'], ['constant_of_shape_out'],
"constant_of_shape",
value=helper.make_tensor('mask_shape', TensorProto.FLOAT, [1], [1.0])),
helper.make_node("Cast", ["constant_of_shape_out"], ["mask_cast_out"], "mask_cast", to=6),
]
@ -437,11 +448,21 @@ def GenerateModel9(model_name):
model = helper.make_model(graph)
onnx.save(model, model_name)
GenerateModel3('embed_layer_norm_format3.onnx', True)
GenerateModel3('embed_layer_norm_format3_no_cast.onnx', False)
GenerateModel5('embed_layer_norm_format5.onnx')
GenerateModel6('embed_layer_norm_format6.onnx')
GenerateModel7('embed_layer_norm_format7.onnx') #distilbert
GenerateModel8('embed_layer_norm_format8.onnx') #distilbert & shape nodes integration with input mask
GenerateModel9('embed_layer_norm_format9.onnx') #distilbert & shape nodes integration without input mask
GenerateMultipleEmbedModel('embed_layer_norm_multiple.onnx')
if opset_version == 11:
GenerateModel3('embed_layer_norm_format3.onnx', True)
GenerateModel3('embed_layer_norm_format3_no_cast.onnx', False)
GenerateModel5('embed_layer_norm_format5.onnx')
GenerateModel6('embed_layer_norm_format6.onnx')
GenerateModel7('embed_layer_norm_format7.onnx') #distilbert
GenerateModel8('embed_layer_norm_format8.onnx') #distilbert & shape nodes integration with input mask
GenerateModel9('embed_layer_norm_format9.onnx') #distilbert & shape nodes integration without input mask
GenerateMultipleEmbedModel('embed_layer_norm_multiple.onnx')
else:
GenerateModel3('embed_layer_norm_format3_opset13.onnx', True)
GenerateModel3('embed_layer_norm_format3_no_cast_opset13.onnx', False)
GenerateModel5('embed_layer_norm_format5_opset13.onnx')
GenerateModel6('embed_layer_norm_format6_opset13.onnx')
GenerateModel7('embed_layer_norm_format7_opset13.onnx')
GenerateModel8('embed_layer_norm_format8_opset13.onnx')
GenerateModel9('embed_layer_norm_format9_opset13.onnx')
GenerateMultipleEmbedModel('embed_layer_norm_multiple_opset13.onnx')