Add and correct multichannel test for qlinear pool test. (#7864)

This commit is contained in:
Zhang Lei 2021-06-04 12:03:41 -07:00 committed by GitHub
parent 5a7f65b831
commit 0975e7c9a7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -206,8 +206,8 @@ static std::vector<uint8_t> transpose_to_nhwc(const std::vector<uint8_t>& nchw_d
auto channels = nchw_dims[1];
int64_t image_size = std::accumulate(nchw_dims.begin() + 2, nchw_dims.end(), 1LL, std::multiplies<int64_t>());
for (int64_t b = 0; b < batch_count; b++) {
const uint8_t* nchw_image = nchw_data.data() + (b * image_size);
uint8_t* nhwc_image = nhwc_data.data() + (b * image_size);
const uint8_t* nchw_image = nchw_data.data() + (b * channels * image_size);
uint8_t* nhwc_image = nhwc_data.data() + (b * channels * image_size);
for (int64_t img_index = 0; img_index < image_size; ++img_index) {
for (int64_t c = 0; c < channels; c++) {
*nhwc_image++ = nchw_image[c * image_size + img_index];
@ -331,6 +331,16 @@ TEST(QLinearPoolTest, AveragePool2D_IncludePadPixel) {
1); // count_include_pad
}
TEST(QLinearPoolTest, AveragePool2D_MultiChannel) {
RunQLinearAveragePoolNchwU8(
{1, 3, 5, 7}, // x shape
{1, 3, 6, 4}, // expected y shape
{3, 4}, // kernel shape
{1, 2}, // strides
{1, 3, 2, 1}, // pads
1); // count_include_pad
}
TEST(QLinearPoolTest, AveragePool3D_ExcludePadPixel) {
RunQLinearAveragePoolNchwU8(
{1, 1, 5, 7, 9}, // x shape
@ -394,6 +404,16 @@ TEST(QLinearPoolTest, AveragePool2D_IncludePadPixel_nhwc) {
1); // count_include_pad
}
TEST(QLinearPoolTest, AveragePool2D_MultiChannel_nhwc) {
RunQLinearAveragePoolNhwcU8(
{1, 3, 5, 7}, // x shape
{1, 3, 6, 4}, // expected y shape
{3, 4}, // kernel shape
{1, 2}, // strides
{1, 3, 2, 1}, // pads
1); // count_include_pad
}
TEST(QLinearPoolTest, AveragePool3D_ExcludePadPixel_nhwc) {
RunQLinearAveragePoolNhwcU8(
{1, 1, 5, 7, 9}, // x shape