From fc9a895b46fa1730dedf6dcbb13e8b6ede20960e Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Thu, 13 Jun 2019 12:45:09 -0700 Subject: [PATCH] Add shape inference logic for Crop (contrib) op (#1157) * Add shape inference logic for Crop contrib op * Fix build break * More refinements * PR feedback * PR feedback 2 --- onnxruntime/contrib_ops/cpu/crop.h | 16 +-- .../core/graph/contrib_ops/contrib_defs.cc | 116 +++++++++++++++--- onnxruntime/test/contrib_ops/crop_op_test.cc | 34 +++++ 3 files changed, 140 insertions(+), 26 deletions(-) create mode 100644 onnxruntime/test/contrib_ops/crop_op_test.cc diff --git a/onnxruntime/contrib_ops/cpu/crop.h b/onnxruntime/contrib_ops/cpu/crop.h index 16f397c717..35bc94e39d 100644 --- a/onnxruntime/contrib_ops/cpu/crop.h +++ b/onnxruntime/contrib_ops/cpu/crop.h @@ -19,16 +19,16 @@ class CropBase { } Status ValidateInput(const Tensor* X) const { - if (border_.size() < 4) { + if (border_.size() != 4) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Attribute border needs to be specified with four border elements, got ", border_.size()); } const auto dims = X->Shape().GetDims(); - if (dims.size() < 4) { + if (dims.size() != 4) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input is expected to have four dimensions corresponding to [N,C,H,W], got ", dims.size()); + "Input is expected to have four dimensions corresponding to [N,C,H,W], got ", dims.size(), " input dimensions instead"); } const int64_t H = dims[2]; @@ -41,11 +41,11 @@ class CropBase { bottomBorder = border_[3]; if (H < topBorder + bottomBorder) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input's height (", H, ") needs to be greater than the topBorder (", topBorder, ") + bottomBorder (", bottomBorder, ")"); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input's height (", H, ") needs to be greater than or equal to the topBorder (", topBorder, ") + bottomBorder (", bottomBorder, ")"); } if (W < leftBorder + rightBorder) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input's width (", W, ") needs to be greater than the leftBorder (", leftBorder, ") + rightBorder (", rightBorder, ")"); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input's width (", W, ") needs to be greater than or equal to the leftBorder (", leftBorder, ") + rightBorder (", rightBorder, ")"); } int64_t bottomLimit = H - bottomBorder; @@ -58,11 +58,11 @@ class CropBase { if (H < bottomLimit) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input's height (", H, ") needs to be greater than the topBorder (", topBorder, ") + scale_[0] (", scale_[0], ")"); + "Input's height (", H, ") needs to be greater than or equal to the topBorder (", topBorder, ") + scale_[0] (", scale_[0], ")"); } if (W < rightLimit) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input's width (", W, ") needs to be greater than the leftBorder (", leftBorder, ") + scale_[1] (", scale_[1], ")"); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input's width (", W, ") needs to be greater than or equal to the leftBorder (", leftBorder, ") + scale_[1] (", scale_[1], ")"); } } @@ -132,5 +132,5 @@ class Crop final : public CropBase, public OpKernel { } }; -} +} // namespace contrib } //namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 13e4cea71b..9e7a9b1ced 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -437,11 +437,102 @@ and op)DOC"; .SinceVersion(10) .Deprecate() .SetDoc(Crop_ver1_doc) - .Attr("border", "A 1-D values of (leftBorder, topBorder, rightBorder, bottomBorder).", AttributeProto::INTS, OPTIONAL) + .Attr("border", "A 1-D values of (leftBorder, topBorder, rightBorder, bottomBorder).", AttributeProto::INTS) .Attr("scale", "A 1-D values of (height, width).", AttributeProto::INTS, OPTIONAL) .Input(0, "input", "Input tensor of shape [N,C,H,W]", "T") .Output(0, "output", "Result, has same type as input, with H and W dimensions reduced.", "T") - .TypeConstraint("T", {"tensor(float16)", "tensor(float)", "tensor(double)"}, "Constrain input and output types to float tensors."); + .TypeConstraint("T", {"tensor(float16)", "tensor(float)", "tensor(double)"}, "Constrain input and output types to float tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + // Type inference + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); + + // Shape inference + auto* output_shape = + ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); + + if (ONNX_NAMESPACE::hasNInputShapes(ctx, 1)) { + const auto& input_shape = + ctx.getInputType(0)->tensor_type().shape(); + const auto input_rank = + input_shape.dim_size(); + if (input_rank != 4) + fail_shape_inference("Input's shape must be 4-D"); + + // parse necessary attributes for futher processing + std::vector border; + bool border_present = + getRepeatedAttribute(ctx, "border", border); + if (!border_present || border.size() != 4) + fail_shape_inference( + "'Border' attribute must be present and must contain exactly 4 values - " + "(left_border, top_border, right_border, bottom_border)"); + + std::vector scale; + bool scale_present = + getRepeatedAttribute(ctx, "scale", scale); + if (scale_present && scale.size() != 2) + fail_shape_inference("'Scale' must contain exactly 2 values - (height, width)"); + + // actual shape inference processing + // [N, C] can be copied over from the input as is + *output_shape->mutable_dim(static_cast(0)) = input_shape.dim(static_cast(0)); + *output_shape->mutable_dim(static_cast(1)) = input_shape.dim(static_cast(1)); + + // process 'H' and 'W' + if (!input_shape.dim(static_cast(2)).has_dim_value() || + !input_shape.dim(static_cast(3)).has_dim_value()) { + // either height and width input has symbolic dims, so can't proceed further + // add two dims as placeholders for output_H and output_W and return + output_shape->add_dim(); + output_shape->add_dim(); + return; + } + + int64_t H = input_shape.dim(static_cast(2)).dim_value(); + int64_t W = input_shape.dim(static_cast(3)).dim_value(); + + int64_t left_border = border[0], + top_border = border[1], + right_border = border[2], + bottom_border = border[3]; + + if (H < top_border + bottom_border) + fail_shape_inference("Input's height (", H, ") needs to be greater than or equal to " + "the top_border (", top_border, ") + bottom_border (", bottom_border, ")"); + + if (W < left_border + right_border) + fail_shape_inference("Input's width (", W, ") needs to be greater than or equal to " + "the left_border (", left_border, ") + right_border (", right_border, ")"); + + int64_t bottom_limit = H - bottom_border; + int64_t right_limit = W - right_border; + + // scale = (height, width) + if (!scale.empty()) { + bottom_limit = top_border + scale[0]; + right_limit = left_border + scale[1]; + + if (H < bottom_limit) + fail_shape_inference("Input's height (", H, ") needs to be greater than or equal to the top_border (", top_border, ") + scale[0] (", scale[0], ")"); + + if (W < right_limit) + fail_shape_inference("Input's width (", W, ") needs to be greater than or equal to the left_border (", left_border, ") + scale[1] (", scale[1], ")"); + } + + auto* h_output_dim = output_shape->add_dim(); + h_output_dim->set_dim_value(bottom_limit - top_border); + + auto* w_output_dim = output_shape->add_dim(); + w_output_dim->set_dim_value(right_limit - left_border); + + } else { + // Rank Inference at the very least + // (We know that the output is going to be 4-D) + for (int i = 0; i < 4; ++i) { + output_shape->add_dim(); + } + } + }); ONNX_CONTRIB_OPERATOR_SCHEMA(DynamicSlice) .SinceVersion(10) @@ -791,41 +882,30 @@ activation and leaky_relu_alpha.)DOC") The first mode is selected when "tokenexp" is not set and "separators" is set. If "tokenexp" is set and "separators" is not set, the second mode will be used. The first mode breaks each input string into tokens by matching and removing separators. "separators" is a list of strings which are regular expressions. "tokenexp" is a single regular expression. - Let's assume "separators" is [" "] and consider an example. If input is - ["Hello World", "I love computer science !"] whose shape is [2], - then the output would be - [["Hello", "World", padvalue, padvalue, padvalue], ["I", "love", "computer", "science", "!"]] - whose shape is [2, 5] because you can find at most 5 tokens per input string. Note that the input at most can have two axes, so 3-D and higher dimension are not supported. - If "separators" contains a single empty string, the Tokenizer will enter into character tokenezation mode. This means all strings will be broken part into individual characters. - For each input string, the second mode searches matches of "tokenexp" and each match will be a token in Y. The matching of "tokenexp" is conducted greedily (i.e., a match should be as long as possible). This operator searches for the first match starting from the beginning of the considered string, and then launches another search starting from the first remained character after the first matched token. If no match found, this operator will remove the first character from the remained string and do another search. This procedure will be repeated until reaching the end of the considered string. - Let's consider another example to illustrate the effect of setting "mark" to true. If input is ["Hello", "World"], then the corresponding output would be [0x02, "Hello", "World", 0x03]. This implies that if mark is true, [C]/[N, C] - input's output shape becomes [C, D+2]/[N, C, D+2]. - If tokenizer removes the entire content of [C]-input, it will produce [[]]. I.e. the output shape should be [C][0] or [N][C][0] if input shape was [N][C]. - If the tokenizer receives empty input of [0] then the output is [0] if empty input of [N, 0] then [N, 0]. - )DOC"; ONNX_CONTRIB_OPERATOR_SCHEMA(Tokenizer) @@ -1117,7 +1197,7 @@ Example 4: if ((pads_initializer->dims_size() != 1 && pads_initializer->dims_size() != 2) || (pads_initializer->dims_size() == 2 && - pads_shape.dim((int)0).dim_value() != 1) || + pads_shape.dim(static_cast(0)).dim_value() != 1) || pads_initializer->data_type() != ONNX_NAMESPACE::TensorProto::INT64) fail_shape_inference( "'pads' input must be a 1D (shape: [input_rank]) " @@ -1140,8 +1220,8 @@ Example 4: const auto& output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); - for (size_t i = 0; (int64_t)i < input_rank; ++i) { - const auto& input_dim = input_shape.dim((int)i); + for (size_t i = 0; static_cast(i) < input_rank; ++i) { + const auto& input_dim = input_shape.dim(static_cast(i)); auto* output_dim = output_shape->add_dim(); if (input_dim.has_dim_value()) { output_dim->set_dim_value( @@ -1153,7 +1233,7 @@ Example 4: } else { // Infer ouput shapes' rank in any case auto* output_shape_0 = getOutputShape(ctx, 0); - for (size_t i = 0; (int64_t)i < input_rank; ++i) { + for (size_t i = 0; static_cast(i) < input_rank; ++i) { output_shape_0->add_dim(); } } @@ -1249,4 +1329,4 @@ Example 4: #endif } // namespace contrib } // namespace contrib -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/test/contrib_ops/crop_op_test.cc b/onnxruntime/test/contrib_ops/crop_op_test.cc new file mode 100644 index 0000000000..2ff4ceb7a6 --- /dev/null +++ b/onnxruntime/test/contrib_ops/crop_op_test.cc @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +TEST(CropOpTest, Crop_Border) { + OpTester test("Crop", 1, onnxruntime::kOnnxDomain); + test.AddInput("x", {1, 1, 4, 4}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0}); + std::vector border{1, 1, 1, 1}; + test.AddAttribute("border", border); + test.AddOutput("y", {1, 1, 2, 2}, {6.0, 7.0, 10.0, 11.0}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +TEST(CropOpTest, Crop_Scale) { + OpTester test("Crop", 1, onnxruntime::kOnnxDomain); + test.AddInput("x", {1, 1, 4, 4}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0}); + + std::vector border{1, 1, 1, 1}; + test.AddAttribute("border", border); + + std::vector scale{2, 2}; + test.AddAttribute("scale", scale); + + test.AddOutput("y", {1, 1, 2, 2}, {6.0, 7.0, 10.0, 11.0}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +} // namespace test +} // namespace onnxruntime