mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Rashuai/fix dilation (#415)
* test with conv * add dilation to shape inferencing * add test cases * add test cases
This commit is contained in:
parent
696ab8a194
commit
2062c49033
3 changed files with 147 additions and 6 deletions
|
|
@ -31,6 +31,7 @@ inline void ComputeTransposePadAndOutputShape(
|
|||
const int64_t in_size,
|
||||
const int64_t stride,
|
||||
const int64_t kernel,
|
||||
const int64_t dilation,
|
||||
const int64_t adj,
|
||||
AutoPadType pad_type,
|
||||
int64_t* pad_head,
|
||||
|
|
@ -39,7 +40,7 @@ inline void ComputeTransposePadAndOutputShape(
|
|||
if (*out_size != -1) {
|
||||
ORT_ENFORCE(*out_size >= 0);
|
||||
// total padding size
|
||||
int64_t paddings = std::max<int64_t>(0, (in_size - 1) * stride + kernel + adj - *out_size);
|
||||
int64_t paddings = std::max<int64_t>(0, (in_size - 1) * stride + kernel + dilation - 1 + adj - *out_size);
|
||||
if (pad_type == AutoPadType::SAME_UPPER) { // pad more on head when paddings are odd.
|
||||
*pad_head = paddings - paddings / 2;
|
||||
*pad_tail = paddings / 2;
|
||||
|
|
@ -61,14 +62,14 @@ inline void ComputeTransposePadAndOutputShape(
|
|||
case AutoPadType::SAME_LOWER:
|
||||
*pad_head = 0;
|
||||
*pad_tail = 0;
|
||||
*out_size = (in_size - 1) * stride + kernel + adj;
|
||||
*out_size = (in_size - 1) * stride + kernel + dilation - 1 + adj;
|
||||
break;
|
||||
default:
|
||||
throw NotImplementedException("pad type not supported");
|
||||
}
|
||||
} else {
|
||||
*out_size =
|
||||
(in_size - 1) * stride + kernel + adj - *pad_head - *pad_tail;
|
||||
(in_size - 1) * stride + kernel + dilation - 1 + adj - *pad_head - *pad_tail;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -144,7 +145,7 @@ Status ConvTransposeBase::PrepareForCompute(OpKernelContext* context, bool has_b
|
|||
|
||||
std::vector<int64_t> Y_dims;
|
||||
|
||||
ComputePadsAndOutputShape(input_shape, num_output_channels, kernel_shape, strides, output_padding, &pads, &Y_dims);
|
||||
ComputePadsAndOutputShape(input_shape, num_output_channels, kernel_shape, strides, dilations, output_padding, &pads, &Y_dims);
|
||||
TensorShape Yshape(Y_dims);
|
||||
Tensor* Y = context->Output(0, Yshape);
|
||||
|
||||
|
|
@ -169,6 +170,7 @@ void ConvTransposeBase::ComputePadsAndOutputShape(
|
|||
const int64_t output_channel,
|
||||
const std::vector<int64_t>& kernel_shape,
|
||||
const std::vector<int64_t>& strides,
|
||||
const std::vector<int64_t>& dilations,
|
||||
const std::vector<int64_t>& output_padding,
|
||||
std::vector<int64_t>* pads,
|
||||
std::vector<int64_t>* output_shape) const {
|
||||
|
|
@ -189,6 +191,7 @@ void ConvTransposeBase::ComputePadsAndOutputShape(
|
|||
H,
|
||||
strides[0],
|
||||
kernel_shape[0],
|
||||
dilations[0],
|
||||
output_padding[0],
|
||||
auto_pad_,
|
||||
&pads->at(0),
|
||||
|
|
@ -199,6 +202,7 @@ void ConvTransposeBase::ComputePadsAndOutputShape(
|
|||
W,
|
||||
strides[1],
|
||||
kernel_shape[1],
|
||||
dilations[1],
|
||||
output_padding[1],
|
||||
auto_pad_,
|
||||
&pads->at(1),
|
||||
|
|
@ -256,8 +260,8 @@ Status ConvTranspose<T>::Compute(OpKernelContext* context) const {
|
|||
p.Y->Shape()[3],
|
||||
p.kernel_shape[0],
|
||||
p.kernel_shape[1],
|
||||
1,
|
||||
1,
|
||||
p.dilations[0],
|
||||
p.dilations[1],
|
||||
p.pads[0],
|
||||
p.pads[1],
|
||||
p.pads[2],
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ class ConvTransposeBase : public ConvBase {
|
|||
const int64_t output_channel,
|
||||
const std::vector<int64_t>& kernel_shape,
|
||||
const std::vector<int64_t>& strides,
|
||||
const std::vector<int64_t>& dilations,
|
||||
const std::vector<int64_t>& output_padding,
|
||||
std::vector<int64_t>* pads,
|
||||
std::vector<int64_t>* output_shape) const;
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ struct ConvTransposeOpAttributes {
|
|||
vector<int64_t> output_shape;
|
||||
vector<int64_t> pads;
|
||||
vector<int64_t> strides;
|
||||
vector<int64_t> dilations;
|
||||
int64_t group;
|
||||
};
|
||||
|
||||
|
|
@ -35,6 +36,7 @@ void TestConvTransposeOp(const ConvTransposeOpAttributes& attributes,
|
|||
}
|
||||
test.AddAttribute("pads", attributes.pads);
|
||||
test.AddAttribute("strides", attributes.strides);
|
||||
test.AddAttribute("dilations", attributes.dilations);
|
||||
test.AddAttribute("group", attributes.group);
|
||||
|
||||
ORT_ENFORCE(inputs.size() <= 3, "Our name array is only setup to handle 3 inputs");
|
||||
|
|
@ -54,6 +56,7 @@ TEST(ConvTransposeTest, ConvTranspose_1) {
|
|||
{}, // output_shape
|
||||
vector<int64_t>{1, 1, 1, 1}, // pads
|
||||
vector<int64_t>{2, 2}, // strides
|
||||
vector<int64_t>{1, 1}, // dilations
|
||||
1 // group
|
||||
};
|
||||
vector<float> X = {0.16857791f, -0.15161794f, 0.08540368f,
|
||||
|
|
@ -81,6 +84,7 @@ TEST(ConvTransposeTest, ConvTranspose_Bias_1) {
|
|||
{}, // output_shape
|
||||
vector<int64_t>{1, 1, 1, 1}, // pads
|
||||
vector<int64_t>{1, 1}, // strides
|
||||
vector<int64_t>{1, 1}, // dilations
|
||||
1 // group
|
||||
};
|
||||
vector<float> X = {0.22572887f, -0.07105902f, -0.40399021f, -0.14461157f, 0.05367219f,
|
||||
|
|
@ -111,6 +115,7 @@ TEST(ConvTransposeTest, ConvTranspose_Bias_2) {
|
|||
{}, // output_shape
|
||||
vector<int64_t>{0, 0, 0, 0}, // pads
|
||||
vector<int64_t>{1, 1}, // strides
|
||||
vector<int64_t>{1, 1}, // dilations
|
||||
1 // group
|
||||
};
|
||||
vector<float> X = {0.01270282f, 0.09657472f, -0.36909008f, -0.08085269f,
|
||||
|
|
@ -158,6 +163,7 @@ TEST(ConvTransposeTest, ConvTranspose_Output_Shape) {
|
|||
vector<int64_t>{1, 3, 4, 4}, // output_shape
|
||||
vector<int64_t>{0, 0, 0, 0}, // pads
|
||||
vector<int64_t>{1, 1}, // strides
|
||||
vector<int64_t>{1, 1}, // dilations
|
||||
1 // group
|
||||
};
|
||||
int image_size = 4 * 4;
|
||||
|
|
@ -197,6 +203,7 @@ TEST(ConvTransposeTest, ConvTranspose_Output_Shape2) {
|
|||
vector<int64_t>{1, 1, 1, 14}, // output_shape
|
||||
vector<int64_t>{0, 0, 0, 0}, // pads
|
||||
vector<int64_t>{1, 1}, // strides
|
||||
vector<int64_t>{1, 1}, // dilations
|
||||
1 // group
|
||||
};
|
||||
vector<float> X = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f};
|
||||
|
|
@ -217,6 +224,7 @@ TEST(ConvTransposeTest, ConvTranspose_Output_Shape_Batch) {
|
|||
vector<int64_t>{2, 1, 1, 14}, // output_shape
|
||||
vector<int64_t>{0, 0, 0, 0}, // pads
|
||||
vector<int64_t>{1, 1}, // strides
|
||||
vector<int64_t>{1, 1}, // dilations
|
||||
1 // group
|
||||
};
|
||||
vector<float> X = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f,
|
||||
|
|
@ -239,6 +247,7 @@ TEST(ConvTransposeTest, ConvTranspose_Invalid_Kernel_Shape) {
|
|||
vector<int64_t>{2, 1, 1, 14}, // output_shape
|
||||
vector<int64_t>{0, 0, 0, 0}, // pads
|
||||
vector<int64_t>{1, 1}, // strides
|
||||
vector<int64_t>{1, 1}, // dilations
|
||||
1 // group
|
||||
};
|
||||
vector<float> X = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f,
|
||||
|
|
@ -263,6 +272,7 @@ TEST(ConvTransposeTest, ConvTranspose_onnx) {
|
|||
{}, // output_shape
|
||||
vector<int64_t>{0, 0, 0, 0}, // pads
|
||||
vector<int64_t>{1, 1}, // strides
|
||||
vector<int64_t>{1, 1}, // dilations
|
||||
1 // group
|
||||
};
|
||||
vector<float> X = {0., 1., 2., 3., 4., 5., 6., 7., 8.};
|
||||
|
|
@ -292,6 +302,7 @@ TEST(ConvTransposeTest, ConvTranspose_onnx2) {
|
|||
{}, // output_shape
|
||||
vector<int64_t>{0, 0, 0, 0}, // pads
|
||||
vector<int64_t>{1, 1}, // strides
|
||||
vector<int64_t>{1, 1}, // dilations
|
||||
1 // group
|
||||
};
|
||||
vector<float> X = {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17.};
|
||||
|
|
@ -323,6 +334,7 @@ TEST(ConvTransposeTest, ConvTranspose_onnx_group) {
|
|||
{}, // output_shape
|
||||
vector<int64_t>{0, 0, 0, 0}, // pads
|
||||
vector<int64_t>{1, 1}, // strides
|
||||
vector<int64_t>{1, 1}, // dilations
|
||||
4 // group
|
||||
};
|
||||
vector<float> X = {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f};
|
||||
|
|
@ -334,5 +346,129 @@ TEST(ConvTransposeTest, ConvTranspose_onnx_group) {
|
|||
TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape);
|
||||
}
|
||||
|
||||
TEST(ConvTransposeTest, ConvTranspose2D_group_dilation2) {
|
||||
ConvTransposeOpAttributes attrs = {
|
||||
vector<int64_t>{2, 2},
|
||||
{}, {},
|
||||
vector<int64_t>{0,0,0,0},
|
||||
vector<int64_t>{1,1},
|
||||
{2,2},
|
||||
1
|
||||
};
|
||||
|
||||
vector<float> X = {11.0f,12.0f,21.0f,22.0f};
|
||||
vector<int64_t> X_shape = {1,1,2,2};
|
||||
vector<float> W = {1.0f,1.0f,1.0f,1.0f};
|
||||
vector<int64_t> W_shape = {1,1,2,2};
|
||||
vector<int64_t> Y_shape = {1,1,4,4};
|
||||
auto expected_vals = {11.0f,12.0f,11.0f,12.0f,
|
||||
21.0f,22.0f,21.0f,22.0f,
|
||||
11.0f,12.0f,11.0f,12.0f,
|
||||
21.0f,22.0f,21.0f,22.0f};
|
||||
TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape);
|
||||
}
|
||||
|
||||
TEST(ConvTransposeTest, ConvTranspose2D_group_dilation3) {
|
||||
ConvTransposeOpAttributes attrs = {
|
||||
vector<int64_t>{2, 2},
|
||||
{}, {},
|
||||
vector<int64_t>{0,0,0,0},
|
||||
vector<int64_t>{1,1},
|
||||
{3,3},
|
||||
1
|
||||
};
|
||||
|
||||
vector<float> X = {11.0f,12.0f,21.0f,22.0f};
|
||||
vector<int64_t> X_shape = {1,1,2,2};
|
||||
vector<float> W = {1.0f,1.0f,1.0f,1.0f};
|
||||
vector<int64_t> W_shape = {1,1,2,2};
|
||||
vector<int64_t> Y_shape = {1,1,5,5};
|
||||
auto expected_vals = {11.0f,12.0f,0.0f,11.0f,12.0f,
|
||||
21.0f,22.0f,0.0f,21.0f,22.0f,
|
||||
0.0f, 0.0f, 0.0f,0.0f, 0.0f,
|
||||
11.0f,12.0f,0.0f,11.0f,12.0f,
|
||||
21.0f,22.0f,0.0f,21.0f,22.0f};
|
||||
TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape);
|
||||
}
|
||||
|
||||
TEST(ConvTransposeTest, ConvTranspose3D_group_dilation2) {
|
||||
ConvTransposeOpAttributes attrs = {
|
||||
vector<int64_t>{2, 2},
|
||||
{}, {},
|
||||
vector<int64_t>{0,0,0,0,0,0,0,0,0},
|
||||
vector<int64_t>{1,1,1},
|
||||
{2,2},
|
||||
1
|
||||
};
|
||||
|
||||
vector<float> X = {3.0f,8.0f,1.0f,9.0f,5.0f,7.0f,3.0f,2.0f,6.0f};
|
||||
vector<int64_t> X_shape = {1,1,3,3};
|
||||
vector<float> W = {7.0f,2.0f,1.0f,9.0f};
|
||||
vector<int64_t> W_shape = {1,1,2,2};
|
||||
vector<int64_t> Y_shape = {1,1,5,5};
|
||||
auto expected_vals = {21.0f, 56.0f, 13.0f, 16.0f, 2.0f,
|
||||
63.0f, 35.0f, 67.0f, 10.0f, 14.0f,
|
||||
24.0f, 22.0f, 76.0f, 76.0f, 21.0f,
|
||||
9.0f, 5.0f, 88.0f, 45.0f, 63.0f,
|
||||
3.0f, 2.0f, 33.0f, 18.0f, 54.0f};
|
||||
|
||||
TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape);
|
||||
}
|
||||
|
||||
TEST(ConvTransposeTest, ConvTranspose3D_group_dilation3) {
|
||||
ConvTransposeOpAttributes attrs = {
|
||||
vector<int64_t>{2, 2},
|
||||
{}, {},
|
||||
vector<int64_t>{0,0,0,0,0,0,0,0,0},
|
||||
vector<int64_t>{1,1,1},
|
||||
{3,3},
|
||||
1
|
||||
};
|
||||
|
||||
vector<float> X = {3.0f,8.0f,1.0f,9.0f,5.0f,7.0f,3.0f,2.0f,6.0f};
|
||||
vector<int64_t> X_shape = {1,1,3,3};
|
||||
vector<float> W = {7.0f,2.0f,1.0f,9.0f};
|
||||
vector<int64_t> W_shape = {1,1,2,2};
|
||||
vector<int64_t> Y_shape = {1,1,6,6};
|
||||
auto expected_vals = {21.0f, 56.0f, 7.0f, 6.0f, 16.0f, 2.0f,
|
||||
63.0f, 35.0f, 49.0f, 18.0f, 10.0f, 14.0f,
|
||||
21.0f, 14.0f, 42.0f, 6.0f, 4.0f, 12.0f,
|
||||
3.0f, 8.0f, 1.0f, 27.0f, 72.0f, 9.0f,
|
||||
9.0f, 5.0f, 7.0f, 81.0f, 45.0f, 63.0f,
|
||||
3.0f, 2.0f, 6.0f, 27.0f, 18.0f, 54.0f};
|
||||
|
||||
TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape);
|
||||
}
|
||||
|
||||
TEST(ConvTransposeTest, ConvTranspose3D_group2_dilation2) {
|
||||
ConvTransposeOpAttributes attrs = {
|
||||
vector<int64_t>{2, 2},
|
||||
{}, {},
|
||||
vector<int64_t>{0,0,0,0,0,0,0,0,0},
|
||||
vector<int64_t>{1,1,1},
|
||||
{2,2},
|
||||
2
|
||||
};
|
||||
|
||||
vector<float> X = {3.0f,8.0f,1.0f,9.0f,5.0f,7.0f,3.0f,2.0f,3.0f,7.0f,9.0f,1.0f,5.0f,2.0f,3.0f,9.0f,0.0f,2.0f};
|
||||
vector<int64_t> X_shape = {1,2,3,3};
|
||||
vector<float> W = {9.0f,3.0f,1.0f,2.0f,3.0f,7.0f,0.0f,8.0f};
|
||||
vector<int64_t> W_shape = {2,1,2,2};
|
||||
vector<int64_t> Y_shape = {1,2,5,5};
|
||||
auto expected_vals = {27.0f, 72.0f, 18.0f, 24.0f, 3.0f,
|
||||
81.0f, 45.0f, 90.0f, 15.0f, 21.0f,
|
||||
30.0f, 26.0f, 43.0f, 22.0f, 11.0f,
|
||||
9.0f, 5.0f, 25.0f, 10.0f, 14.0f,
|
||||
3.0f, 2.0f, 9.0f, 4.0f, 6.0f,
|
||||
21.0f, 27.0f, 52.0f, 63.0f, 7.0f,
|
||||
15.0f, 6.0f, 44.0f, 14.0f, 21.0f,
|
||||
27.0f, 0.0f, 125.0f, 72.0f, 22.0f,
|
||||
0.0f, 0.0f, 40.0f, 16.0f, 24.0f,
|
||||
0.0f, 0.0f, 72.0f, 0.0f, 16.0f};
|
||||
|
||||
TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape);
|
||||
}
|
||||
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue