From 2d6e10ba0054bfbbbc89f1e6d945b4ca3d59631f Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Sat, 27 Feb 2021 09:50:16 -0800 Subject: [PATCH] Update Attention and QAttention to support pruned model (#6819) * update Attention operator spec to support pruned model * update Attention and QAttention cpu & cuda kernel * Fix invalid embed layer norm fusion test models. --- onnxruntime/contrib_ops/cpu/bert/attention.cc | 94 +++++++------ .../cpu/quantization/attention_quant.cc | 44 +++--- .../contrib_ops/cuda/bert/attention.cc | 21 ++- .../quantization/attention_quantization.cc | 32 ++--- .../core/graph/contrib_ops/contrib_defs.cc | 115 +++++++++------- .../fusion/embed_layer_norm_format3.onnx | Bin 2239 -> 2435 bytes .../embed_layer_norm_format3_no_cast.onnx | Bin 2194 -> 2390 bytes .../fusion/embed_layer_norm_format5.onnx | Bin 1756 -> 1709 bytes .../fusion/embed_layer_norm_format6.onnx | Bin 2461 -> 2657 bytes .../fusion/embed_layer_norm_format7.onnx | Bin 1909 -> 1879 bytes .../fusion/embed_layer_norm_format8.onnx | Bin 1919 -> 1889 bytes .../fusion/embed_layer_norm_format9.onnx | Bin 2308 -> 2278 bytes .../transform/fusion/embed_layer_norm_gen.py | 129 ++++++++++-------- .../fusion/embed_layer_norm_multiple.onnx | Bin 4077 -> 4273 bytes 14 files changed, 247 insertions(+), 188 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index be0aae79eb..65352fe2a1 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -54,14 +54,17 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, const Tensor*& mask_index, const Tensor* past) const { // Input shapes: - // input : (batch_size, sequence_length, hidden_size) - // weights : (hidden_size, 3 * hidden_size) + // input : (batch_size, sequence_length, input_hidden_size) + // weights : (input_hidden_size, 3 * hidden_size) // bias : (3 * hidden_size) // mask_index : nullptr, (batch_size), (2 * batch_size), // or (batch_size, 1), (1, 1) // or (batch_size, past_sequence_length + sequence_length) // or (batch_size, sequence_length, past_sequence_length + sequence_length) // past : (2, batch_size, num_heads, past_sequence_length, head_size) + // + // Where hidden_size = num_heads * head_size. + // When a model is pruned (like some attention heads are removed), hidden_size < input_hidden_size. const auto& dims = input_shape.GetDims(); if (dims.size() != 3) { @@ -70,11 +73,6 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, } int batch_size = static_cast(dims[0]); int sequence_length = static_cast(dims[1]); - int hidden_size = static_cast(dims[2]); - if (hidden_size % num_heads_ != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 0 dimension 2 should be divisiable by value of the num_heads attribute."); - } const auto& weights_dims = weights_shape.GetDims(); if (weights_dims.size() != 2) { @@ -85,8 +83,15 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 1 dimension 0 should have same length as dimension 2 of input 0"); } - if (weights_dims[1] != 3 * weights_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'weights' dimension 1 should be 3 times of dimension 0"); + + int hidden_size = static_cast(weights_dims[1]) / 3; + if (3 * hidden_size != static_cast(weights_dims[1])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 1 dimension 1 should be 3 times of hidden dimension"); + } + + if (hidden_size % num_heads_ != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "hidden_size should be divisiable by num_heads."); } const auto& bias_dims = bias_shape.GetDims(); @@ -125,7 +130,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, const auto& mask_dims = mask_index->Shape().GetDims(); if (mask_dims.size() == 1) { if (static_cast(mask_dims[0]) != batch_size && static_cast(mask_dims[0]) != 2 * batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' dimension 0 shall have length of batch_size or 2 * batch_size"); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with 1D data shall have length of batch_size or 2 * batch_size"); } } else if (mask_dims.size() == 2) { if (static_cast(mask_dims[0]) != batch_size || static_cast(mask_dims[1]) != past_sequence_length + sequence_length) { @@ -134,12 +139,12 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, // Mask will have same value after propogation, which has same effect as no mask. mask_index = nullptr; } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with raw attention mask shall have shape batch_size x (past_sequence_length + sequence_length)"); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with 2D data shall have shape batch_size x (past_sequence_length + sequence_length)"); } } } else if (mask_dims.size() == 3) { if (static_cast(mask_dims[0]) != batch_size || mask_dims[1] != sequence_length || static_cast(mask_dims[2]) != past_sequence_length + sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' of 3d shall have shape batch_size x sequence_length x (past_sequence_length + sequence_length)"); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with 3D data shall have shape batch_size x sequence_length x (past_sequence_length + sequence_length)"); } } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'mask_index' is expected to have 1, 2 or 3 dimensions, got ", @@ -193,8 +198,9 @@ Status Attention::PrePack(const Tensor& weights, int input_idx, bool& is_pack return Status::OK(); } - const size_t hidden_size = static_cast(weights_dims[0]); + const size_t input_hidden_size = static_cast(weights_dims[0]); const size_t hidden_size_x3 = static_cast(weights_dims[1]); + const size_t hidden_size = hidden_size_x3 / 3; const size_t head_size = hidden_size / num_heads_; // Bail out if the weights shape has an expected shape. @@ -204,7 +210,7 @@ Status Attention::PrePack(const Tensor& weights, int input_idx, bool& is_pack const auto* weights_data = weights.Data(); - packed_weights_size_ = MlasGemmPackBSize(head_size, hidden_size); + packed_weights_size_ = MlasGemmPackBSize(head_size, input_hidden_size); if (packed_weights_size_ == 0) { return Status::OK(); } @@ -215,7 +221,7 @@ Status Attention::PrePack(const Tensor& weights, int input_idx, bool& is_pack packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc)); for (size_t i = 0; i < loop_len; i++) { - MlasGemmPackB(CblasNoTrans, head_size, hidden_size, weights_data, hidden_size_x3, packed_weights_data); + MlasGemmPackB(CblasNoTrans, head_size, input_hidden_size, weights_data, hidden_size_x3, packed_weights_data); packed_weights_data += packed_weights_size_; weights_data += head_size; } @@ -232,8 +238,9 @@ Status Attention::Compute(OpKernelContext* context) const { const Tensor* mask_index = context->Input(3); const Tensor* past = context->Input(4); + const TensorShape& weights_shape = (packed_weights_ ? weight_shape_ : weights->Shape()); ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), - weights ? weights->Shape() : weight_shape_, + weights_shape, bias->Shape(), mask_index, past)); @@ -241,10 +248,17 @@ Status Attention::Compute(OpKernelContext* context) const { const auto& shape = input->Shape().GetDims(); const int batch_size = static_cast(shape[0]); const int sequence_length = static_cast(shape[1]); - const int hidden_size = static_cast(shape[2]); + const int input_hidden_size = static_cast(shape[2]); + + const auto& weights_dims = weights_shape.GetDims(); + const int hidden_size = static_cast(weights_dims[1]) / 3; const int head_size = hidden_size / num_heads_; - Tensor* output = context->Output(0, shape); + std::vector output_shape(3); + output_shape[0] = shape[0]; + output_shape[1] = shape[1]; + output_shape[2] = static_cast(hidden_size); + Tensor* output = context->Output(0, output_shape); constexpr size_t element_size = sizeof(T); @@ -253,7 +267,8 @@ Status Attention::Compute(OpKernelContext* context) const { auto* tp = context->GetOperatorThreadPool(); // Compute Q, K, V - // gemm_data(BS, 3NH) = input(BS, NH) x weights(NH, 3NH) + bias(3NH) + // gemm_data(BS, 3NH) = input(BS, D) x weights(D, 3NH) + bias(3NH) + // D (input_hidden_size) is hidden dimension of input, where D could be larger than hidden_size (NH) when model is pruned. auto gemm_data = allocator->Alloc(SafeInt(batch_size) * sequence_length * 3 * hidden_size * element_size); BufferUniquePtr gemm_buffer(gemm_data, BufferDeleter(allocator)); @@ -269,14 +284,14 @@ Status Attention::Compute(OpKernelContext* context) const { const auto* bias_data = bias->template Data(); const double cost = - static_cast(sequence_length) * static_cast(head_size) * static_cast(hidden_size); + static_cast(sequence_length) * static_cast(head_size) * static_cast(input_hidden_size); ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { for (std::ptrdiff_t i = begin; i != end; ++i) { const int batch_index = static_cast((i / 3) / num_heads_); const int head_index = static_cast((i / 3) % num_heads_); const int qkv_index = static_cast(i % 3); - int input_offset = batch_index * sequence_length * hidden_size; + int input_offset = batch_index * sequence_length * input_hidden_size; int weights_offset = qkv_index * hidden_size + head_index * head_size; T* qkv_dest = QKV[qkv_index]; int qkv_offset = (batch_index * num_heads_ + head_index) * (sequence_length * head_size); @@ -290,8 +305,8 @@ Status Attention::Compute(OpKernelContext* context) const { } // original transposed iteration - // A: input (BxSxNxH) (B.)S x NH S x NH - // B: weights (NxHx3xNxH) NH x (3.N.)H NH x H + // A: input (BxSxD) (B.)S x D S x D + // B: weights (Dx3xNxH) D x (3.N.)H D x H // C: QKV[qkv_index] (3xBxNxSxH) (3.B.N.)S x H S x H if (packed_weights_) { const auto* packed_weight = @@ -300,30 +315,31 @@ Status Attention::Compute(OpKernelContext* context) const { CblasNoTrans, // TransA = no sequence_length, // M = S head_size, // N = H - hidden_size, // K = NH + input_hidden_size, // K = D 1.0f, // alpha input_data + input_offset, // A - hidden_size, // lda = NH + input_hidden_size, // lda = D packed_weight, // B 1.0f, // beta qkv_dest + qkv_offset, // C head_size, // ldc nullptr); // use single-thread } else { - math::GemmEx(CblasNoTrans, // TransA = no - CblasNoTrans, // TransB = no - sequence_length, // M = S - head_size, // N = H - hidden_size, // K = NH - 1.0f, // alpha - input_data + input_offset, // A - hidden_size, // lda = NH - weights_data + weights_offset, // B - 3 * hidden_size, // ldb = 3NH - 1.0f, // beta - qkv_dest + qkv_offset, // C - head_size, // ldc - nullptr // use single-thread + math::GemmEx( + CblasNoTrans, // TransA = no + CblasNoTrans, // TransB = no + sequence_length, // M = S + head_size, // N = H + input_hidden_size, // K = D + 1.0f, // alpha + input_data + input_offset, // A + input_hidden_size, // lda = D + weights_data + weights_offset, // B + 3 * hidden_size, // ldb = 3NH + 1.0f, // beta + qkv_dest + qkv_offset, // C + head_size, // ldc + nullptr // use single-thread ); } } diff --git a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc index 2c2ef835aa..e871a69c86 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc @@ -66,8 +66,9 @@ Status QAttention::PrePack(const Tensor& weights, int input_idx, bool& is_pac return Status::OK(); } - const size_t hidden_size = static_cast(weights_dims[0]); + const size_t input_hidden_size = static_cast(weights_dims[0]); const size_t hidden_size_x3 = static_cast(weights_dims[1]); + const size_t hidden_size = hidden_size_x3 / 3; const size_t head_size = hidden_size / num_heads_; // Bail out if the weights shape has an expected shape. @@ -78,7 +79,7 @@ Status QAttention::PrePack(const Tensor& weights, int input_idx, bool& is_pac const auto* weights_data = static_cast(weights.DataRaw()); weights_is_signed_ = weights.IsDataType(); - packed_weights_size_ = MlasGemmPackBSize(head_size, hidden_size, weights_is_signed_); + packed_weights_size_ = MlasGemmPackBSize(head_size, input_hidden_size, weights_is_signed_); if (packed_weights_size_ == 0) { return Status::OK(); } @@ -89,7 +90,7 @@ Status QAttention::PrePack(const Tensor& weights, int input_idx, bool& is_pac packed_weights_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc)); for (size_t i = 0; i < loop_len; i++) { - MlasGemmPackB(head_size, hidden_size, weights_data, hidden_size_x3, weights_is_signed_, packed_weights_data); + MlasGemmPackB(head_size, input_hidden_size, weights_data, hidden_size_x3, weights_is_signed_, packed_weights_data); packed_weights_data += packed_weights_size_; weights_data += head_size; } @@ -102,8 +103,8 @@ Status QAttention::PrePack(const Tensor& weights, int input_idx, bool& is_pac template Status QAttention::Compute(OpKernelContext* context) const { // Input and output shapes: - // Input 0 - input : (batch_size, sequence_length, hidden_size) - // Input 1 - weights : (hidden_size, 3 * hidden_size) + // Input 0 - input : (batch_size, sequence_length, input_hidden_size) + // Input 1 - weights : (input_hidden_size, 3 * hidden_size) // Input 2 - bias : (3 * hidden_size) // Input 3 - input_scale : scalar // Input 4 - weight_scale : scalar @@ -124,8 +125,9 @@ Status QAttention::Compute(OpKernelContext* context) const { const Tensor* w_zp_tensor = context->Input(7); const Tensor* past_tensor = context->Input(8); + const TensorShape& weights_shape = (packed_weights_ ? weight_shape_ : weights->Shape()); ORT_RETURN_IF_ERROR(AttentionBase::CheckInputs(input->Shape(), - packed_weights_ ? weight_shape_ : weights->Shape(), + weights_shape, bias->Shape(), mask_index, past_tensor)); @@ -157,10 +159,17 @@ Status QAttention::Compute(OpKernelContext* context) const { const auto& shape = input->Shape(); const int batch_size = static_cast(shape[0]); const int sequence_length = static_cast(shape[1]); - const int hidden_size = static_cast(shape[2]); + const int input_hidden_size = static_cast(shape[2]); + + const auto hidden_size_x3 = weights_shape.GetDims()[1]; + const int hidden_size = static_cast(hidden_size_x3) / 3; const int head_size = hidden_size / num_heads_; - Tensor* output = context->Output(0, shape); + std::vector output_shape(3); + output_shape[0] = shape[0]; + output_shape[1] = shape[1]; + output_shape[2] = static_cast(hidden_size); + Tensor* output = context->Output(0, output_shape); AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); @@ -168,7 +177,8 @@ Status QAttention::Compute(OpKernelContext* context) const { constexpr size_t element_size = sizeof(T); auto* tp = context->GetOperatorThreadPool(); - // STEP.1: gemm_data(BS, 3NH) = Scale(input(BS, NH) x weights(NH, 3NH)) + bias(3NH) + // STEP.1: gemm_data(BS, 3NH) = Scale(input(BS, D) x weights(D, 3NH)) + bias(3NH) + // D is hidden dimension of input, where input_hidden_size (D) could be larger than hidden_size (NH) when model is pruned. auto gemm_data = allocator->Alloc(SafeInt(batch_size) * sequence_length * 3 * hidden_size * element_size); BufferUniquePtr gemm_buffer(gemm_data, BufferDeleter(allocator)); @@ -186,21 +196,21 @@ Status QAttention::Compute(OpKernelContext* context) const { const bool weights_is_signed = packed_weights_ ? weights_is_signed_ : weights->IsDataType(); const double cost = - static_cast(sequence_length) * static_cast(head_size) * static_cast(hidden_size); + static_cast(sequence_length) * static_cast(head_size) * static_cast(input_hidden_size); ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { for (std::ptrdiff_t i = begin; i != end; ++i) { const int batch_index = static_cast((i / 3) / num_heads_); const int head_index = static_cast((i / 3) % num_heads_); const int qkv_index = static_cast(i % 3); - int input_offset = batch_index * sequence_length * hidden_size; + int input_offset = batch_index * sequence_length * input_hidden_size; int weights_offset = qkv_index * hidden_size + head_index * head_size; float* qkv_dest = QKV[qkv_index]; int qkv_offset = (batch_index * num_heads_ + head_index) * (sequence_length * head_size); // original transposed iteration - // A: input (BxSxNxH) (B.)S x NH S x NH - // B: weights (NxHx3xNxH) NH x (3.N.)H NH x H + // A: input (BxSxD) (B.)S x D S x D + // B: weights (Dx3xNxH) D x (3.N.)H D x H // C: QKV[qkv_index] (3xBxNxSxH) (3.B.N.)S x H S x H MLAS_QGEMM_SCALE_BIAS_OUTPUT_PROCESSOR scale_bias_processor(qkv_dest + qkv_offset, @@ -215,9 +225,9 @@ Status QAttention::Compute(OpKernelContext* context) const { MlasGemm( sequence_length, // M = S head_size, // N = H - hidden_size, // K = NH + input_hidden_size, // K = D input_data + input_offset, // A - hidden_size, // lda = NH + input_hidden_size, // lda = D input_zero_point, // input zero point packed_weight, // B weight_zero_point, // weight zero point @@ -233,9 +243,9 @@ Status QAttention::Compute(OpKernelContext* context) const { MlasGemm( sequence_length, // M = S head_size, // N = H - hidden_size, // K = NH + input_hidden_size, // K = D input_data + input_offset, // A - hidden_size, // lda = NH + input_hidden_size, // lda = D input_zero_point, // input zero point weights_data + weights_offset, // B 3 * hidden_size, // ldb = 3NH diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index ce9147ad1b..4eee0ed90e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -41,16 +41,23 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { const Tensor* past = context->Input(4); ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), mask_index, past)); - // Input and output shapes: - // Input 0 - input : (batch_size, sequence_length, hidden_size) - // Output 0 - output : (batch_size, sequence_length, hidden_size) + // input shape (batch_size, sequence_length, input_hidden_size) const auto& shape = input->Shape(); int batch_size = static_cast(shape[0]); int sequence_length = static_cast(shape[1]); - int hidden_size = static_cast(shape[2]); - int head_size = hidden_size / num_heads_; + int input_hidden_size = static_cast(shape[2]); - Tensor* output = context->Output(0, shape); + // bias shape (3 * hidden_size) + const auto& bias_shape = bias->Shape(); + int hidden_size = static_cast(bias_shape[0]) / 3; + + int head_size = hidden_size / num_heads_; + + std::vector output_shape(3); + output_shape[0] = shape[0]; + output_shape[1] = shape[1]; + output_shape[2] = static_cast(hidden_size); + Tensor* output = context->Output(0, output_shape); int past_sequence_length = 0; Tensor* present = GetPresent(context, past, batch_size, head_size, sequence_length, past_sequence_length); @@ -61,7 +68,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { // Use GEMM for fully connection. int m = batch_size * sequence_length; int n = 3 * hidden_size; - int k = hidden_size; + int k = input_hidden_size; auto gemm_buffer = GetScratchBuffer(batch_size * sequence_length * 3 * hidden_size * element_size); typedef typename ToCudaType::MappedType CudaT; diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index 5833e2fcee..79946d167a 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -49,17 +49,6 @@ Status QAttention::CheckInputs(const Tensor* input, const Tensor* i_zp_tensor, const Tensor* w_zp_tensor, const Tensor* past_tensor) const { - // Input and output shapes: - // Input 0 - input : (batch_size, sequence_length, hidden_size) - // Input 1 - weights : (hidden_size, 3 * hidden_size) - // Input 2 - bias : (3 * hidden_size) - // Input 3 - input_scale : scalar - // Input 4 - weight_scale : scalar - // Input 5 - mask_index : nullptr, (batch_size), (2 * batch_size), (batch_size, 1), (1, 1) or (batch_size, past_sequence_length + sequence_length) - // Input 6 - input_zero_point : scalar - // Input 7 - weight_zero_point : scalar - // Output : (batch_size, sequence_length, hidden_size) - ORT_RETURN_IF_ERROR(AttentionBase::CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), mask_index, past_tensor)); ORT_RETURN_IF_NOT(IsScalarOr1ElementVector(input_scale_tensor), @@ -89,12 +78,12 @@ Status QAttention::CheckInputs(const Tensor* input, template Status QAttention::ComputeInternal(OpKernelContext* context) const { // Input and output shapes: - // Input 0 - input : (batch_size, sequence_length, hidden_size) - // Input 1 - weights : (hidden_size, 3 * hidden_size) + // Input 0 - input : (batch_size, sequence_length, input_hidden_size) + // Input 1 - weights : (input_hidden_size, 3 * hidden_size) // Input 2 - bias : (3 * hidden_size) // Input 3 - input_scale : scalar // Input 4 - weight_scale : scalar - // Input 5 - mask_index : (batch_size) + // Input 5 - mask_index : nullptr, (batch_size), (2 * batch_size), (batch_size, 1), (1, 1) or (batch_size, past_sequence_length + sequence_length) // Input 6 - input_zero_point : scalar // Input 7 - weight_zero_point : scalar // Input 8 - past : (2, batch_size, num_heads, past_sequence_length, head_size) @@ -124,10 +113,17 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { const auto& shape = input->Shape(); int batch_size = static_cast(shape[0]); int sequence_length = static_cast(shape[1]); - int hidden_size = static_cast(shape[2]); - int head_size = hidden_size / num_heads_; + int input_hidden_size = static_cast(shape[2]); - Tensor* output = context->Output(0, shape); + const auto& bias_shape = bias->Shape(); + const int hidden_size = static_cast(bias_shape.GetDims()[0]) / 3; + const int head_size = hidden_size / num_heads_; + + std::vector output_shape(3); + output_shape[0] = shape[0]; + output_shape[1] = shape[1]; + output_shape[2] = static_cast(hidden_size); + Tensor* output = context->Output(0, output_shape); cublasHandle_t cublas = CublasHandle(); const size_t element_size = sizeof(T); @@ -135,7 +131,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { // Use GEMM for fully connection. int m = batch_size * sequence_length; int n = 3 * hidden_size; - int k = hidden_size; + int k = input_hidden_size; auto gemm_buffer = GetScratchBuffer(batch_size * sequence_length * 3 * hidden_size * element_size); auto gemm_buffer_quantized = GetScratchBuffer(batch_size * sequence_length * 3 * hidden_size); diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index fdda10f467..de257885ed 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -275,6 +275,59 @@ void FusedMatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { updateOutputShape(ctx, 0, resultShape); } + +void AttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_input_index) { + // Type inference + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 2, 0); + if (ctx.getNumOutputs() > 1) { + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 2, 1); + } + + // Shape inference + if (hasInputShape(ctx, 0) && hasInputShape(ctx, 2)) { + auto& input_shape = getInputShape(ctx, 0); + auto& input_dims = input_shape.dim(); + if (input_dims.size() != 3) { + fail_shape_inference("Inputs 0 shall be 3 dimensions"); + } + + auto& bias_shape = getInputShape(ctx, 2); + auto& bias_dims = bias_shape.dim(); + if (bias_dims.size() != 1 || bias_shape.dim(0).dim_value() % 3 != 0) { + fail_shape_inference("Invalid bias shape"); + } + + ONNX_NAMESPACE::TensorShapeProto output_shape; + for (auto& dim : input_dims) { + *output_shape.add_dim() = dim; + } + output_shape.mutable_dim(2)->set_dim_value(bias_shape.dim(0).dim_value() / 3); + updateOutputShape(ctx, 0, output_shape); + + if (ctx.getNumOutputs() > 1) { + if (hasInputShape(ctx, past_input_index)) { + auto& past_shape = getInputShape(ctx, past_input_index); + auto& past_dims = past_shape.dim(); + if (past_dims.size() != 5) { + fail_shape_inference("Inputs 4 shall be 5 dimensions"); + } + + if (past_dims[3].has_dim_value() && input_dims[1].has_dim_value()) { + auto all_sequence_length = past_shape.dim(3).dim_value() + input_shape.dim(1).dim_value(); + + ONNX_NAMESPACE::TensorShapeProto present_shape; + for (auto& dim : past_dims) { + *present_shape.add_dim() = dim; + } + present_shape.mutable_dim(3)->set_dim_value(all_sequence_length); + + updateOutputShape(ctx, 1, present_shape); + } + } + } + } +} + void RegisterBertSchemas() { static const char* Attention_ver1_doc = R"DOC( Multi-Head Self Attention that can be either unidirectional (like GPT-2) or bidirectional (like BERT). @@ -296,8 +349,8 @@ and present state are optional. Present state could appear in output even when p "Whether every token can only attend to previous tokens. Default value is 0.", AttributeProto::INT, static_cast(0)) - .Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, hidden_size), hidden_size = num_heads * head_size", "T") - .Input(1, "weight", "2D input tensor with shape (hidden_size, 3 * hidden_size)", "T") + .Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, input_hidden_size)", "T") + .Input(1, "weight", "2D input tensor with shape (input_hidden_size, 3 * hidden_size), where hidden_size = num_heads * head_size", "T") .Input(2, "bias", "1D input tensor with shape (3 * hidden_size)", "T") .Input(3, "mask_index", "Attention mask with shape (batch_size, past_sequence_length + sequence_length) or (batch_size, sequence_length, past_sequence_length + sequence_length), or index with shape (batch_size) or (2 * batch_size).", "M", OpSchema::Optional) .Input(4, "past", "past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).", "T", OpSchema::Optional) @@ -306,42 +359,8 @@ and present state are optional. Present state could appear in output even when p .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") .TypeConstraint("M", {"tensor(int32)"}, "Constrain mask index to integer types") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - propagateElemTypeFromInputToOutput(ctx, 0, 0); - if (ctx.getNumOutputs() > 1) { - propagateElemTypeFromInputToOutput(ctx, 0, 1); - } - - if (hasInputShape(ctx, 0)) { - propagateShapeFromInputToOutput(ctx, 0, 0); - - if (ctx.getNumOutputs() > 1) { - auto& input_shape = getInputShape(ctx, 0); - auto& input_dims = input_shape.dim(); - if (input_dims.size() != 3) { - fail_shape_inference("Inputs 0 shall be 3 dimensions"); - } - - if (hasInputShape(ctx, 4)) { - auto& past_shape = getInputShape(ctx, 4); - auto& past_dims = past_shape.dim(); - if (past_dims.size() != 5) { - fail_shape_inference("Inputs 4 shall be 5 dimensions"); - } - - if (past_dims[3].has_dim_value() && input_dims[1].has_dim_value()) { - auto all_sequence_length = past_shape.dim(3).dim_value() + input_shape.dim(1).dim_value(); - - ONNX_NAMESPACE::TensorShapeProto present_shape; - for (auto& dim : past_dims) { - *present_shape.add_dim() = dim; - } - present_shape.mutable_dim(3)->set_dim_value(all_sequence_length); - - updateOutputShape(ctx, 1, present_shape); - } - } - } - } + constexpr int past_input_index = 4; + AttentionTypeAndShapeInference(ctx, past_input_index); }); ONNX_CONTRIB_OPERATOR_SCHEMA(QAttention) @@ -356,12 +375,12 @@ and present state are optional. Present state could appear in output even when p .Input( 0, "input", - "3D input tensor with shape (batch_size, sequence_length, hidden_size), hidden_size = num_heads * head_size", + "3D input tensor with shape (batch_size, sequence_length, input_hidden_size)", "T1") .Input( 1, "weight", - "2D input tensor with shape (hidden_size, 3 * hidden_size)", + "2D input tensor with shape (input_hidden_size, 3 * hidden_size), hidden_size = num_heads * head_size", "T2") .Input( 2, @@ -418,18 +437,8 @@ and present state are optional. Present state could appear in output even when p .TypeConstraint("T3", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") .TypeConstraint("T4", {"tensor(int32)"}, "Constrain mask index to integer types") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - // Type inference - ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 2, 0); - - // Shape inference - // if the input shape doesn't exist, further shape inference is not possible - if (!hasNInputShapes(ctx, 1)) { - return; - } - - ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, 0, 0); - - return; + constexpr int past_input_index = 8; + AttentionTypeAndShapeInference(ctx, past_input_index); }); static const char* Longformer_Attention_doc = R"DOC( @@ -456,7 +465,7 @@ Global attention flags have value 1 for the tokens attend globally and 0 otherwi .Input(4, "global_weight", "2D input tensor with shape (hidden_size, 3 * hidden_size)", "T") .Input(5, "global_bias", "1D input tensor with shape (3 * hidden_size)", "T") .Input(6, "global", "Global attention flags with shape (batch_size, sequence_length)", "G") - .Output(0, "output", "3D output tensor with shape (batch_size, append_length, hidden_size)", "T") + .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", "T") .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") .TypeConstraint("G", {"tensor(int32)"}, "Constrain to integer types") .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput); diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format3.onnx b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format3.onnx index e58ffa62a142f9652f9e6f396bb174cac9d8f9de..6209c88e010e687aa28c0250be139581d02c8304 100644 GIT binary patch delta 316 zcmdll*euM-!DjVKXd|mBd;K*=4i*j`0Y;?*jAzfBv86Yzr>2EY+=ba?@#U$R=@})( zS`J`u89*%Xv70q>rX7$C!a!fb_)KT)CeL7>s-(rw!6?9P#mL~qmRONm9B%*|?!Dh8ha3iZJJ7ehNST;`{?!xS{`0~`u^o)|?$!j^JCi`$SPyWa;2>>WC B4zB0*p!r7|)(LV@q#bPfZJ*xC^t(;>%Mr(=$qn zwH(0SGJsg%V>fH&OgkVOgn_<<@tMxpP0nVYs-(rw!6?9P#mL~qmRONm9B%*7SQBpj)fkSGtJV!G}lmZuLW?n(* P_sG4cQa{{9il diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format5.onnx b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format5.onnx index 074806243552d75435f6d6cbd772395c68f258ae..b2ed57c9c495a6945bbbf0d0b269ef307df6937a 100644 GIT binary patch delta 244 zcmcb^yOx)QgUxEmMwV8V`fH3FEF3%nj7kR>85kPu>5YA;Y2oB4tfvIE_&FE_*sT~D Ro!Am9Qj6maC&#hX0RTa$F!BHZ delta 53 zcmZ3>dxw{WgU#ylMwV6ntX=i;pFq26B&5` DZ6**I diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format7.onnx b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format7.onnx index 2d10dc96a8470651f84194d52067647337e4e086..421ad356019d231421ce2c580bffce80d0abf6d6 100644 GIT binary patch delta 263 zcmey$cb$)ggU#yPMiyJv`fH3FEF3%nj7kR>85kPu>5YA;Y2joawo^)4{2Yt|>{g5n YPHc%4sm1XIFabuWfZ=2@_Bcjv01tyRr2qf` delta 95 zcmcc4_mz)@gWc-QMiyIE?rBUMEF3%nj7kkmlN(tSCP%Xfkbwm!^Rk_s{DUoykp}>G Ce;82! diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format8.onnx b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format8.onnx index a133cc53cf0547a18a3af8effcb0d2065257b4ab..cd44646499ae53c3109549d53406ed89d8bb084a 100644 GIT binary patch delta 263 zcmey*_mGc;gU#y3MixQV`fH3FEF3%nj7kR>85kPu>5YA;Y2joAwo^)4{2Yt|>{g5n YPHc%4sm1XIFabuWfZ^mMjikH Cy%tsg diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format9.onnx b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format9.onnx index 12cb4f977e0832ef4e813c3bca261b4b2f48f74e..fd789ca51c6186d86310026a9d75d1912fbf384e 100644 GIT binary patch delta 263 zcmZn>dM3!i!De-LBg-+i`fH3FEF3%nj7kR>85kPu>5YA;Y2oCn?5C8p_&FE_*sT~D YoY)d8Qj6maU;>O#0mI4jIZ_z80b?CB{Qv*} delta 55 zcmaDR*doNj!EW_?Bg-*1?rBUMEF3%nj7kkmli#sBO#a7ifP@VvPh&qhxsfAjggTX00C*JOaK4? delta 52 zcmdm}_*R~kgWc+h{6^Myyo{lf`T0FXxC^t(;>%Mr(=$qnwG=p51Q?YBCV%Fex_K2p I8zUnR0IgULJ^%m!