From 81ed6c55bf3fb7e0e5f85ccf2e71092eeefcf3d7 Mon Sep 17 00:00:00 2001 From: Tracy Sharpe <42477615+tracysh@users.noreply.github.com> Date: Tue, 1 Jun 2021 08:54:04 -0700 Subject: [PATCH] fix grouped pointwise convolution (#7885) --- .../core/providers/cpu/nn/qlinearconv.cc | 42 ++++++++++--------- .../providers/cpu/nn/qlinearconv_op_test.cc | 20 +++++++++ 2 files changed, 42 insertions(+), 20 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/qlinearconv.cc b/onnxruntime/core/providers/cpu/nn/qlinearconv.cc index 882138d50b..ff20ff0f8c 100644 --- a/onnxruntime/core/providers/cpu/nn/qlinearconv.cc +++ b/onnxruntime/core/providers/cpu/nn/qlinearconv.cc @@ -510,14 +510,27 @@ Status QLinearConv::Compute(OpKernelContext* context) const { static_cast(kernel_size)); } else { for (int64_t group_id = 0; group_id < group_count; ++group_id) { + MLAS_GEMM_U8X8_DATA_PARAMS gemm_params; + gemm_params.ZeroPointA = X_zero_point_value; + if (packed_W_buffer_) { + gemm_params.B = static_cast(packed_W_buffer_.get()) + group_id * packed_W_size_, + gemm_params.BIsPacked = true; + } else { + gemm_params.B = reordered_W + group_id * group_output_channels, + gemm_params.ldb = static_cast(M); + } + gemm_params.ZeroPointB = &W_zero_point_value; + gemm_params.C = worker_gemm_output + group_id * group_output_channels; + gemm_params.ldc = static_cast(M); + // Prepare the im2col transformation or use the input buffer directly for // pointwise convolutions. - const uint8_t* worker_gemm_input; + const auto* group_input_data = input_data + group_id * group_input_channels; if (col_buffer) { auto* worker_col_buffer = static_cast(col_buffer.get()) + output_start * kernel_dim; if (kernel_rank == 2) { math::Im2col()( - input_data + group_id * group_input_channels, + group_input_data, group_input_channels, C, input_shape[0], @@ -537,7 +550,7 @@ Status QLinearConv::Compute(OpKernelContext* context) const { X_zero_point_value); } else if (kernel_rank == 1) { math::Im2col()( - input_data + group_id * group_input_channels, + group_input_data, group_input_channels, C, 1, @@ -559,9 +572,11 @@ Status QLinearConv::Compute(OpKernelContext* context) const { // Use the im2col buffer prepared outside the thread, indexed by group. worker_col_buffer += group_id * col_buffer_size; } - worker_gemm_input = worker_col_buffer; + gemm_params.A = worker_col_buffer; + gemm_params.lda = static_cast(kernel_dim); } else { - worker_gemm_input = input_data + output_start * kernel_dim; + gemm_params.A = group_input_data + output_start * C; + gemm_params.lda = static_cast(C); } MLAS_GEMM_U8X8_SHAPE_PARAMS gemm_shape; @@ -570,20 +585,6 @@ Status QLinearConv::Compute(OpKernelContext* context) const { gemm_shape.K = static_cast(kernel_dim); gemm_shape.BIsSigned = is_W_signed; - MLAS_GEMM_U8X8_DATA_PARAMS gemm_params; - gemm_params.A = worker_gemm_input; - gemm_params.lda = static_cast(kernel_dim); - gemm_params.ZeroPointA = X_zero_point_value; - if (packed_W_buffer_) { - gemm_params.B = static_cast(packed_W_buffer_.get()) + group_id * packed_W_size_, - gemm_params.BIsPacked = true; - } else { - gemm_params.B = reordered_W + group_id * group_output_channels, - gemm_params.ldb = static_cast(M); - } - gemm_params.ZeroPointB = &W_zero_point_value; - gemm_params.C = worker_gemm_output + group_id * group_output_channels; - gemm_params.ldc = static_cast(M); MlasGemm(gemm_shape, gemm_params, nullptr); } } @@ -597,7 +598,8 @@ Status QLinearConv::Compute(OpKernelContext* context) const { output_scales.data(), output_scales.size() > 1, Y_zero_point_value, - 0,0, + 0, + 0, static_cast(output_count), static_cast(M)); }; diff --git a/onnxruntime/test/providers/cpu/nn/qlinearconv_op_test.cc b/onnxruntime/test/providers/cpu/nn/qlinearconv_op_test.cc index ea1869aec8..0916669e86 100644 --- a/onnxruntime/test/providers/cpu/nn/qlinearconv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/qlinearconv_op_test.cc @@ -736,6 +736,26 @@ TEST(QLinearConvTest, Conv2D_U8S8_Groups_PerChannel) { test.Run(); } +TEST(QLinearConvTest, Conv2D_U8S8_Groups_Pointwise) { + QLinearConvOpTester test; + test.GenerateRandomInput({1, 12, 17, 13}, .03f, 7); + test.GenerateRandomWeights({15, 4, 1, 1}, .10f, 0); + test.GenerateRandomBias(); + test.SetGroups(3); + test.SetOutputScaleAndZeroPoint(.26f, 88); + test.Run(); +} + +TEST(QLinearConvTest, Conv3D_U8S8_Groups_Pointwise) { + QLinearConvOpTester test; + test.GenerateRandomInput({2, 4, 13, 17, 13}, .03f, 7); + test.GenerateRandomWeights({6, 2, 1, 1, 1}, .10f, 0); + test.GenerateRandomBias(); + test.SetGroups(2); + test.SetOutputScaleAndZeroPoint(.26f, 88); + test.Run(); +} + TEST(QLinearConvTest, Conv1D_U8S8_Depthwise) { for (int64_t channels : std::initializer_list{7, 8, 9, 16, 25, 64}) { QLinearConvOpTester test;