diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index be0aae79eb..65352fe2a1 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -54,14 +54,17 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, const Tensor*& mask_index, const Tensor* past) const { // Input shapes: - // input : (batch_size, sequence_length, hidden_size) - // weights : (hidden_size, 3 * hidden_size) + // input : (batch_size, sequence_length, input_hidden_size) + // weights : (input_hidden_size, 3 * hidden_size) // bias : (3 * hidden_size) // mask_index : nullptr, (batch_size), (2 * batch_size), // or (batch_size, 1), (1, 1) // or (batch_size, past_sequence_length + sequence_length) // or (batch_size, sequence_length, past_sequence_length + sequence_length) // past : (2, batch_size, num_heads, past_sequence_length, head_size) + // + // Where hidden_size = num_heads * head_size. + // When a model is pruned (like some attention heads are removed), hidden_size < input_hidden_size. const auto& dims = input_shape.GetDims(); if (dims.size() != 3) { @@ -70,11 +73,6 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, } int batch_size = static_cast(dims[0]); int sequence_length = static_cast(dims[1]); - int hidden_size = static_cast(dims[2]); - if (hidden_size % num_heads_ != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 0 dimension 2 should be divisiable by value of the num_heads attribute."); - } const auto& weights_dims = weights_shape.GetDims(); if (weights_dims.size() != 2) { @@ -85,8 +83,15 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 1 dimension 0 should have same length as dimension 2 of input 0"); } - if (weights_dims[1] != 3 * weights_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'weights' dimension 1 should be 3 times of dimension 0"); + + int hidden_size = static_cast(weights_dims[1]) / 3; + if (3 * hidden_size != static_cast(weights_dims[1])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 1 dimension 1 should be 3 times of hidden dimension"); + } + + if (hidden_size % num_heads_ != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "hidden_size should be divisiable by num_heads."); } const auto& bias_dims = bias_shape.GetDims(); @@ -125,7 +130,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, const auto& mask_dims = mask_index->Shape().GetDims(); if (mask_dims.size() == 1) { if (static_cast(mask_dims[0]) != batch_size && static_cast(mask_dims[0]) != 2 * batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' dimension 0 shall have length of batch_size or 2 * batch_size"); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with 1D data shall have length of batch_size or 2 * batch_size"); } } else if (mask_dims.size() == 2) { if (static_cast(mask_dims[0]) != batch_size || static_cast(mask_dims[1]) != past_sequence_length + sequence_length) { @@ -134,12 +139,12 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, // Mask will have same value after propogation, which has same effect as no mask. mask_index = nullptr; } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with raw attention mask shall have shape batch_size x (past_sequence_length + sequence_length)"); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with 2D data shall have shape batch_size x (past_sequence_length + sequence_length)"); } } } else if (mask_dims.size() == 3) { if (static_cast(mask_dims[0]) != batch_size || mask_dims[1] != sequence_length || static_cast(mask_dims[2]) != past_sequence_length + sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' of 3d shall have shape batch_size x sequence_length x (past_sequence_length + sequence_length)"); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with 3D data shall have shape batch_size x sequence_length x (past_sequence_length + sequence_length)"); } } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'mask_index' is expected to have 1, 2 or 3 dimensions, got ", @@ -193,8 +198,9 @@ Status Attention::PrePack(const Tensor& weights, int input_idx, bool& is_pack return Status::OK(); } - const size_t hidden_size = static_cast(weights_dims[0]); + const size_t input_hidden_size = static_cast(weights_dims[0]); const size_t hidden_size_x3 = static_cast(weights_dims[1]); + const size_t hidden_size = hidden_size_x3 / 3; const size_t head_size = hidden_size / num_heads_; // Bail out if the weights shape has an expected shape. @@ -204,7 +210,7 @@ Status Attention::PrePack(const Tensor& weights, int input_idx, bool& is_pack const auto* weights_data = weights.Data(); - packed_weights_size_ = MlasGemmPackBSize(head_size, hidden_size); + packed_weights_size_ = MlasGemmPackBSize(head_size, input_hidden_size); if (packed_weights_size_ == 0) { return Status::OK(); } @@ -215,7 +221,7 @@ Status Attention::PrePack(const Tensor& weights, int input_idx, bool& is_pack packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc)); for (size_t i = 0; i < loop_len; i++) { - MlasGemmPackB(CblasNoTrans, head_size, hidden_size, weights_data, hidden_size_x3, packed_weights_data); + MlasGemmPackB(CblasNoTrans, head_size, input_hidden_size, weights_data, hidden_size_x3, packed_weights_data); packed_weights_data += packed_weights_size_; weights_data += head_size; } @@ -232,8 +238,9 @@ Status Attention::Compute(OpKernelContext* context) const { const Tensor* mask_index = context->Input(3); const Tensor* past = context->Input(4); + const TensorShape& weights_shape = (packed_weights_ ? weight_shape_ : weights->Shape()); ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), - weights ? weights->Shape() : weight_shape_, + weights_shape, bias->Shape(), mask_index, past)); @@ -241,10 +248,17 @@ Status Attention::Compute(OpKernelContext* context) const { const auto& shape = input->Shape().GetDims(); const int batch_size = static_cast(shape[0]); const int sequence_length = static_cast(shape[1]); - const int hidden_size = static_cast(shape[2]); + const int input_hidden_size = static_cast(shape[2]); + + const auto& weights_dims = weights_shape.GetDims(); + const int hidden_size = static_cast(weights_dims[1]) / 3; const int head_size = hidden_size / num_heads_; - Tensor* output = context->Output(0, shape); + std::vector output_shape(3); + output_shape[0] = shape[0]; + output_shape[1] = shape[1]; + output_shape[2] = static_cast(hidden_size); + Tensor* output = context->Output(0, output_shape); constexpr size_t element_size = sizeof(T); @@ -253,7 +267,8 @@ Status Attention::Compute(OpKernelContext* context) const { auto* tp = context->GetOperatorThreadPool(); // Compute Q, K, V - // gemm_data(BS, 3NH) = input(BS, NH) x weights(NH, 3NH) + bias(3NH) + // gemm_data(BS, 3NH) = input(BS, D) x weights(D, 3NH) + bias(3NH) + // D (input_hidden_size) is hidden dimension of input, where D could be larger than hidden_size (NH) when model is pruned. auto gemm_data = allocator->Alloc(SafeInt(batch_size) * sequence_length * 3 * hidden_size * element_size); BufferUniquePtr gemm_buffer(gemm_data, BufferDeleter(allocator)); @@ -269,14 +284,14 @@ Status Attention::Compute(OpKernelContext* context) const { const auto* bias_data = bias->template Data(); const double cost = - static_cast(sequence_length) * static_cast(head_size) * static_cast(hidden_size); + static_cast(sequence_length) * static_cast(head_size) * static_cast(input_hidden_size); ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { for (std::ptrdiff_t i = begin; i != end; ++i) { const int batch_index = static_cast((i / 3) / num_heads_); const int head_index = static_cast((i / 3) % num_heads_); const int qkv_index = static_cast(i % 3); - int input_offset = batch_index * sequence_length * hidden_size; + int input_offset = batch_index * sequence_length * input_hidden_size; int weights_offset = qkv_index * hidden_size + head_index * head_size; T* qkv_dest = QKV[qkv_index]; int qkv_offset = (batch_index * num_heads_ + head_index) * (sequence_length * head_size); @@ -290,8 +305,8 @@ Status Attention::Compute(OpKernelContext* context) const { } // original transposed iteration - // A: input (BxSxNxH) (B.)S x NH S x NH - // B: weights (NxHx3xNxH) NH x (3.N.)H NH x H + // A: input (BxSxD) (B.)S x D S x D + // B: weights (Dx3xNxH) D x (3.N.)H D x H // C: QKV[qkv_index] (3xBxNxSxH) (3.B.N.)S x H S x H if (packed_weights_) { const auto* packed_weight = @@ -300,30 +315,31 @@ Status Attention::Compute(OpKernelContext* context) const { CblasNoTrans, // TransA = no sequence_length, // M = S head_size, // N = H - hidden_size, // K = NH + input_hidden_size, // K = D 1.0f, // alpha input_data + input_offset, // A - hidden_size, // lda = NH + input_hidden_size, // lda = D packed_weight, // B 1.0f, // beta qkv_dest + qkv_offset, // C head_size, // ldc nullptr); // use single-thread } else { - math::GemmEx(CblasNoTrans, // TransA = no - CblasNoTrans, // TransB = no - sequence_length, // M = S - head_size, // N = H - hidden_size, // K = NH - 1.0f, // alpha - input_data + input_offset, // A - hidden_size, // lda = NH - weights_data + weights_offset, // B - 3 * hidden_size, // ldb = 3NH - 1.0f, // beta - qkv_dest + qkv_offset, // C - head_size, // ldc - nullptr // use single-thread + math::GemmEx( + CblasNoTrans, // TransA = no + CblasNoTrans, // TransB = no + sequence_length, // M = S + head_size, // N = H + input_hidden_size, // K = D + 1.0f, // alpha + input_data + input_offset, // A + input_hidden_size, // lda = D + weights_data + weights_offset, // B + 3 * hidden_size, // ldb = 3NH + 1.0f, // beta + qkv_dest + qkv_offset, // C + head_size, // ldc + nullptr // use single-thread ); } } diff --git a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc index 2c2ef835aa..e871a69c86 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc @@ -66,8 +66,9 @@ Status QAttention::PrePack(const Tensor& weights, int input_idx, bool& is_pac return Status::OK(); } - const size_t hidden_size = static_cast(weights_dims[0]); + const size_t input_hidden_size = static_cast(weights_dims[0]); const size_t hidden_size_x3 = static_cast(weights_dims[1]); + const size_t hidden_size = hidden_size_x3 / 3; const size_t head_size = hidden_size / num_heads_; // Bail out if the weights shape has an expected shape. @@ -78,7 +79,7 @@ Status QAttention::PrePack(const Tensor& weights, int input_idx, bool& is_pac const auto* weights_data = static_cast(weights.DataRaw()); weights_is_signed_ = weights.IsDataType(); - packed_weights_size_ = MlasGemmPackBSize(head_size, hidden_size, weights_is_signed_); + packed_weights_size_ = MlasGemmPackBSize(head_size, input_hidden_size, weights_is_signed_); if (packed_weights_size_ == 0) { return Status::OK(); } @@ -89,7 +90,7 @@ Status QAttention::PrePack(const Tensor& weights, int input_idx, bool& is_pac packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc)); for (size_t i = 0; i < loop_len; i++) { - MlasGemmPackB(head_size, hidden_size, weights_data, hidden_size_x3, weights_is_signed_, packed_weights_data); + MlasGemmPackB(head_size, input_hidden_size, weights_data, hidden_size_x3, weights_is_signed_, packed_weights_data); packed_weights_data += packed_weights_size_; weights_data += head_size; } @@ -102,8 +103,8 @@ Status QAttention::PrePack(const Tensor& weights, int input_idx, bool& is_pac template Status QAttention::Compute(OpKernelContext* context) const { // Input and output shapes: - // Input 0 - input : (batch_size, sequence_length, hidden_size) - // Input 1 - weights : (hidden_size, 3 * hidden_size) + // Input 0 - input : (batch_size, sequence_length, input_hidden_size) + // Input 1 - weights : (input_hidden_size, 3 * hidden_size) // Input 2 - bias : (3 * hidden_size) // Input 3 - input_scale : scalar // Input 4 - weight_scale : scalar @@ -124,8 +125,9 @@ Status QAttention::Compute(OpKernelContext* context) const { const Tensor* w_zp_tensor = context->Input(7); const Tensor* past_tensor = context->Input(8); + const TensorShape& weights_shape = (packed_weights_ ? weight_shape_ : weights->Shape()); ORT_RETURN_IF_ERROR(AttentionBase::CheckInputs(input->Shape(), - packed_weights_ ? weight_shape_ : weights->Shape(), + weights_shape, bias->Shape(), mask_index, past_tensor)); @@ -157,10 +159,17 @@ Status QAttention::Compute(OpKernelContext* context) const { const auto& shape = input->Shape(); const int batch_size = static_cast(shape[0]); const int sequence_length = static_cast(shape[1]); - const int hidden_size = static_cast(shape[2]); + const int input_hidden_size = static_cast(shape[2]); + + const auto hidden_size_x3 = weights_shape.GetDims()[1]; + const int hidden_size = static_cast(hidden_size_x3) / 3; const int head_size = hidden_size / num_heads_; - Tensor* output = context->Output(0, shape); + std::vector output_shape(3); + output_shape[0] = shape[0]; + output_shape[1] = shape[1]; + output_shape[2] = static_cast(hidden_size); + Tensor* output = context->Output(0, output_shape); AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); @@ -168,7 +177,8 @@ Status QAttention::Compute(OpKernelContext* context) const { constexpr size_t element_size = sizeof(T); auto* tp = context->GetOperatorThreadPool(); - // STEP.1: gemm_data(BS, 3NH) = Scale(input(BS, NH) x weights(NH, 3NH)) + bias(3NH) + // STEP.1: gemm_data(BS, 3NH) = Scale(input(BS, D) x weights(D, 3NH)) + bias(3NH) + // D is hidden dimension of input, where input_hidden_size (D) could be larger than hidden_size (NH) when model is pruned. auto gemm_data = allocator->Alloc(SafeInt(batch_size) * sequence_length * 3 * hidden_size * element_size); BufferUniquePtr gemm_buffer(gemm_data, BufferDeleter(allocator)); @@ -186,21 +196,21 @@ Status QAttention::Compute(OpKernelContext* context) const { const bool weights_is_signed = packed_weights_ ? weights_is_signed_ : weights->IsDataType(); const double cost = - static_cast(sequence_length) * static_cast(head_size) * static_cast(hidden_size); + static_cast(sequence_length) * static_cast(head_size) * static_cast(input_hidden_size); ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { for (std::ptrdiff_t i = begin; i != end; ++i) { const int batch_index = static_cast((i / 3) / num_heads_); const int head_index = static_cast((i / 3) % num_heads_); const int qkv_index = static_cast(i % 3); - int input_offset = batch_index * sequence_length * hidden_size; + int input_offset = batch_index * sequence_length * input_hidden_size; int weights_offset = qkv_index * hidden_size + head_index * head_size; float* qkv_dest = QKV[qkv_index]; int qkv_offset = (batch_index * num_heads_ + head_index) * (sequence_length * head_size); // original transposed iteration - // A: input (BxSxNxH) (B.)S x NH S x NH - // B: weights (NxHx3xNxH) NH x (3.N.)H NH x H + // A: input (BxSxD) (B.)S x D S x D + // B: weights (Dx3xNxH) D x (3.N.)H D x H // C: QKV[qkv_index] (3xBxNxSxH) (3.B.N.)S x H S x H MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR scale_bias_processor(qkv_dest + qkv_offset, @@ -215,9 +225,9 @@ Status QAttention::Compute(OpKernelContext* context) const { MlasGemm( sequence_length, // M = S head_size, // N = H - hidden_size, // K = NH + input_hidden_size, // K = D input_data + input_offset, // A - hidden_size, // lda = NH + input_hidden_size, // lda = D input_zero_point, // input zero point packed_weight, // B weight_zero_point, // weight zero point @@ -233,9 +243,9 @@ Status QAttention::Compute(OpKernelContext* context) const { MlasGemm( sequence_length, // M = S head_size, // N = H - hidden_size, // K = NH + input_hidden_size, // K = D input_data + input_offset, // A - hidden_size, // lda = NH + input_hidden_size, // lda = D input_zero_point, // input zero point weights_data + weights_offset, // B 3 * hidden_size, // ldb = 3NH diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index ce9147ad1b..4eee0ed90e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -41,16 +41,23 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { const Tensor* past = context->Input(4); ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), mask_index, past)); - // Input and output shapes: - // Input 0 - input : (batch_size, sequence_length, hidden_size) - // Output 0 - output : (batch_size, sequence_length, hidden_size) + // input shape (batch_size, sequence_length, input_hidden_size) const auto& shape = input->Shape(); int batch_size = static_cast(shape[0]); int sequence_length = static_cast(shape[1]); - int hidden_size = static_cast(shape[2]); - int head_size = hidden_size / num_heads_; + int input_hidden_size = static_cast(shape[2]); - Tensor* output = context->Output(0, shape); + // bias shape (3 * hidden_size) + const auto& bias_shape = bias->Shape(); + int hidden_size = static_cast(bias_shape[0]) / 3; + + int head_size = hidden_size / num_heads_; + + std::vector output_shape(3); + output_shape[0] = shape[0]; + output_shape[1] = shape[1]; + output_shape[2] = static_cast(hidden_size); + Tensor* output = context->Output(0, output_shape); int past_sequence_length = 0; Tensor* present = GetPresent(context, past, batch_size, head_size, sequence_length, past_sequence_length); @@ -61,7 +68,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { // Use GEMM for fully connection. int m = batch_size * sequence_length; int n = 3 * hidden_size; - int k = hidden_size; + int k = input_hidden_size; auto gemm_buffer = GetScratchBuffer(batch_size * sequence_length * 3 * hidden_size * element_size); typedef typename ToCudaType::MappedType CudaT; diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index 5833e2fcee..79946d167a 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -49,17 +49,6 @@ Status QAttention::CheckInputs(const Tensor* input, const Tensor* i_zp_tensor, const Tensor* w_zp_tensor, const Tensor* past_tensor) const { - // Input and output shapes: - // Input 0 - input : (batch_size, sequence_length, hidden_size) - // Input 1 - weights : (hidden_size, 3 * hidden_size) - // Input 2 - bias : (3 * hidden_size) - // Input 3 - input_scale : scalar - // Input 4 - weight_scale : scalar - // Input 5 - mask_index : nullptr, (batch_size), (2 * batch_size), (batch_size, 1), (1, 1) or (batch_size, past_sequence_length + sequence_length) - // Input 6 - input_zero_point : scalar - // Input 7 - weight_zero_point : scalar - // Output : (batch_size, sequence_length, hidden_size) - ORT_RETURN_IF_ERROR(AttentionBase::CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), mask_index, past_tensor)); ORT_RETURN_IF_NOT(IsScalarOr1ElementVector(input_scale_tensor), @@ -89,12 +78,12 @@ Status QAttention::CheckInputs(const Tensor* input, template Status QAttention::ComputeInternal(OpKernelContext* context) const { // Input and output shapes: - // Input 0 - input : (batch_size, sequence_length, hidden_size) - // Input 1 - weights : (hidden_size, 3 * hidden_size) + // Input 0 - input : (batch_size, sequence_length, input_hidden_size) + // Input 1 - weights : (input_hidden_size, 3 * hidden_size) // Input 2 - bias : (3 * hidden_size) // Input 3 - input_scale : scalar // Input 4 - weight_scale : scalar - // Input 5 - mask_index : (batch_size) + // Input 5 - mask_index : nullptr, (batch_size), (2 * batch_size), (batch_size, 1), (1, 1) or (batch_size, past_sequence_length + sequence_length) // Input 6 - input_zero_point : scalar // Input 7 - weight_zero_point : scalar // Input 8 - past : (2, batch_size, num_heads, past_sequence_length, head_size) @@ -124,10 +113,17 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { const auto& shape = input->Shape(); int batch_size = static_cast(shape[0]); int sequence_length = static_cast(shape[1]); - int hidden_size = static_cast(shape[2]); - int head_size = hidden_size / num_heads_; + int input_hidden_size = static_cast(shape[2]); - Tensor* output = context->Output(0, shape); + const auto& bias_shape = bias->Shape(); + const int hidden_size = static_cast(bias_shape.GetDims()[0]) / 3; + const int head_size = hidden_size / num_heads_; + + std::vector output_shape(3); + output_shape[0] = shape[0]; + output_shape[1] = shape[1]; + output_shape[2] = static_cast(hidden_size); + Tensor* output = context->Output(0, output_shape); cublasHandle_t cublas = CublasHandle(); const size_t element_size = sizeof(T); @@ -135,7 +131,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { // Use GEMM for fully connection. int m = batch_size * sequence_length; int n = 3 * hidden_size; - int k = hidden_size; + int k = input_hidden_size; auto gemm_buffer = GetScratchBuffer(batch_size * sequence_length * 3 * hidden_size * element_size); auto gemm_buffer_quantized = GetScratchBuffer(batch_size * sequence_length * 3 * hidden_size); diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index fdda10f467..de257885ed 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -275,6 +275,59 @@ void FusedMatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { updateOutputShape(ctx, 0, resultShape); } + +void AttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_input_index) { + // Type inference + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 2, 0); + if (ctx.getNumOutputs() > 1) { + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 2, 1); + } + + // Shape inference + if (hasInputShape(ctx, 0) && hasInputShape(ctx, 2)) { + auto& input_shape = getInputShape(ctx, 0); + auto& input_dims = input_shape.dim(); + if (input_dims.size() != 3) { + fail_shape_inference("Inputs 0 shall be 3 dimensions"); + } + + auto& bias_shape = getInputShape(ctx, 2); + auto& bias_dims = bias_shape.dim(); + if (bias_dims.size() != 1 || bias_shape.dim(0).dim_value() % 3 != 0) { + fail_shape_inference("Invalid bias shape"); + } + + ONNX_NAMESPACE::TensorShapeProto output_shape; + for (auto& dim : input_dims) { + *output_shape.add_dim() = dim; + } + output_shape.mutable_dim(2)->set_dim_value(bias_shape.dim(0).dim_value() / 3); + updateOutputShape(ctx, 0, output_shape); + + if (ctx.getNumOutputs() > 1) { + if (hasInputShape(ctx, past_input_index)) { + auto& past_shape = getInputShape(ctx, past_input_index); + auto& past_dims = past_shape.dim(); + if (past_dims.size() != 5) { + fail_shape_inference("Inputs 4 shall be 5 dimensions"); + } + + if (past_dims[3].has_dim_value() && input_dims[1].has_dim_value()) { + auto all_sequence_length = past_shape.dim(3).dim_value() + input_shape.dim(1).dim_value(); + + ONNX_NAMESPACE::TensorShapeProto present_shape; + for (auto& dim : past_dims) { + *present_shape.add_dim() = dim; + } + present_shape.mutable_dim(3)->set_dim_value(all_sequence_length); + + updateOutputShape(ctx, 1, present_shape); + } + } + } + } +} + void RegisterBertSchemas() { static const char* Attention_ver1_doc = R"DOC( Multi-Head Self Attention that can be either unidirectional (like GPT-2) or bidirectional (like BERT). @@ -296,8 +349,8 @@ and present state are optional. Present state could appear in output even when p "Whether every token can only attend to previous tokens. Default value is 0.", AttributeProto::INT, static_cast(0)) - .Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, hidden_size), hidden_size = num_heads * head_size", "T") - .Input(1, "weight", "2D input tensor with shape (hidden_size, 3 * hidden_size)", "T") + .Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, input_hidden_size)", "T") + .Input(1, "weight", "2D input tensor with shape (input_hidden_size, 3 * hidden_size), where hidden_size = num_heads * head_size", "T") .Input(2, "bias", "1D input tensor with shape (3 * hidden_size)", "T") .Input(3, "mask_index", "Attention mask with shape (batch_size, past_sequence_length + sequence_length) or (batch_size, sequence_length, past_sequence_length + sequence_length), or index with shape (batch_size) or (2 * batch_size).", "M", OpSchema::Optional) .Input(4, "past", "past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).", "T", OpSchema::Optional) @@ -306,42 +359,8 @@ and present state are optional. Present state could appear in output even when p .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") .TypeConstraint("M", {"tensor(int32)"}, "Constrain mask index to integer types") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - propagateElemTypeFromInputToOutput(ctx, 0, 0); - if (ctx.getNumOutputs() > 1) { - propagateElemTypeFromInputToOutput(ctx, 0, 1); - } - - if (hasInputShape(ctx, 0)) { - propagateShapeFromInputToOutput(ctx, 0, 0); - - if (ctx.getNumOutputs() > 1) { - auto& input_shape = getInputShape(ctx, 0); - auto& input_dims = input_shape.dim(); - if (input_dims.size() != 3) { - fail_shape_inference("Inputs 0 shall be 3 dimensions"); - } - - if (hasInputShape(ctx, 4)) { - auto& past_shape = getInputShape(ctx, 4); - auto& past_dims = past_shape.dim(); - if (past_dims.size() != 5) { - fail_shape_inference("Inputs 4 shall be 5 dimensions"); - } - - if (past_dims[3].has_dim_value() && input_dims[1].has_dim_value()) { - auto all_sequence_length = past_shape.dim(3).dim_value() + input_shape.dim(1).dim_value(); - - ONNX_NAMESPACE::TensorShapeProto present_shape; - for (auto& dim : past_dims) { - *present_shape.add_dim() = dim; - } - present_shape.mutable_dim(3)->set_dim_value(all_sequence_length); - - updateOutputShape(ctx, 1, present_shape); - } - } - } - } + constexpr int past_input_index = 4; + AttentionTypeAndShapeInference(ctx, past_input_index); }); ONNX_CONTRIB_OPERATOR_SCHEMA(QAttention) @@ -356,12 +375,12 @@ and present state are optional. Present state could appear in output even when p .Input( 0, "input", - "3D input tensor with shape (batch_size, sequence_length, hidden_size), hidden_size = num_heads * head_size", + "3D input tensor with shape (batch_size, sequence_length, input_hidden_size)", "T1") .Input( 1, "weight", - "2D input tensor with shape (hidden_size, 3 * hidden_size)", + "2D input tensor with shape (input_hidden_size, 3 * hidden_size), hidden_size = num_heads * head_size", "T2") .Input( 2, @@ -418,18 +437,8 @@ and present state are optional. Present state could appear in output even when p .TypeConstraint("T3", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") .TypeConstraint("T4", {"tensor(int32)"}, "Constrain mask index to integer types") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - // Type inference - ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 2, 0); - - // Shape inference - // if the input shape doesn't exist, further shape inference is not possible - if (!hasNInputShapes(ctx, 1)) { - return; - } - - ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, 0, 0); - - return; + constexpr int past_input_index = 8; + AttentionTypeAndShapeInference(ctx, past_input_index); }); static const char* Longformer_Attention_doc = R"DOC( @@ -456,7 +465,7 @@ Global attention flags have value 1 for the tokens attend globally and 0 otherwi .Input(4, "global_weight", "2D input tensor with shape (hidden_size, 3 * hidden_size)", "T") .Input(5, "global_bias", "1D input tensor with shape (3 * hidden_size)", "T") .Input(6, "global", "Global attention flags with shape (batch_size, sequence_length)", "G") - .Output(0, "output", "3D output tensor with shape (batch_size, append_length, hidden_size)", "T") + .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", "T") .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") .TypeConstraint("G", {"tensor(int32)"}, "Constrain to integer types") .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput); diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format3.onnx b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format3.onnx index e58ffa62a1..6209c88e01 100644 Binary files a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format3.onnx and b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format3.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format3_no_cast.onnx b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format3_no_cast.onnx index ed55d898f3..e100df9089 100644 Binary files a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format3_no_cast.onnx and b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format3_no_cast.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format5.onnx b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format5.onnx index 0748062435..b2ed57c9c4 100644 Binary files a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format5.onnx and b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format5.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format6.onnx b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format6.onnx index dd51076224..7441a5d260 100644 Binary files a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format6.onnx and b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format6.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format7.onnx b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format7.onnx index 2d10dc96a8..421ad35601 100644 Binary files a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format7.onnx and b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format7.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format8.onnx b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format8.onnx index a133cc53cf..cd44646499 100644 Binary files a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format8.onnx and b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format8.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format9.onnx b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format9.onnx index 12cb4f977e..fd789ca51c 100644 Binary files a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format9.onnx and b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format9.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_gen.py b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_gen.py index 1118023ef3..63254f0fd0 100644 --- a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_gen.py @@ -2,29 +2,32 @@ import onnx from onnx import helper from onnx import TensorProto from enum import Enum +from packaging import version +if version.parse(onnx.__version__) == version.parse('1.8.0'): + opset_version = 13 +elif version.parse(onnx.__version__) == version.parse('1.6.0'): + opset_version = 11 +else: + raise RuntimeError("Please pip install onnx==1.8.0 or 1.6.0 before running this script") def GenerateNodes(model_name, has_cast, suffix=''): nodes = [ # LayerNorm subgraph helper.make_node("Shape", ["input_ids" + suffix], ["shape1_out" + suffix], "shape1" + suffix), helper.make_node("Gather", ["shape1_out" + suffix, "indices_0"], ["gather0_out" + suffix], "gather0" + suffix), - helper.make_node("Unsqueeze", ["gather0_out" + suffix], ["unsqueeze0_out" + suffix], - "unsqueeze0" + suffix, - axes=[0]), + helper.make_node("Unsqueeze", ["gather0_out" + suffix, "axes_0"], ["unsqueeze0_out" + suffix], "unsqueeze0" + suffix) if opset_version == 13 \ + else helper.make_node("Unsqueeze", ["gather0_out" + suffix], ["unsqueeze0_out" + suffix], "unsqueeze0" + suffix, axes=[0]), helper.make_node("Shape", ["input_ids" + suffix], ["shape2_out" + suffix], "shape2" + suffix), helper.make_node("Gather", ["shape2_out" + suffix, "indices_1"], ["gather1_out" + suffix], "gather1" + suffix), - helper.make_node("Unsqueeze", ["gather1_out" + suffix], ["unsqueeze1_out" + suffix], - "unsqueeze1" + suffix, - axes=[0]), + helper.make_node("Unsqueeze", ["gather1_out" + suffix, "axes_0"], ["unsqueeze1_out" + suffix], "unsqueeze1" + suffix) if opset_version == 13 \ + else helper.make_node("Unsqueeze", ["gather1_out" + suffix], ["unsqueeze1_out" + suffix], "unsqueeze1" + suffix, axes=[0]), helper.make_node("Concat", ["unsqueeze0_out" + suffix, "unsqueeze1_out" + suffix], ["concat_out" + suffix], - "concat" + suffix, - axis=0), + "concat" + suffix, axis=0), helper.make_node("Cast", ["gather1_out" + suffix], ["cast_out" + suffix], "cast" + suffix, to=7), helper.make_node("Range", ["start_0", "cast_out" + suffix if has_cast else "gather1_out" + suffix, "delta_1"], ["range_out" + suffix], "range" + suffix), - helper.make_node("Unsqueeze", ["range_out" + suffix], ["unsqueeze2_out" + suffix], - "unsqueeze2" + suffix, - axes=[0]), + helper.make_node("Unsqueeze", ["range_out" + suffix, "axes_0"], ["unsqueeze2_out" + suffix], "unsqueeze2" + suffix) if opset_version == 13 \ + else helper.make_node("Unsqueeze", ["range_out" + suffix], ["unsqueeze2_out" + suffix], "unsqueeze2" + suffix, axes=[0]), helper.make_node("Expand", ["unsqueeze2_out" + suffix, "concat_out" + suffix], ["expand_out" + suffix], "expand" + suffix), helper.make_node("Gather", ["pos_embed", "expand_out" + suffix], ["pos_gather_out" + suffix], @@ -43,10 +46,8 @@ def GenerateNodes(model_name, has_cast, suffix=''): axis=-1, epsion=0.000009999999747378752), helper.make_node("Cast", ["input_mask" + suffix], ["mask_cast_out" + suffix], "mask_cast" + suffix, to=6), - helper.make_node("ReduceSum", ["mask_cast_out" + suffix], ["mask_index_out" + suffix], - "mask_index" + suffix, - axes=[1], - keepdims=0), + helper.make_node("ReduceSum", ["mask_cast_out" + suffix, "axes_1"], ["mask_index_out" + suffix], "mask_index" + suffix, keepdims=0) if opset_version == 13 \ + else helper.make_node("ReduceSum", ["mask_cast_out" + suffix], ["mask_index_out" + suffix], "mask_index" + suffix, axes=[1], keepdims=0), helper.make_node("Attention", ["layernorm_out" + suffix, "qkv_weights", "qkv_bias", "mask_index_out" + suffix], ["att_out" + suffix], "att" + suffix, @@ -63,7 +64,7 @@ def GenerateNodes(model_name, has_cast, suffix=''): def GenerateInitializers(): - # hidden_size=4, num_heads=2, max_seq_length=3 + # hidden_size=4, num_heads=2 initializers = [ # initializers helper.make_tensor('indices_0', TensorProto.INT64, [], [0]), helper.make_tensor('indices_1', TensorProto.INT64, [], [1]), @@ -75,12 +76,14 @@ def GenerateInitializers(): helper.make_tensor('seg_embed', TensorProto.FLOAT, [2, 4], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), helper.make_tensor('layer_norm_weight', TensorProto.FLOAT, [4], [1.0, 2.0, 3.0, 4.0]), helper.make_tensor('layer_norm_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]), - helper.make_tensor('qkv_weights', TensorProto.FLOAT, [4, 4], - [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), - helper.make_tensor('qkv_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]), + helper.make_tensor('qkv_weights', TensorProto.FLOAT, [4, 12], [0.1] * 4 * 12), + helper.make_tensor('qkv_bias', TensorProto.FLOAT, [12], + [0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4]), helper.make_tensor('matmul_weight', TensorProto.FLOAT, [4, 4], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), helper.make_tensor('add_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]), + helper.make_tensor('axes_0', TensorProto.INT64, [1], [0]), + helper.make_tensor('axes_1', TensorProto.INT64, [1], [1]), ] return initializers @@ -154,7 +157,9 @@ def GenerateModel5(model_name): axis=-1, epsion=0.000009999999747378752), helper.make_node("Cast", ["input_mask"], ["mask_cast_out"], "mask_cast", to=6), - helper.make_node("ReduceSum", ["mask_cast_out"], ["mask_index_out"], "mask_index", axes=[1], keepdims=0), + helper.make_node("ReduceSum", ["mask_cast_out", "axes_1"], ["mask_index_out"], "mask_index", keepdims=0) if opset_version == 13 \ + else helper.make_node("ReduceSum", ["mask_cast_out"], ["mask_index_out"], "mask_index", axes=[1], keepdims=0), + helper.make_node("Attention", ["layernorm_out", "qkv_weights", "qkv_bias", "mask_index_out"], ["att_out"], "att", domain="com.microsoft", @@ -164,11 +169,7 @@ def GenerateModel5(model_name): helper.make_node("Add", ["add_out", "layernorm_out"], ["add2_out"], "add2") ] - qkv_weights = [ - 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, - 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, - 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0 - ] + qkv_weights = [1.0] * hidden_size * (3 * hidden_size) initializers = [ # initializers helper.make_tensor('word_embed', TensorProto.FLOAT, [2, hidden_size], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), @@ -185,6 +186,7 @@ def GenerateModel5(model_name): helper.make_tensor('matmul_weight', TensorProto.FLOAT, [hidden_size, hidden_size], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), helper.make_tensor('add_bias', TensorProto.FLOAT, [hidden_size], [0.1, 0.2, 0.3, 0.4]), + helper.make_tensor('axes_1', TensorProto.INT64, [1], [1]), ] graph = helper.make_graph( @@ -208,16 +210,19 @@ def GenerateModel6(model_name): nodes = [ # LayerNorm subgraph helper.make_node("Shape", ["input_ids"], ["shape1_out"], "shape1"), helper.make_node("Gather", ["shape1_out", "indices_0"], ["gather0_out"], "gather0"), - helper.make_node("Unsqueeze", ["gather0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]), + helper.make_node("Unsqueeze", ["gather0_out", "axes_0"], ["unsqueeze0_out"], "unsqueeze0") if opset_version == 13 \ + else helper.make_node("Unsqueeze", ["gather0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]), helper.make_node("Shape", ["input_ids"], ["shape2_out"], "shape2"), helper.make_node("Gather", ["shape2_out", "indices_1"], ["gather1_out"], "gather1"), - helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]), + helper.make_node("Unsqueeze", ["gather1_out", "axes_0"], ["unsqueeze1_out"], "unsqueeze1") if opset_version == 13 \ + else helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]), helper.make_node("Concat", ["unsqueeze0_out", "unsqueeze1_out"], ["concat_out"], "concat", axis=0), helper.make_node("Reshape", ["concat_out", "reshape_init"], ["reshape_out"], "reshape"), helper.make_node("Equal", ["reshape_out", "equal_init"], ["equal_out"], "equal"), helper.make_node("Where", ["equal_out", "where_init", "reshape_out"], ["where_out"], "where"), helper.make_node("Range", ["start_0", "gather1_out", "delta_1"], ["range_out"], "range"), - helper.make_node("Unsqueeze", ["range_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]), + helper.make_node("Unsqueeze", ["range_out", "axes_0"], ["unsqueeze2_out"], "unsqueeze2") if opset_version == 13 \ + else helper.make_node("Unsqueeze", ["range_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]), helper.make_node("Expand", ["unsqueeze2_out", "where_out"], ["expand_out"], "expand"), helper.make_node("Gather", ["pos_embed", "expand_out"], ["pos_gather_out"], "pos_gather"), helper.make_node("Gather", ["word_embed", "input_ids"], ["word_gather_out"], "word_gather"), @@ -229,7 +234,8 @@ def GenerateModel6(model_name): axis=-1, epsion=0.000009999999747378752), helper.make_node("Cast", ["input_mask"], ["mask_cast_out"], "mask_cast", to=6), - helper.make_node("ReduceSum", ["mask_cast_out"], ["mask_index_out"], "mask_index", axes=[1], keepdims=0), + helper.make_node("ReduceSum", ["mask_cast_out", "axes_1"], ["mask_index_out"], "mask_index", keepdims=0) if opset_version == 13 \ + else helper.make_node("ReduceSum", ["mask_cast_out"], ["mask_index_out"], "mask_index", axes=[1], keepdims=0), helper.make_node("Attention", ["layernorm_out", "qkv_weights", "qkv_bias", "mask_index_out"], ["att_out"], "att", domain="com.microsoft", @@ -251,15 +257,16 @@ def GenerateModel6(model_name): helper.make_tensor('seg_embed', TensorProto.FLOAT, [2, 4], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), helper.make_tensor('layer_norm_weight', TensorProto.FLOAT, [4], [1.0, 2.0, 3.0, 4.0]), helper.make_tensor('layer_norm_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]), - helper.make_tensor('qkv_weights', TensorProto.FLOAT, [4, 4], - [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), - helper.make_tensor('qkv_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]), + helper.make_tensor('qkv_weights', TensorProto.FLOAT, [4, 12], [0.1] * 4 * 12), + helper.make_tensor('qkv_bias', TensorProto.FLOAT, [12], [0.1] * 12), helper.make_tensor('matmul_weight', TensorProto.FLOAT, [4, 4], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), helper.make_tensor('add_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]), helper.make_tensor('reshape_init', TensorProto.INT64, [1], [-1]), helper.make_tensor('equal_init', TensorProto.INT64, [2], [-1, -1]), helper.make_tensor('where_init', TensorProto.INT64, [2], [1, 1]), + helper.make_tensor('axes_0', TensorProto.INT64, [1], [0]), + helper.make_tensor('axes_1', TensorProto.INT64, [1], [1]), ] graph = helper.make_graph( @@ -278,12 +285,9 @@ def GenerateModel6(model_name): model = helper.make_model(graph) onnx.save(model, model_name) + def GenerateInitializers2(hidden_size): - qkv_weights = [ - 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, - 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, - 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0 - ] + qkv_weights = [1.0] * hidden_size * (3 * hidden_size) initializers = [ # initializers helper.make_tensor('word_embed', TensorProto.FLOAT, [2, hidden_size], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), @@ -300,31 +304,32 @@ def GenerateInitializers2(hidden_size): helper.make_tensor('matmul_weight', TensorProto.FLOAT, [hidden_size, hidden_size], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), helper.make_tensor('add_bias', TensorProto.FLOAT, [hidden_size], [0.1, 0.2, 0.3, 0.4]), + helper.make_tensor('axes_0', TensorProto.INT64, [1], [0]), + helper.make_tensor('axes_1', TensorProto.INT64, [1], [1]), ] return initializers + def GenerateNodes2(attention_heads): nodes = [ helper.make_node("Gather", ["word_embed", "input_ids"], ["word_gather_out"], "word_gather", axis=0), - helper.make_node("Shape", ["input_ids"], ["shape0_out"], "shape0"), helper.make_node("Gather", ["shape0_out", "indices_1"], ["gather0_out"], "gather0"), helper.make_node("Range", ["start", "gather0_out", "delta"], ["range0_out"], "range0"), - helper.make_node("Unsqueeze", ["range0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]), + helper.make_node("Unsqueeze", ["range0_out", "axes_0"], ["unsqueeze0_out"], "unsqueeze0") if opset_version == 13 \ + else helper.make_node("Unsqueeze", ["range0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]), helper.make_node("Shape", ["input_ids"], ["shape1_out"], "shape1"), helper.make_node("Expand", ["unsqueeze0_out", "shape1_out"], ["expand_out"], "expand"), helper.make_node("Gather", ["pos_embed", "expand_out"], ["pos_gather_out"], "pos_gather", axis=0), - helper.make_node("Add", ["word_gather_out", "pos_gather_out"], ["add1_out"], "add1"), helper.make_node("LayerNormalization", ["add1_out", "layer_norm_weight", "layer_norm_bias"], ["layernorm_out"], "layernorm", axis=-1, epsion=0.000009999999747378752), - helper.make_node("Cast", ["input_mask"], ["mask_cast_out"], "mask_cast", to=6), - - helper.make_node("ReduceSum", ["mask_cast_out"], ["mask_index_out"], "mask_index", axes=[1], keepdims=0), + helper.make_node("ReduceSum", ["mask_cast_out", "axes_1"], ["mask_index_out"], "mask_index", keepdims=0) if opset_version == 13 \ + else helper.make_node("ReduceSum", ["mask_cast_out"], ["mask_index_out"], "mask_index", axes=[1], keepdims=0), helper.make_node("Attention", ["layernorm_out", "qkv_weights", "qkv_bias", "mask_index_out"], ["att_out"], "att", domain="com.microsoft", @@ -336,6 +341,7 @@ def GenerateNodes2(attention_heads): return nodes + def GenerateModel7(model_name): batch_size = 2 hidden_size = 4 @@ -361,6 +367,7 @@ def GenerateModel7(model_name): model = helper.make_model(graph) onnx.save(model, model_name) + def GenerateModel8(model_name): batch_size = -1 hidden_size = 4 @@ -395,6 +402,7 @@ def GenerateModel8(model_name): model = helper.make_model(graph) onnx.save(model, model_name) + def GenerateModel9(model_name): batch_size = -1 hidden_size = 4 @@ -412,10 +420,13 @@ def GenerateModel9(model_name): helper.make_node("Expand", ["unsqueeze0_out", "shape_out"], ["expand_out"], "expand"), helper.make_node("Gather", ["shape_out", "indices_0"], ["gather1_out"], "gather1"), helper.make_node("Gather", ["shape_out", "indices_1"], ["gather2_out"], "gather2"), - helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]), - helper.make_node("Unsqueeze", ["gather2_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]), + helper.make_node("Unsqueeze", ["gather1_out", "axes_0"], ["unsqueeze1_out"], "unsqueeze1") if opset_version == 13 \ + else helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]), + helper.make_node("Unsqueeze", ["gather2_out", "axes_0"], ["unsqueeze2_out"], "unsqueeze2") if opset_version == 13 \ + else helper.make_node("Unsqueeze", ["gather2_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]), helper.make_node("Concat", ["unsqueeze1_out", "unsqueeze2_out"], ["concat_out"], "concat", axis=0), - helper.make_node('ConstantOfShape', ['concat_out'], ['constant_of_shape_out'], "constant_of_shape", + helper.make_node('ConstantOfShape', ['concat_out'], ['constant_of_shape_out'], + "constant_of_shape", value=helper.make_tensor('mask_shape', TensorProto.FLOAT, [1], [1.0])), helper.make_node("Cast", ["constant_of_shape_out"], ["mask_cast_out"], "mask_cast", to=6), ] @@ -437,11 +448,21 @@ def GenerateModel9(model_name): model = helper.make_model(graph) onnx.save(model, model_name) -GenerateModel3('embed_layer_norm_format3.onnx', True) -GenerateModel3('embed_layer_norm_format3_no_cast.onnx', False) -GenerateModel5('embed_layer_norm_format5.onnx') -GenerateModel6('embed_layer_norm_format6.onnx') -GenerateModel7('embed_layer_norm_format7.onnx') #distilbert -GenerateModel8('embed_layer_norm_format8.onnx') #distilbert & shape nodes integration with input mask -GenerateModel9('embed_layer_norm_format9.onnx') #distilbert & shape nodes integration without input mask -GenerateMultipleEmbedModel('embed_layer_norm_multiple.onnx') +if opset_version == 11: + GenerateModel3('embed_layer_norm_format3.onnx', True) + GenerateModel3('embed_layer_norm_format3_no_cast.onnx', False) + GenerateModel5('embed_layer_norm_format5.onnx') + GenerateModel6('embed_layer_norm_format6.onnx') + GenerateModel7('embed_layer_norm_format7.onnx') #distilbert + GenerateModel8('embed_layer_norm_format8.onnx') #distilbert & shape nodes integration with input mask + GenerateModel9('embed_layer_norm_format9.onnx') #distilbert & shape nodes integration without input mask + GenerateMultipleEmbedModel('embed_layer_norm_multiple.onnx') +else: + GenerateModel3('embed_layer_norm_format3_opset13.onnx', True) + GenerateModel3('embed_layer_norm_format3_no_cast_opset13.onnx', False) + GenerateModel5('embed_layer_norm_format5_opset13.onnx') + GenerateModel6('embed_layer_norm_format6_opset13.onnx') + GenerateModel7('embed_layer_norm_format7_opset13.onnx') + GenerateModel8('embed_layer_norm_format8_opset13.onnx') + GenerateModel9('embed_layer_norm_format9_opset13.onnx') + GenerateMultipleEmbedModel('embed_layer_norm_multiple_opset13.onnx') \ No newline at end of file diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_multiple.onnx b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_multiple.onnx index e090614283..d388c92ee2 100644 Binary files a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_multiple.onnx and b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_multiple.onnx differ