fix grouped pointwise convolution (#7885)

This commit is contained in:
Tracy Sharpe 2021-06-01 08:54:04 -07:00 committed by GitHub
parent 3a72932c4a
commit 81ed6c55bf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 42 additions and 20 deletions

View file

@ -510,14 +510,27 @@ Status QLinearConv::Compute(OpKernelContext* context) const {
static_cast<size_t>(kernel_size));
} else {
for (int64_t group_id = 0; group_id < group_count; ++group_id) {
MLAS_GEMM_U8X8_DATA_PARAMS gemm_params;
gemm_params.ZeroPointA = X_zero_point_value;
if (packed_W_buffer_) {
gemm_params.B = static_cast<const int8_t*>(packed_W_buffer_.get()) + group_id * packed_W_size_,
gemm_params.BIsPacked = true;
} else {
gemm_params.B = reordered_W + group_id * group_output_channels,
gemm_params.ldb = static_cast<size_t>(M);
}
gemm_params.ZeroPointB = &W_zero_point_value;
gemm_params.C = worker_gemm_output + group_id * group_output_channels;
gemm_params.ldc = static_cast<size_t>(M);
// Prepare the im2col transformation or use the input buffer directly for
// pointwise convolutions.
const uint8_t* worker_gemm_input;
const auto* group_input_data = input_data + group_id * group_input_channels;
if (col_buffer) {
auto* worker_col_buffer = static_cast<uint8_t*>(col_buffer.get()) + output_start * kernel_dim;
if (kernel_rank == 2) {
math::Im2col<uint8_t, StorageOrder::NHWC>()(
input_data + group_id * group_input_channels,
group_input_data,
group_input_channels,
C,
input_shape[0],
@ -537,7 +550,7 @@ Status QLinearConv::Compute(OpKernelContext* context) const {
X_zero_point_value);
} else if (kernel_rank == 1) {
math::Im2col<uint8_t, StorageOrder::NHWC>()(
input_data + group_id * group_input_channels,
group_input_data,
group_input_channels,
C,
1,
@ -559,9 +572,11 @@ Status QLinearConv::Compute(OpKernelContext* context) const {
// Use the im2col buffer prepared outside the thread, indexed by group.
worker_col_buffer += group_id * col_buffer_size;
}
worker_gemm_input = worker_col_buffer;
gemm_params.A = worker_col_buffer;
gemm_params.lda = static_cast<size_t>(kernel_dim);
} else {
worker_gemm_input = input_data + output_start * kernel_dim;
gemm_params.A = group_input_data + output_start * C;
gemm_params.lda = static_cast<size_t>(C);
}
MLAS_GEMM_U8X8_SHAPE_PARAMS gemm_shape;
@ -570,20 +585,6 @@ Status QLinearConv::Compute(OpKernelContext* context) const {
gemm_shape.K = static_cast<size_t>(kernel_dim);
gemm_shape.BIsSigned = is_W_signed;
MLAS_GEMM_U8X8_DATA_PARAMS gemm_params;
gemm_params.A = worker_gemm_input;
gemm_params.lda = static_cast<size_t>(kernel_dim);
gemm_params.ZeroPointA = X_zero_point_value;
if (packed_W_buffer_) {
gemm_params.B = static_cast<const int8_t*>(packed_W_buffer_.get()) + group_id * packed_W_size_,
gemm_params.BIsPacked = true;
} else {
gemm_params.B = reordered_W + group_id * group_output_channels,
gemm_params.ldb = static_cast<size_t>(M);
}
gemm_params.ZeroPointB = &W_zero_point_value;
gemm_params.C = worker_gemm_output + group_id * group_output_channels;
gemm_params.ldc = static_cast<size_t>(M);
MlasGemm(gemm_shape, gemm_params, nullptr);
}
}
@ -597,7 +598,8 @@ Status QLinearConv::Compute(OpKernelContext* context) const {
output_scales.data(),
output_scales.size() > 1,
Y_zero_point_value,
0,0,
0,
0,
static_cast<size_t>(output_count),
static_cast<size_t>(M));
};

View file

@ -736,6 +736,26 @@ TEST(QLinearConvTest, Conv2D_U8S8_Groups_PerChannel) {
test.Run();
}
TEST(QLinearConvTest, Conv2D_U8S8_Groups_Pointwise) {
QLinearConvOpTester<uint8_t, int8_t> test;
test.GenerateRandomInput({1, 12, 17, 13}, .03f, 7);
test.GenerateRandomWeights({15, 4, 1, 1}, .10f, 0);
test.GenerateRandomBias();
test.SetGroups(3);
test.SetOutputScaleAndZeroPoint(.26f, 88);
test.Run();
}
TEST(QLinearConvTest, Conv3D_U8S8_Groups_Pointwise) {
QLinearConvOpTester<uint8_t, int8_t> test;
test.GenerateRandomInput({2, 4, 13, 17, 13}, .03f, 7);
test.GenerateRandomWeights({6, 2, 1, 1, 1}, .10f, 0);
test.GenerateRandomBias();
test.SetGroups(2);
test.SetOutputScaleAndZeroPoint(.26f, 88);
test.Run();
}
TEST(QLinearConvTest, Conv1D_U8S8_Depthwise) {
for (int64_t channels : std::initializer_list<int64_t>{7, 8, 9, 16, 25, 64}) {
QLinearConvOpTester<uint8_t, int8_t> test;