mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-17 21:10:43 +00:00
Add shape inference logic for Crop (contrib) op (#1157)
* Add shape inference logic for Crop contrib op * Fix build break * More refinements * PR feedback * PR feedback 2
This commit is contained in:
parent
6c17567d7b
commit
fc9a895b46
3 changed files with 140 additions and 26 deletions
|
|
@ -19,16 +19,16 @@ class CropBase {
|
|||
}
|
||||
|
||||
Status ValidateInput(const Tensor* X) const {
|
||||
if (border_.size() < 4) {
|
||||
if (border_.size() != 4) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Attribute border needs to be specified with four border elements, got ", border_.size());
|
||||
}
|
||||
|
||||
const auto dims = X->Shape().GetDims();
|
||||
|
||||
if (dims.size() < 4) {
|
||||
if (dims.size() != 4) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input is expected to have four dimensions corresponding to [N,C,H,W], got ", dims.size());
|
||||
"Input is expected to have four dimensions corresponding to [N,C,H,W], got ", dims.size(), " input dimensions instead");
|
||||
}
|
||||
|
||||
const int64_t H = dims[2];
|
||||
|
|
@ -41,11 +41,11 @@ class CropBase {
|
|||
bottomBorder = border_[3];
|
||||
|
||||
if (H < topBorder + bottomBorder) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input's height (", H, ") needs to be greater than the topBorder (", topBorder, ") + bottomBorder (", bottomBorder, ")");
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input's height (", H, ") needs to be greater than or equal to the topBorder (", topBorder, ") + bottomBorder (", bottomBorder, ")");
|
||||
}
|
||||
|
||||
if (W < leftBorder + rightBorder) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input's width (", W, ") needs to be greater than the leftBorder (", leftBorder, ") + rightBorder (", rightBorder, ")");
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input's width (", W, ") needs to be greater than or equal to the leftBorder (", leftBorder, ") + rightBorder (", rightBorder, ")");
|
||||
}
|
||||
|
||||
int64_t bottomLimit = H - bottomBorder;
|
||||
|
|
@ -58,11 +58,11 @@ class CropBase {
|
|||
|
||||
if (H < bottomLimit) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input's height (", H, ") needs to be greater than the topBorder (", topBorder, ") + scale_[0] (", scale_[0], ")");
|
||||
"Input's height (", H, ") needs to be greater than or equal to the topBorder (", topBorder, ") + scale_[0] (", scale_[0], ")");
|
||||
}
|
||||
|
||||
if (W < rightLimit) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input's width (", W, ") needs to be greater than the leftBorder (", leftBorder, ") + scale_[1] (", scale_[1], ")");
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input's width (", W, ") needs to be greater than or equal to the leftBorder (", leftBorder, ") + scale_[1] (", scale_[1], ")");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -132,5 +132,5 @@ class Crop final : public CropBase, public OpKernel {
|
|||
}
|
||||
};
|
||||
|
||||
}
|
||||
} // namespace contrib
|
||||
} //namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -437,11 +437,102 @@ and op)DOC";
|
|||
.SinceVersion(10)
|
||||
.Deprecate()
|
||||
.SetDoc(Crop_ver1_doc)
|
||||
.Attr("border", "A 1-D values of (leftBorder, topBorder, rightBorder, bottomBorder).", AttributeProto::INTS, OPTIONAL)
|
||||
.Attr("border", "A 1-D values of (leftBorder, topBorder, rightBorder, bottomBorder).", AttributeProto::INTS)
|
||||
.Attr("scale", "A 1-D values of (height, width).", AttributeProto::INTS, OPTIONAL)
|
||||
.Input(0, "input", "Input tensor of shape [N,C,H,W]", "T")
|
||||
.Output(0, "output", "Result, has same type as input, with H and W dimensions reduced.", "T")
|
||||
.TypeConstraint("T", {"tensor(float16)", "tensor(float)", "tensor(double)"}, "Constrain input and output types to float tensors.");
|
||||
.TypeConstraint("T", {"tensor(float16)", "tensor(float)", "tensor(double)"}, "Constrain input and output types to float tensors.")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
// Type inference
|
||||
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
|
||||
// Shape inference
|
||||
auto* output_shape =
|
||||
ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
|
||||
|
||||
if (ONNX_NAMESPACE::hasNInputShapes(ctx, 1)) {
|
||||
const auto& input_shape =
|
||||
ctx.getInputType(0)->tensor_type().shape();
|
||||
const auto input_rank =
|
||||
input_shape.dim_size();
|
||||
if (input_rank != 4)
|
||||
fail_shape_inference("Input's shape must be 4-D");
|
||||
|
||||
// parse necessary attributes for futher processing
|
||||
std::vector<int64_t> border;
|
||||
bool border_present =
|
||||
getRepeatedAttribute(ctx, "border", border);
|
||||
if (!border_present || border.size() != 4)
|
||||
fail_shape_inference(
|
||||
"'Border' attribute must be present and must contain exactly 4 values - "
|
||||
"(left_border, top_border, right_border, bottom_border)");
|
||||
|
||||
std::vector<int64_t> scale;
|
||||
bool scale_present =
|
||||
getRepeatedAttribute(ctx, "scale", scale);
|
||||
if (scale_present && scale.size() != 2)
|
||||
fail_shape_inference("'Scale' must contain exactly 2 values - (height, width)");
|
||||
|
||||
// actual shape inference processing
|
||||
// [N, C] can be copied over from the input as is
|
||||
*output_shape->mutable_dim(static_cast<int>(0)) = input_shape.dim(static_cast<int>(0));
|
||||
*output_shape->mutable_dim(static_cast<int>(1)) = input_shape.dim(static_cast<int>(1));
|
||||
|
||||
// process 'H' and 'W'
|
||||
if (!input_shape.dim(static_cast<int>(2)).has_dim_value() ||
|
||||
!input_shape.dim(static_cast<int>(3)).has_dim_value()) {
|
||||
// either height and width input has symbolic dims, so can't proceed further
|
||||
// add two dims as placeholders for output_H and output_W and return
|
||||
output_shape->add_dim();
|
||||
output_shape->add_dim();
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t H = input_shape.dim(static_cast<int>(2)).dim_value();
|
||||
int64_t W = input_shape.dim(static_cast<int>(3)).dim_value();
|
||||
|
||||
int64_t left_border = border[0],
|
||||
top_border = border[1],
|
||||
right_border = border[2],
|
||||
bottom_border = border[3];
|
||||
|
||||
if (H < top_border + bottom_border)
|
||||
fail_shape_inference("Input's height (", H, ") needs to be greater than or equal to "
|
||||
"the top_border (", top_border, ") + bottom_border (", bottom_border, ")");
|
||||
|
||||
if (W < left_border + right_border)
|
||||
fail_shape_inference("Input's width (", W, ") needs to be greater than or equal to "
|
||||
"the left_border (", left_border, ") + right_border (", right_border, ")");
|
||||
|
||||
int64_t bottom_limit = H - bottom_border;
|
||||
int64_t right_limit = W - right_border;
|
||||
|
||||
// scale = (height, width)
|
||||
if (!scale.empty()) {
|
||||
bottom_limit = top_border + scale[0];
|
||||
right_limit = left_border + scale[1];
|
||||
|
||||
if (H < bottom_limit)
|
||||
fail_shape_inference("Input's height (", H, ") needs to be greater than or equal to the top_border (", top_border, ") + scale[0] (", scale[0], ")");
|
||||
|
||||
if (W < right_limit)
|
||||
fail_shape_inference("Input's width (", W, ") needs to be greater than or equal to the left_border (", left_border, ") + scale[1] (", scale[1], ")");
|
||||
}
|
||||
|
||||
auto* h_output_dim = output_shape->add_dim();
|
||||
h_output_dim->set_dim_value(bottom_limit - top_border);
|
||||
|
||||
auto* w_output_dim = output_shape->add_dim();
|
||||
w_output_dim->set_dim_value(right_limit - left_border);
|
||||
|
||||
} else {
|
||||
// Rank Inference at the very least
|
||||
// (We know that the output is going to be 4-D)
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
output_shape->add_dim();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(DynamicSlice)
|
||||
.SinceVersion(10)
|
||||
|
|
@ -791,41 +882,30 @@ activation and leaky_relu_alpha.)DOC")
|
|||
The first mode is selected when "tokenexp" is not set and "separators" is set. If "tokenexp" is set and "separators" is not set,
|
||||
the second mode will be used. The first mode breaks each input string into tokens by matching and removing separators.
|
||||
"separators" is a list of strings which are regular expressions. "tokenexp" is a single regular expression.
|
||||
|
||||
Let's assume "separators" is [" "] and consider an example.
|
||||
If input is
|
||||
|
||||
["Hello World", "I love computer science !"] whose shape is [2],
|
||||
|
||||
then the output would be
|
||||
|
||||
[["Hello", "World", padvalue, padvalue, padvalue],
|
||||
["I", "love", "computer", "science", "!"]]
|
||||
|
||||
whose shape is [2, 5] because you can find at most 5 tokens per input string.
|
||||
Note that the input at most can have two axes, so 3-D and higher dimension are not supported.
|
||||
|
||||
If "separators" contains a single empty string, the Tokenizer will enter into character tokenezation mode. This means all strings
|
||||
will be broken part into individual characters.
|
||||
|
||||
For each input string, the second mode searches matches of "tokenexp" and each match will be a token in Y.
|
||||
The matching of "tokenexp" is conducted greedily (i.e., a match should be as long as possible).
|
||||
This operator searches for the first match starting from the beginning of the considered string,
|
||||
and then launches another search starting from the first remained character after the first matched token.
|
||||
If no match found, this operator will remove the first character from the remained string and do another search.
|
||||
This procedure will be repeated until reaching the end of the considered string.
|
||||
|
||||
Let's consider another example to illustrate the effect of setting "mark" to true.
|
||||
If input is ["Hello", "World"],
|
||||
then the corresponding output would be [0x02, "Hello", "World", 0x03].
|
||||
This implies that if mark is true, [C]/[N, C] - input's output shape becomes [C, D+2]/[N, C, D+2].
|
||||
|
||||
If tokenizer removes the entire content of [C]-input, it will produce [[]].
|
||||
I.e. the output shape should be [C][0] or [N][C][0] if input shape was [N][C].
|
||||
|
||||
If the tokenizer receives empty input of [0] then the output is [0] if empty input
|
||||
of [N, 0] then [N, 0].
|
||||
|
||||
)DOC";
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(Tokenizer)
|
||||
|
|
@ -1117,7 +1197,7 @@ Example 4:
|
|||
if ((pads_initializer->dims_size() != 1 &&
|
||||
pads_initializer->dims_size() != 2) ||
|
||||
(pads_initializer->dims_size() == 2 &&
|
||||
pads_shape.dim((int)0).dim_value() != 1) ||
|
||||
pads_shape.dim(static_cast<int>(0)).dim_value() != 1) ||
|
||||
pads_initializer->data_type() != ONNX_NAMESPACE::TensorProto::INT64)
|
||||
fail_shape_inference(
|
||||
"'pads' input must be a 1D (shape: [input_rank]) "
|
||||
|
|
@ -1140,8 +1220,8 @@ Example 4:
|
|||
|
||||
const auto& output_shape =
|
||||
ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
|
||||
for (size_t i = 0; (int64_t)i < input_rank; ++i) {
|
||||
const auto& input_dim = input_shape.dim((int)i);
|
||||
for (size_t i = 0; static_cast<int64_t>(i) < input_rank; ++i) {
|
||||
const auto& input_dim = input_shape.dim(static_cast<int>(i));
|
||||
auto* output_dim = output_shape->add_dim();
|
||||
if (input_dim.has_dim_value()) {
|
||||
output_dim->set_dim_value(
|
||||
|
|
@ -1153,7 +1233,7 @@ Example 4:
|
|||
} else {
|
||||
// Infer ouput shapes' rank in any case
|
||||
auto* output_shape_0 = getOutputShape(ctx, 0);
|
||||
for (size_t i = 0; (int64_t)i < input_rank; ++i) {
|
||||
for (size_t i = 0; static_cast<int64_t>(i) < input_rank; ++i) {
|
||||
output_shape_0->add_dim();
|
||||
}
|
||||
}
|
||||
|
|
@ -1249,4 +1329,4 @@ Example 4:
|
|||
#endif
|
||||
} // namespace contrib
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
34
onnxruntime/test/contrib_ops/crop_op_test.cc
Normal file
34
onnxruntime/test/contrib_ops/crop_op_test.cc
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
TEST(CropOpTest, Crop_Border) {
|
||||
OpTester test("Crop", 1, onnxruntime::kOnnxDomain);
|
||||
test.AddInput<float>("x", {1, 1, 4, 4}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0});
|
||||
std::vector<int64_t> border{1, 1, 1, 1};
|
||||
test.AddAttribute("border", border);
|
||||
test.AddOutput<float>("y", {1, 1, 2, 2}, {6.0, 7.0, 10.0, 11.0});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(CropOpTest, Crop_Scale) {
|
||||
OpTester test("Crop", 1, onnxruntime::kOnnxDomain);
|
||||
test.AddInput<float>("x", {1, 1, 4, 4}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0});
|
||||
|
||||
std::vector<int64_t> border{1, 1, 1, 1};
|
||||
test.AddAttribute("border", border);
|
||||
|
||||
std::vector<int64_t> scale{2, 2};
|
||||
test.AddAttribute("scale", scale);
|
||||
|
||||
test.AddOutput<float>("y", {1, 1, 2, 2}, {6.0, 7.0, 10.0, 11.0});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue