mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-12 00:59:23 +00:00
fix grouped pointwise convolution (#7885)
This commit is contained in:
parent
3a72932c4a
commit
81ed6c55bf
2 changed files with 42 additions and 20 deletions
|
|
@ -510,14 +510,27 @@ Status QLinearConv::Compute(OpKernelContext* context) const {
|
|||
static_cast<size_t>(kernel_size));
|
||||
} else {
|
||||
for (int64_t group_id = 0; group_id < group_count; ++group_id) {
|
||||
MLAS_GEMM_U8X8_DATA_PARAMS gemm_params;
|
||||
gemm_params.ZeroPointA = X_zero_point_value;
|
||||
if (packed_W_buffer_) {
|
||||
gemm_params.B = static_cast<const int8_t*>(packed_W_buffer_.get()) + group_id * packed_W_size_,
|
||||
gemm_params.BIsPacked = true;
|
||||
} else {
|
||||
gemm_params.B = reordered_W + group_id * group_output_channels,
|
||||
gemm_params.ldb = static_cast<size_t>(M);
|
||||
}
|
||||
gemm_params.ZeroPointB = &W_zero_point_value;
|
||||
gemm_params.C = worker_gemm_output + group_id * group_output_channels;
|
||||
gemm_params.ldc = static_cast<size_t>(M);
|
||||
|
||||
// Prepare the im2col transformation or use the input buffer directly for
|
||||
// pointwise convolutions.
|
||||
const uint8_t* worker_gemm_input;
|
||||
const auto* group_input_data = input_data + group_id * group_input_channels;
|
||||
if (col_buffer) {
|
||||
auto* worker_col_buffer = static_cast<uint8_t*>(col_buffer.get()) + output_start * kernel_dim;
|
||||
if (kernel_rank == 2) {
|
||||
math::Im2col<uint8_t, StorageOrder::NHWC>()(
|
||||
input_data + group_id * group_input_channels,
|
||||
group_input_data,
|
||||
group_input_channels,
|
||||
C,
|
||||
input_shape[0],
|
||||
|
|
@ -537,7 +550,7 @@ Status QLinearConv::Compute(OpKernelContext* context) const {
|
|||
X_zero_point_value);
|
||||
} else if (kernel_rank == 1) {
|
||||
math::Im2col<uint8_t, StorageOrder::NHWC>()(
|
||||
input_data + group_id * group_input_channels,
|
||||
group_input_data,
|
||||
group_input_channels,
|
||||
C,
|
||||
1,
|
||||
|
|
@ -559,9 +572,11 @@ Status QLinearConv::Compute(OpKernelContext* context) const {
|
|||
// Use the im2col buffer prepared outside the thread, indexed by group.
|
||||
worker_col_buffer += group_id * col_buffer_size;
|
||||
}
|
||||
worker_gemm_input = worker_col_buffer;
|
||||
gemm_params.A = worker_col_buffer;
|
||||
gemm_params.lda = static_cast<size_t>(kernel_dim);
|
||||
} else {
|
||||
worker_gemm_input = input_data + output_start * kernel_dim;
|
||||
gemm_params.A = group_input_data + output_start * C;
|
||||
gemm_params.lda = static_cast<size_t>(C);
|
||||
}
|
||||
|
||||
MLAS_GEMM_U8X8_SHAPE_PARAMS gemm_shape;
|
||||
|
|
@ -570,20 +585,6 @@ Status QLinearConv::Compute(OpKernelContext* context) const {
|
|||
gemm_shape.K = static_cast<size_t>(kernel_dim);
|
||||
gemm_shape.BIsSigned = is_W_signed;
|
||||
|
||||
MLAS_GEMM_U8X8_DATA_PARAMS gemm_params;
|
||||
gemm_params.A = worker_gemm_input;
|
||||
gemm_params.lda = static_cast<size_t>(kernel_dim);
|
||||
gemm_params.ZeroPointA = X_zero_point_value;
|
||||
if (packed_W_buffer_) {
|
||||
gemm_params.B = static_cast<const int8_t*>(packed_W_buffer_.get()) + group_id * packed_W_size_,
|
||||
gemm_params.BIsPacked = true;
|
||||
} else {
|
||||
gemm_params.B = reordered_W + group_id * group_output_channels,
|
||||
gemm_params.ldb = static_cast<size_t>(M);
|
||||
}
|
||||
gemm_params.ZeroPointB = &W_zero_point_value;
|
||||
gemm_params.C = worker_gemm_output + group_id * group_output_channels;
|
||||
gemm_params.ldc = static_cast<size_t>(M);
|
||||
MlasGemm(gemm_shape, gemm_params, nullptr);
|
||||
}
|
||||
}
|
||||
|
|
@ -597,7 +598,8 @@ Status QLinearConv::Compute(OpKernelContext* context) const {
|
|||
output_scales.data(),
|
||||
output_scales.size() > 1,
|
||||
Y_zero_point_value,
|
||||
0,0,
|
||||
0,
|
||||
0,
|
||||
static_cast<size_t>(output_count),
|
||||
static_cast<size_t>(M));
|
||||
};
|
||||
|
|
|
|||
|
|
@ -736,6 +736,26 @@ TEST(QLinearConvTest, Conv2D_U8S8_Groups_PerChannel) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(QLinearConvTest, Conv2D_U8S8_Groups_Pointwise) {
|
||||
QLinearConvOpTester<uint8_t, int8_t> test;
|
||||
test.GenerateRandomInput({1, 12, 17, 13}, .03f, 7);
|
||||
test.GenerateRandomWeights({15, 4, 1, 1}, .10f, 0);
|
||||
test.GenerateRandomBias();
|
||||
test.SetGroups(3);
|
||||
test.SetOutputScaleAndZeroPoint(.26f, 88);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(QLinearConvTest, Conv3D_U8S8_Groups_Pointwise) {
|
||||
QLinearConvOpTester<uint8_t, int8_t> test;
|
||||
test.GenerateRandomInput({2, 4, 13, 17, 13}, .03f, 7);
|
||||
test.GenerateRandomWeights({6, 2, 1, 1, 1}, .10f, 0);
|
||||
test.GenerateRandomBias();
|
||||
test.SetGroups(2);
|
||||
test.SetOutputScaleAndZeroPoint(.26f, 88);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(QLinearConvTest, Conv1D_U8S8_Depthwise) {
|
||||
for (int64_t channels : std::initializer_list<int64_t>{7, 8, 9, 16, 25, 64}) {
|
||||
QLinearConvOpTester<uint8_t, int8_t> test;
|
||||
|
|
|
|||
Loading…
Reference in a new issue