diff --git a/onnxruntime/core/providers/cpu/nn/conv_transpose.cc b/onnxruntime/core/providers/cpu/nn/conv_transpose.cc index ef8dadad16..9c1512854d 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cpu/nn/conv_transpose.cc @@ -31,6 +31,7 @@ inline void ComputeTransposePadAndOutputShape( const int64_t in_size, const int64_t stride, const int64_t kernel, + const int64_t dilation, const int64_t adj, AutoPadType pad_type, int64_t* pad_head, @@ -39,7 +40,7 @@ inline void ComputeTransposePadAndOutputShape( if (*out_size != -1) { ORT_ENFORCE(*out_size >= 0); // total padding size - int64_t paddings = std::max(0, (in_size - 1) * stride + kernel + adj - *out_size); + int64_t paddings = std::max(0, (in_size - 1) * stride + kernel + dilation - 1 + adj - *out_size); if (pad_type == AutoPadType::SAME_UPPER) { // pad more on head when paddings are odd. *pad_head = paddings - paddings / 2; *pad_tail = paddings / 2; @@ -61,14 +62,14 @@ inline void ComputeTransposePadAndOutputShape( case AutoPadType::SAME_LOWER: *pad_head = 0; *pad_tail = 0; - *out_size = (in_size - 1) * stride + kernel + adj; + *out_size = (in_size - 1) * stride + kernel + dilation - 1 + adj; break; default: throw NotImplementedException("pad type not supported"); } } else { *out_size = - (in_size - 1) * stride + kernel + adj - *pad_head - *pad_tail; + (in_size - 1) * stride + kernel + dilation - 1 + adj - *pad_head - *pad_tail; } } } @@ -144,7 +145,7 @@ Status ConvTransposeBase::PrepareForCompute(OpKernelContext* context, bool has_b std::vector Y_dims; - ComputePadsAndOutputShape(input_shape, num_output_channels, kernel_shape, strides, output_padding, &pads, &Y_dims); + ComputePadsAndOutputShape(input_shape, num_output_channels, kernel_shape, strides, dilations, output_padding, &pads, &Y_dims); TensorShape Yshape(Y_dims); Tensor* Y = context->Output(0, Yshape); @@ -169,6 +170,7 @@ void ConvTransposeBase::ComputePadsAndOutputShape( const int64_t output_channel, const std::vector& kernel_shape, const std::vector& strides, + const std::vector& dilations, const std::vector& output_padding, std::vector* pads, std::vector* output_shape) const { @@ -189,6 +191,7 @@ void ConvTransposeBase::ComputePadsAndOutputShape( H, strides[0], kernel_shape[0], + dilations[0], output_padding[0], auto_pad_, &pads->at(0), @@ -199,6 +202,7 @@ void ConvTransposeBase::ComputePadsAndOutputShape( W, strides[1], kernel_shape[1], + dilations[1], output_padding[1], auto_pad_, &pads->at(1), @@ -256,8 +260,8 @@ Status ConvTranspose::Compute(OpKernelContext* context) const { p.Y->Shape()[3], p.kernel_shape[0], p.kernel_shape[1], - 1, - 1, + p.dilations[0], + p.dilations[1], p.pads[0], p.pads[1], p.pads[2], diff --git a/onnxruntime/core/providers/cpu/nn/conv_transpose.h b/onnxruntime/core/providers/cpu/nn/conv_transpose.h index 0e50c0fb12..2c0958a638 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_transpose.h +++ b/onnxruntime/core/providers/cpu/nn/conv_transpose.h @@ -52,6 +52,7 @@ class ConvTransposeBase : public ConvBase { const int64_t output_channel, const std::vector& kernel_shape, const std::vector& strides, + const std::vector& dilations, const std::vector& output_padding, std::vector* pads, std::vector* output_shape) const; diff --git a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc index a1c20488c3..bd2ef0cafa 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc @@ -15,6 +15,7 @@ struct ConvTransposeOpAttributes { vector output_shape; vector pads; vector strides; + vector dilations; int64_t group; }; @@ -35,6 +36,7 @@ void TestConvTransposeOp(const ConvTransposeOpAttributes& attributes, } test.AddAttribute("pads", attributes.pads); test.AddAttribute("strides", attributes.strides); + test.AddAttribute("dilations", attributes.dilations); test.AddAttribute("group", attributes.group); ORT_ENFORCE(inputs.size() <= 3, "Our name array is only setup to handle 3 inputs"); @@ -54,6 +56,7 @@ TEST(ConvTransposeTest, ConvTranspose_1) { {}, // output_shape vector{1, 1, 1, 1}, // pads vector{2, 2}, // strides + vector{1, 1}, // dilations 1 // group }; vector X = {0.16857791f, -0.15161794f, 0.08540368f, @@ -81,6 +84,7 @@ TEST(ConvTransposeTest, ConvTranspose_Bias_1) { {}, // output_shape vector{1, 1, 1, 1}, // pads vector{1, 1}, // strides + vector{1, 1}, // dilations 1 // group }; vector X = {0.22572887f, -0.07105902f, -0.40399021f, -0.14461157f, 0.05367219f, @@ -111,6 +115,7 @@ TEST(ConvTransposeTest, ConvTranspose_Bias_2) { {}, // output_shape vector{0, 0, 0, 0}, // pads vector{1, 1}, // strides + vector{1, 1}, // dilations 1 // group }; vector X = {0.01270282f, 0.09657472f, -0.36909008f, -0.08085269f, @@ -158,6 +163,7 @@ TEST(ConvTransposeTest, ConvTranspose_Output_Shape) { vector{1, 3, 4, 4}, // output_shape vector{0, 0, 0, 0}, // pads vector{1, 1}, // strides + vector{1, 1}, // dilations 1 // group }; int image_size = 4 * 4; @@ -197,6 +203,7 @@ TEST(ConvTransposeTest, ConvTranspose_Output_Shape2) { vector{1, 1, 1, 14}, // output_shape vector{0, 0, 0, 0}, // pads vector{1, 1}, // strides + vector{1, 1}, // dilations 1 // group }; vector X = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f}; @@ -217,6 +224,7 @@ TEST(ConvTransposeTest, ConvTranspose_Output_Shape_Batch) { vector{2, 1, 1, 14}, // output_shape vector{0, 0, 0, 0}, // pads vector{1, 1}, // strides + vector{1, 1}, // dilations 1 // group }; vector X = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, @@ -239,6 +247,7 @@ TEST(ConvTransposeTest, ConvTranspose_Invalid_Kernel_Shape) { vector{2, 1, 1, 14}, // output_shape vector{0, 0, 0, 0}, // pads vector{1, 1}, // strides + vector{1, 1}, // dilations 1 // group }; vector X = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, @@ -263,6 +272,7 @@ TEST(ConvTransposeTest, ConvTranspose_onnx) { {}, // output_shape vector{0, 0, 0, 0}, // pads vector{1, 1}, // strides + vector{1, 1}, // dilations 1 // group }; vector X = {0., 1., 2., 3., 4., 5., 6., 7., 8.}; @@ -292,6 +302,7 @@ TEST(ConvTransposeTest, ConvTranspose_onnx2) { {}, // output_shape vector{0, 0, 0, 0}, // pads vector{1, 1}, // strides + vector{1, 1}, // dilations 1 // group }; vector X = {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17.}; @@ -323,6 +334,7 @@ TEST(ConvTransposeTest, ConvTranspose_onnx_group) { {}, // output_shape vector{0, 0, 0, 0}, // pads vector{1, 1}, // strides + vector{1, 1}, // dilations 4 // group }; vector X = {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}; @@ -334,5 +346,129 @@ TEST(ConvTransposeTest, ConvTranspose_onnx_group) { TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); } +TEST(ConvTransposeTest, ConvTranspose2D_group_dilation2) { + ConvTransposeOpAttributes attrs = { + vector{2, 2}, + {}, {}, + vector{0,0,0,0}, + vector{1,1}, + {2,2}, + 1 + }; + + vector X = {11.0f,12.0f,21.0f,22.0f}; + vector X_shape = {1,1,2,2}; + vector W = {1.0f,1.0f,1.0f,1.0f}; + vector W_shape = {1,1,2,2}; + vector Y_shape = {1,1,4,4}; + auto expected_vals = {11.0f,12.0f,11.0f,12.0f, + 21.0f,22.0f,21.0f,22.0f, + 11.0f,12.0f,11.0f,12.0f, + 21.0f,22.0f,21.0f,22.0f}; + TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); +} + +TEST(ConvTransposeTest, ConvTranspose2D_group_dilation3) { + ConvTransposeOpAttributes attrs = { + vector{2, 2}, + {}, {}, + vector{0,0,0,0}, + vector{1,1}, + {3,3}, + 1 + }; + + vector X = {11.0f,12.0f,21.0f,22.0f}; + vector X_shape = {1,1,2,2}; + vector W = {1.0f,1.0f,1.0f,1.0f}; + vector W_shape = {1,1,2,2}; + vector Y_shape = {1,1,5,5}; + auto expected_vals = {11.0f,12.0f,0.0f,11.0f,12.0f, + 21.0f,22.0f,0.0f,21.0f,22.0f, + 0.0f, 0.0f, 0.0f,0.0f, 0.0f, + 11.0f,12.0f,0.0f,11.0f,12.0f, + 21.0f,22.0f,0.0f,21.0f,22.0f}; + TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); +} + +TEST(ConvTransposeTest, ConvTranspose3D_group_dilation2) { + ConvTransposeOpAttributes attrs = { + vector{2, 2}, + {}, {}, + vector{0,0,0,0,0,0,0,0,0}, + vector{1,1,1}, + {2,2}, + 1 + }; + + vector X = {3.0f,8.0f,1.0f,9.0f,5.0f,7.0f,3.0f,2.0f,6.0f}; + vector X_shape = {1,1,3,3}; + vector W = {7.0f,2.0f,1.0f,9.0f}; + vector W_shape = {1,1,2,2}; + vector Y_shape = {1,1,5,5}; + auto expected_vals = {21.0f, 56.0f, 13.0f, 16.0f, 2.0f, + 63.0f, 35.0f, 67.0f, 10.0f, 14.0f, + 24.0f, 22.0f, 76.0f, 76.0f, 21.0f, + 9.0f, 5.0f, 88.0f, 45.0f, 63.0f, + 3.0f, 2.0f, 33.0f, 18.0f, 54.0f}; + + TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); +} + +TEST(ConvTransposeTest, ConvTranspose3D_group_dilation3) { + ConvTransposeOpAttributes attrs = { + vector{2, 2}, + {}, {}, + vector{0,0,0,0,0,0,0,0,0}, + vector{1,1,1}, + {3,3}, + 1 + }; + + vector X = {3.0f,8.0f,1.0f,9.0f,5.0f,7.0f,3.0f,2.0f,6.0f}; + vector X_shape = {1,1,3,3}; + vector W = {7.0f,2.0f,1.0f,9.0f}; + vector W_shape = {1,1,2,2}; + vector Y_shape = {1,1,6,6}; + auto expected_vals = {21.0f, 56.0f, 7.0f, 6.0f, 16.0f, 2.0f, + 63.0f, 35.0f, 49.0f, 18.0f, 10.0f, 14.0f, + 21.0f, 14.0f, 42.0f, 6.0f, 4.0f, 12.0f, + 3.0f, 8.0f, 1.0f, 27.0f, 72.0f, 9.0f, + 9.0f, 5.0f, 7.0f, 81.0f, 45.0f, 63.0f, + 3.0f, 2.0f, 6.0f, 27.0f, 18.0f, 54.0f}; + + TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); +} + +TEST(ConvTransposeTest, ConvTranspose3D_group2_dilation2) { + ConvTransposeOpAttributes attrs = { + vector{2, 2}, + {}, {}, + vector{0,0,0,0,0,0,0,0,0}, + vector{1,1,1}, + {2,2}, + 2 + }; + + vector X = {3.0f,8.0f,1.0f,9.0f,5.0f,7.0f,3.0f,2.0f,3.0f,7.0f,9.0f,1.0f,5.0f,2.0f,3.0f,9.0f,0.0f,2.0f}; + vector X_shape = {1,2,3,3}; + vector W = {9.0f,3.0f,1.0f,2.0f,3.0f,7.0f,0.0f,8.0f}; + vector W_shape = {2,1,2,2}; + vector Y_shape = {1,2,5,5}; + auto expected_vals = {27.0f, 72.0f, 18.0f, 24.0f, 3.0f, + 81.0f, 45.0f, 90.0f, 15.0f, 21.0f, + 30.0f, 26.0f, 43.0f, 22.0f, 11.0f, + 9.0f, 5.0f, 25.0f, 10.0f, 14.0f, + 3.0f, 2.0f, 9.0f, 4.0f, 6.0f, + 21.0f, 27.0f, 52.0f, 63.0f, 7.0f, + 15.0f, 6.0f, 44.0f, 14.0f, 21.0f, + 27.0f, 0.0f, 125.0f, 72.0f, 22.0f, + 0.0f, 0.0f, 40.0f, 16.0f, 24.0f, + 0.0f, 0.0f, 72.0f, 0.0f, 16.0f}; + + TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); +} + + } // namespace test } // namespace onnxruntime