diff --git a/onnxruntime/test/contrib_ops/qlinear_pool_test.cc b/onnxruntime/test/contrib_ops/qlinear_pool_test.cc index 94916ebec3..13408ace01 100644 --- a/onnxruntime/test/contrib_ops/qlinear_pool_test.cc +++ b/onnxruntime/test/contrib_ops/qlinear_pool_test.cc @@ -206,8 +206,8 @@ static std::vector transpose_to_nhwc(const std::vector& nchw_d auto channels = nchw_dims[1]; int64_t image_size = std::accumulate(nchw_dims.begin() + 2, nchw_dims.end(), 1LL, std::multiplies()); for (int64_t b = 0; b < batch_count; b++) { - const uint8_t* nchw_image = nchw_data.data() + (b * image_size); - uint8_t* nhwc_image = nhwc_data.data() + (b * image_size); + const uint8_t* nchw_image = nchw_data.data() + (b * channels * image_size); + uint8_t* nhwc_image = nhwc_data.data() + (b * channels * image_size); for (int64_t img_index = 0; img_index < image_size; ++img_index) { for (int64_t c = 0; c < channels; c++) { *nhwc_image++ = nchw_image[c * image_size + img_index]; @@ -331,6 +331,16 @@ TEST(QLinearPoolTest, AveragePool2D_IncludePadPixel) { 1); // count_include_pad } +TEST(QLinearPoolTest, AveragePool2D_MultiChannel) { + RunQLinearAveragePoolNchwU8( + {1, 3, 5, 7}, // x shape + {1, 3, 6, 4}, // expected y shape + {3, 4}, // kernel shape + {1, 2}, // strides + {1, 3, 2, 1}, // pads + 1); // count_include_pad +} + TEST(QLinearPoolTest, AveragePool3D_ExcludePadPixel) { RunQLinearAveragePoolNchwU8( {1, 1, 5, 7, 9}, // x shape @@ -394,6 +404,16 @@ TEST(QLinearPoolTest, AveragePool2D_IncludePadPixel_nhwc) { 1); // count_include_pad } +TEST(QLinearPoolTest, AveragePool2D_MultiChannel_nhwc) { + RunQLinearAveragePoolNhwcU8( + {1, 3, 5, 7}, // x shape + {1, 3, 6, 4}, // expected y shape + {3, 4}, // kernel shape + {1, 2}, // strides + {1, 3, 2, 1}, // pads + 1); // count_include_pad +} + TEST(QLinearPoolTest, AveragePool3D_ExcludePadPixel_nhwc) { RunQLinearAveragePoolNhwcU8( {1, 1, 5, 7, 9}, // x shape