From e8d722003a2e97d1737dbbf555cd3586f70d87b2 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Mon, 22 Apr 2019 13:24:27 -0700 Subject: [PATCH] Move NMS to Onnx domain (#865) * move files * move files * Remove NonMaxSuppression from Contrib op, move it to Onnx domain, opset 10 * move NMS out of namespace contrib * update data type in UT * update to latest onnx * white list the node test for Mod which is not implemented yet --- cmake/external/onnx | 2 +- onnxruntime/contrib_ops/contrib_kernels.cc | 2 - .../core/graph/contrib_ops/contrib_defs.cc | 58 ----------- .../providers/cpu/cpu_execution_provider.cc | 2 + .../providers/cpu/nn}/non_max_suppression.cc | 28 +++--- .../providers/cpu/nn}/non_max_suppression.h | 14 ++- onnxruntime/test/onnx/main.cc | 6 +- .../cpu/nn}/non_max_suppression_test.cc | 96 +++++++++---------- .../test/python/onnx_backend_test_series.py | 4 + .../linux/docker/scripts/install_deps.sh | 4 +- .../linux/docker/scripts/install_deps_x86.sh | 4 +- 11 files changed, 83 insertions(+), 137 deletions(-) rename onnxruntime/{contrib_ops/cpu => core/providers/cpu/nn}/non_max_suppression.cc (91%) rename onnxruntime/{contrib_ops/cpu => core/providers/cpu/nn}/non_max_suppression.h (75%) rename onnxruntime/test/{contrib_ops => providers/cpu/nn}/non_max_suppression_test.cc (78%) diff --git a/cmake/external/onnx b/cmake/external/onnx index 83dd62659f..0e8d2bc5e5 160000 --- a/cmake/external/onnx +++ b/cmake/external/onnx @@ -1 +1 @@ -Subproject commit 83dd62659fc07d5b7fa93b5d1c1879f93509c7db +Subproject commit 0e8d2bc5e51455c70ef790b9f65aa632ed9bc8a7 diff --git a/onnxruntime/contrib_ops/contrib_kernels.cc b/onnxruntime/contrib_ops/contrib_kernels.cc index 607e8202c2..09c88b6027 100644 --- a/onnxruntime/contrib_ops/contrib_kernels.cc +++ b/onnxruntime/contrib_ops/contrib_kernels.cc @@ -13,7 +13,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, NonMaxSuppression); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Range); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, WordConvEmbedding); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherND); @@ -59,7 +58,6 @@ void RegisterContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 0c8729dabf..ea0f6c00e5 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -875,64 +875,6 @@ with the exception that numpy default keepdims to False instead of True.)DOC") "Keep the reduced dimension or not, default 1 mean keep reduced dimension.", AttributeProto::INT); - ONNX_CONTRIB_OPERATOR_SCHEMA(NonMaxSuppression) - .SetDomain(kMSDomain) - .SinceVersion(1) - .SetDoc(R"DOC( -Filter out boxes that have high intersection-over-union (IOU) overlap with previously selected boxes. -Bounding boxes with score less than score_threshold are removed. Bounding box format is indicated by attribute center_point_box. -Note that this algorithm is agnostic to where the origin is in the coordinate system and more generally is invariant to -orthogonal transformations and translations of the coordinate system; thus translating or reflections of the coordinate system -result in the same boxes being selected by the algorithm. -The selected_indices output is a set of integers indexing into the input collection of bounding boxes representing the selected boxes. -The bounding box coordinates corresponding to the selected indices can then be obtained using the Gather or GatherND operation. -Note: The boxes doesn't has class dimension which means it alwasy has scores calculated for different classes on same box.)DOC") - .Input( - 0, - "boxes", - "An input tensor with shape [num_batches, spatial_dimension, 4]. The single box data format is indicated by center_point_box.", - "tensor(float)") - .Input( - 1, - "scores", - "An input tensor with shape [num_batches, num_classes, spatial_dimension]", - "tensor(float)") - .Input( - 2, - "max_output_boxes_per_class", - "Integer representing the maximum number of boxes to be selected per batch per class. It is a scalar.", - "tensor(int32)", - OpSchema::Optional) - .Input( - 3, - "iou_threshold", - "Float representing the threshold for deciding whether boxes overlap too much with respect to IOU. It is scalar. Value range [0, 1].", - "tensor(float)", - OpSchema::Optional) - .Input( - 4, - "score_threshold", - "Float representing the threshold for deciding when to remove boxes based on score. It is a scalar", - "tensor(float)", - OpSchema::Optional) - .Output( - 0, - "selected_indices", - "selected indices from the boxes tensor. [num_selected_indices, 3], the selected indices format is [batch_index, class_index, box_index].", - "tensor(int32)") - .Attr( - "center_point_box", - "Integer indicate the format of the box data. The default is 0." - "0 - the box data is supplied as [y1, x1, y2, x2] where (y1, x1) and (y2, x2) are the coordinates of any diagonal pair of box corners" - "and the coordinates can be provided as normalized (i.e., lying in the interval [0, 1]) or absolute. Mostly used for TF models." - "1 - the box data is supplied as [x_center, y_center, width, height]. Mostly used for Pytoch models.", - AttributeProto::INT, - static_cast(0)) - .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - auto selected_indices_type = ctx.getOutputType(0)->mutable_tensor_type(); - selected_indices_type->set_elem_type(::ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32); - }); - ONNX_CONTRIB_OPERATOR_SCHEMA(MurmurHash3) .SetDomain(kMSDomain) .SinceVersion(1) diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 73b22320b1..5355e163f9 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -264,6 +264,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int64_t, Slice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, string, Slice); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, Dropout); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, NonMaxSuppression); void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { @@ -517,6 +518,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/contrib_ops/cpu/non_max_suppression.cc b/onnxruntime/core/providers/cpu/nn/non_max_suppression.cc similarity index 91% rename from onnxruntime/contrib_ops/cpu/non_max_suppression.cc rename to onnxruntime/core/providers/cpu/nn/non_max_suppression.cc index 66cb48e833..3c538cfae9 100644 --- a/onnxruntime/contrib_ops/cpu/non_max_suppression.cc +++ b/onnxruntime/core/providers/cpu/nn/non_max_suppression.cc @@ -11,16 +11,15 @@ limitations under the License. ==============================================================================*/ /* Modifications Copyright (c) Microsoft. */ -#include "contrib_ops/cpu/non_max_suppression.h" +#include "non_max_suppression.h" #include namespace onnxruntime { -namespace contrib { ONNX_OPERATOR_KERNEL_EX( NonMaxSuppression, - kMSDomain, - 1, + kOnnxDomain, + 10, kCpuExecutionProvider, KernelDefBuilder(), NonMaxSuppression); @@ -35,7 +34,7 @@ void NonMaxSuppression::MaxMin(const float& lhs, const float& rhs, float& min, f } } -bool NonMaxSuppression::SuppressByIOU(const float* boxes_data, int32_t box_index1, int32_t box_index2, float iou_threshold) const { +bool NonMaxSuppression::SuppressByIOU(const float* boxes_data, int64_t box_index1, int64_t box_index2, float iou_threshold) const { float x1_min, y1_min, x1_max, y1_max, x2_min, y2_min, x2_max, y2_max; // center_point_box_ only support 0 or 1 if (0 == center_point_box_) { @@ -88,7 +87,7 @@ bool NonMaxSuppression::SuppressByIOU(const float* boxes_data, int32_t box_index } Status NonMaxSuppression::ParepareCompute(OpKernelContext* ctx, const TensorShape& boxes_shape, const TensorShape& scores_shape, - int32_t& max_output_boxes_per_class, float& iou_threshold, float& score_threshold, bool& has_score_threshold) const { + int64_t& max_output_boxes_per_class, float& iou_threshold, float& score_threshold, bool& has_score_threshold) const { ORT_RETURN_IF_NOT(boxes_shape.NumDimensions() == 3, "boxes must be a 3D tensor."); ORT_RETURN_IF_NOT(scores_shape.NumDimensions() == 3, "scores must be a 3D tensor."); @@ -104,7 +103,7 @@ Status NonMaxSuppression::ParepareCompute(OpKernelContext* ctx, const TensorShap const Tensor* max_output_boxes_per_class_tensor = ctx->Input(2); if (max_output_boxes_per_class_tensor != nullptr) { - max_output_boxes_per_class = *(max_output_boxes_per_class_tensor->Data()); + max_output_boxes_per_class = *(max_output_boxes_per_class_tensor->Data()); max_output_boxes_per_class = max_output_boxes_per_class > 0 ? max_output_boxes_per_class : 0; } @@ -132,7 +131,7 @@ Status NonMaxSuppression::Compute(OpKernelContext* ctx) const { auto& boxes_shape = boxes->Shape(); auto& scores_shape = scores->Shape(); - int32_t max_output_boxes_per_class = 0; + int64_t max_output_boxes_per_class = 0; float iou_threshold = 0; // Not so sure for the value range of score_threshold, so set a bool to indicate whether it has this input bool has_score_threshold = false; @@ -152,7 +151,7 @@ Status NonMaxSuppression::Compute(OpKernelContext* ctx) const { struct ScoreIndexPair { float score; - int32_t index; + int64_t index; }; auto LessCompare = [](const ScoreIndexPair& lhs, const ScoreIndexPair& rhs) { @@ -168,12 +167,12 @@ Status NonMaxSuppression::Compute(OpKernelContext* ctx) const { std::priority_queue, decltype(LessCompare)> sorted_scores_with_index(LessCompare); for (int64_t box_index = 0; box_index < num_boxes_; ++box_index) { if (!has_score_threshold || (has_score_threshold && scores_data[box_score_offset + box_index] > score_threshold)) { - sorted_scores_with_index.emplace(ScoreIndexPair({scores_data[box_score_offset + box_index], static_cast(box_index)})); + sorted_scores_with_index.emplace(ScoreIndexPair({scores_data[box_score_offset + box_index], box_index})); } } ScoreIndexPair next_top_score; - std::vector selected_indicies_inside_class; + std::vector selected_indicies_inside_class; // Get the next box with top score, filter by iou_threshold_ while (!sorted_scores_with_index.empty()) { next_top_score = sorted_scores_with_index.top(); @@ -189,11 +188,11 @@ Status NonMaxSuppression::Compute(OpKernelContext* ctx) const { } if (selected) { - if (max_output_boxes_per_class > 0 && selected_indicies_inside_class.size() >= max_output_boxes_per_class) { + if (max_output_boxes_per_class > 0 && static_cast(selected_indicies_inside_class.size()) >= max_output_boxes_per_class) { break; } selected_indicies_inside_class.push_back(next_top_score.index); - tmp_selected_indices.push_back(selected_index(static_cast(batch_index), static_cast(class_index), next_top_score.index)); + tmp_selected_indices.push_back(selected_index(batch_index, class_index, next_top_score.index)); } } //while } //for class_index @@ -202,10 +201,9 @@ Status NonMaxSuppression::Compute(OpKernelContext* ctx) const { int32_t num_selected = static_cast(tmp_selected_indices.size()); Tensor* selected_indices = ctx->Output(0, {num_selected, 3}); ORT_ENFORCE(selected_indices); - memcpy(selected_indices->MutableData(), tmp_selected_indices.data(), num_selected * sizeof(selected_index)); + memcpy(selected_indices->MutableData(), tmp_selected_indices.data(), num_selected * sizeof(selected_index)); return Status::OK(); } -} // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/non_max_suppression.h b/onnxruntime/core/providers/cpu/nn/non_max_suppression.h similarity index 75% rename from onnxruntime/contrib_ops/cpu/non_max_suppression.h rename to onnxruntime/core/providers/cpu/nn/non_max_suppression.h index 19a87cd643..d384628013 100644 --- a/onnxruntime/contrib_ops/cpu/non_max_suppression.h +++ b/onnxruntime/core/providers/cpu/nn/non_max_suppression.h @@ -7,7 +7,6 @@ #include "core/framework/op_kernel.h" namespace onnxruntime { -namespace contrib { class NonMaxSuppression final : public OpKernel { public: @@ -22,10 +21,10 @@ class NonMaxSuppression final : public OpKernel { Status Compute(OpKernelContext* context) const override; private: - bool SuppressByIOU(const float* boxes_data, int32_t box_index1, int32_t box_index2, float iou_threshold) const; + bool SuppressByIOU(const float* boxes_data, int64_t box_index1, int64_t box_index2, float iou_threshold) const; void MaxMin(const float& lhs, const float& rhs, float& min, float& max) const; Status ParepareCompute(OpKernelContext* ctx, const TensorShape& boxes_shape, const TensorShape& scores_shape, - int32_t& max_output_boxes_per_batch, float& iou_threshold, float& score_threshold, bool& has_score_threshold) const; + int64_t& max_output_boxes_per_batch, float& iou_threshold, float& score_threshold, bool& has_score_threshold) const; private: int64_t center_point_box_; @@ -35,12 +34,11 @@ class NonMaxSuppression final : public OpKernel { int64_t num_boxes_; struct selected_index { - selected_index(int32_t batch_index, int32_t class_index, int32_t box_index) + selected_index(int64_t batch_index, int64_t class_index, int64_t box_index) : batch_index_(batch_index), class_index_(class_index), box_index_(box_index) {} - int32_t batch_index_ = 0; - int32_t class_index_ = 0; - int32_t box_index_ = 0; + int64_t batch_index_ = 0; + int64_t class_index_ = 0; + int64_t box_index_ = 0; }; }; -} // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 1cd7607cb5..3eeb63ed5f 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -350,7 +350,11 @@ int real_main(int argc, char* argv[], OrtEnv** p_env) { {"shrink", "test case is wrong"}, {"maxpool_2d_precomputed_strides", "ShapeInferenceError"}, {"averagepool_2d_precomputed_strides", "ShapeInferenceError"}, - {"maxpool_with_argmax_2d_precomputed_strides", "ShapeInferenceError"} + {"maxpool_with_argmax_2d_precomputed_strides", "ShapeInferenceError"}, + {"test_mod_bcast", "not implemented"}, + {"test_mod_float_mixed_sign_example", "not implemented"}, + {"test_mod_fmod_mixed_sign_example", "not implemented"}, + {"test_mod_int64_mixed_sign_example", "not implemented"} }; #ifdef USE_CUDA diff --git a/onnxruntime/test/contrib_ops/non_max_suppression_test.cc b/onnxruntime/test/providers/cpu/nn/non_max_suppression_test.cc similarity index 78% rename from onnxruntime/test/contrib_ops/non_max_suppression_test.cc rename to onnxruntime/test/providers/cpu/nn/non_max_suppression_test.cc index 181a52358b..e309bedaa6 100644 --- a/onnxruntime/test/contrib_ops/non_max_suppression_test.cc +++ b/onnxruntime/test/providers/cpu/nn/non_max_suppression_test.cc @@ -8,7 +8,7 @@ namespace onnxruntime { namespace test { TEST(NonMaxSuppressionOpTest, WithIOUThreshold) { - OpTester test("NonMaxSuppression", 1, onnxruntime::kMSDomain); + OpTester test("NonMaxSuppression", 10, kOnnxDomain); test.AddInput("boxes", {1, 6, 4}, {0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.1f, 1.0f, 1.1f, @@ -17,10 +17,10 @@ TEST(NonMaxSuppressionOpTest, WithIOUThreshold) { 0.0f, 10.1f, 1.0f, 11.1f, 0.0f, 100.0f, 1.0f, 101.0f}); test.AddInput("scores", {1, 1, 6}, {0.9f, 0.75f, 0.6f, 0.95f, 0.5f, 0.3f}); - test.AddInput("max_output_boxes_per_class", {}, {3L}); + test.AddInput("max_output_boxes_per_class", {}, {3L}); test.AddInput("iou_threshold", {}, {0.5f}); test.AddInput("score_threshold", {}, {0.0f}); - test.AddOutput("selected_indices", {3, 3}, + test.AddOutput("selected_indices", {3, 3}, {0L, 0L, 3L, 0L, 0L, 0L, 0L, 0L, 5L}); @@ -28,7 +28,7 @@ TEST(NonMaxSuppressionOpTest, WithIOUThreshold) { } TEST(NonMaxSuppressionOpTest, CenterPointBoxFormat) { - OpTester test("NonMaxSuppression", 1, onnxruntime::kMSDomain); + OpTester test("NonMaxSuppression", 10, kOnnxDomain); test.AddInput("boxes", {1, 6, 4}, {0.5f, 0.5f, 1.0f, 1.0f, 0.5f, 0.6f, 1.0f, 1.0f, @@ -37,10 +37,10 @@ TEST(NonMaxSuppressionOpTest, CenterPointBoxFormat) { 0.5f, 10.6f, 1.0f, 1.0f, 0.5f, 100.5f, 1.0f, 1.0f}); test.AddInput("scores", {1, 1, 6}, {0.9f, 0.75f, 0.6f, 0.95f, 0.5f, 0.3f}); - test.AddInput("max_output_boxes_per_class", {}, {3L}); + test.AddInput("max_output_boxes_per_class", {}, {3L}); test.AddInput("iou_threshold", {}, {0.5f}); test.AddInput("score_threshold", {}, {0.0f}); - test.AddOutput("selected_indices", {3, 3}, + test.AddOutput("selected_indices", {3, 3}, {0L, 0L, 3L, 0L, 0L, 0L, 0L, 0L, 5L}); @@ -49,7 +49,7 @@ TEST(NonMaxSuppressionOpTest, CenterPointBoxFormat) { } TEST(NonMaxSuppressionOpTest, TwoClasses) { - OpTester test("NonMaxSuppression", 1, onnxruntime::kMSDomain); + OpTester test("NonMaxSuppression", 10, kOnnxDomain); test.AddInput("boxes", {1, 6, 4}, {0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.1f, 1.0f, 1.1f, @@ -60,10 +60,10 @@ TEST(NonMaxSuppressionOpTest, TwoClasses) { test.AddInput("scores", {1, 2, 6}, {0.9f, 0.75f, 0.6f, 0.95f, 0.5f, 0.3f, 0.9f, 0.75f, 0.6f, 0.95f, 0.5f, 0.3f}); - test.AddInput("max_output_boxes_per_class", {}, {6L}); + test.AddInput("max_output_boxes_per_class", {}, {6L}); test.AddInput("iou_threshold", {}, {0.5f}); test.AddInput("score_threshold", {}, {0.0f}); - test.AddOutput("selected_indices", {6, 3}, + test.AddOutput("selected_indices", {6, 3}, {0L, 0L, 3L, 0L, 0L, 0L, 0L, 0L, 5L, @@ -74,7 +74,7 @@ TEST(NonMaxSuppressionOpTest, TwoClasses) { } TEST(NonMaxSuppressionOpTest, TwoBathes) { - OpTester test("NonMaxSuppression", 1, onnxruntime::kMSDomain); + OpTester test("NonMaxSuppression", 10, kOnnxDomain); test.AddInput("boxes", {2, 6, 4}, {0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.1f, 1.0f, 1.1f, @@ -92,10 +92,10 @@ TEST(NonMaxSuppressionOpTest, TwoBathes) { test.AddInput("scores", {2, 1, 6}, {0.9f, 0.75f, 0.6f, 0.95f, 0.5f, 0.3f, 0.9f, 0.75f, 0.6f, 0.95f, 0.5f, 0.3f}); - test.AddInput("max_output_boxes_per_class", {}, {2L}); + test.AddInput("max_output_boxes_per_class", {}, {2L}); test.AddInput("iou_threshold", {}, {0.5f}); test.AddInput("score_threshold", {}, {0.0f}); - test.AddOutput("selected_indices", {4, 3}, + test.AddOutput("selected_indices", {4, 3}, {0L, 0L, 3L, 0L, 0L, 0L, 1L, 0L, 3L, @@ -104,7 +104,7 @@ TEST(NonMaxSuppressionOpTest, TwoBathes) { } TEST(NonMaxSuppressionOpTest, WithScoreThreshold) { - OpTester test("NonMaxSuppression", 1, onnxruntime::kMSDomain); + OpTester test("NonMaxSuppression", 10, kOnnxDomain); test.AddInput("boxes", {1, 6, 4}, {0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.1f, 1.0f, 1.1f, @@ -113,17 +113,17 @@ TEST(NonMaxSuppressionOpTest, WithScoreThreshold) { 0.0f, 10.1f, 1.0f, 11.1f, 0.0f, 100.0f, 1.0f, 101.0f}); test.AddInput("scores", {1, 1, 6}, {0.9f, 0.75f, 0.6f, 0.95f, 0.5f, 0.3f}); - test.AddInput("max_output_boxes_per_class", {}, {3L}); + test.AddInput("max_output_boxes_per_class", {}, {3L}); test.AddInput("iou_threshold", {}, {0.5f}); test.AddInput("score_threshold", {}, {0.4f}); - test.AddOutput("selected_indices", {2, 3}, + test.AddOutput("selected_indices", {2, 3}, {0L, 0L, 3L, 0L, 0L, 0L}); test.Run(); } TEST(NonMaxSuppressionOpTest, WithoutScoreThreshold) { - OpTester test("NonMaxSuppression", 1, onnxruntime::kMSDomain); + OpTester test("NonMaxSuppression", 10, kOnnxDomain); test.AddInput("boxes", {1, 6, 4}, {0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.1f, 1.0f, 1.1f, @@ -132,9 +132,9 @@ TEST(NonMaxSuppressionOpTest, WithoutScoreThreshold) { 0.0f, 10.1f, 1.0f, 11.1f, 0.0f, 100.0f, 1.0f, 101.0f}); test.AddInput("scores", {1, 1, 6}, {0.9f, 0.75f, 0.6f, 0.95f, 0.5f, 0.3f}); - test.AddInput("max_output_boxes_per_class", {}, {3L}); + test.AddInput("max_output_boxes_per_class", {}, {3L}); test.AddInput("iou_threshold", {}, {0.5f}); - test.AddOutput("selected_indices", {3, 3}, + test.AddOutput("selected_indices", {3, 3}, {0L, 0L, 3L, 0L, 0L, 0L, 0L, 0L, 5L}); @@ -142,7 +142,7 @@ TEST(NonMaxSuppressionOpTest, WithoutScoreThreshold) { } TEST(NonMaxSuppressionOpTest, WithScoreThresholdZeroScores) { - OpTester test("NonMaxSuppression", 1, onnxruntime::kMSDomain); + OpTester test("NonMaxSuppression", 10, kOnnxDomain); test.AddInput("boxes", {1, 6, 4}, {0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.1f, 1.0f, 1.1f, @@ -151,17 +151,17 @@ TEST(NonMaxSuppressionOpTest, WithScoreThresholdZeroScores) { 0.0f, 10.1f, 1.0f, 11.1f, 0.0f, 100.0f, 1.0f, 101.0f}); test.AddInput("scores", {1, 1, 6}, {0.1f, 0.0f, 0.0f, 0.3f, 0.2f, -5.0f}); - test.AddInput("max_output_boxes_per_class", {}, {6L}); + test.AddInput("max_output_boxes_per_class", {}, {6L}); test.AddInput("iou_threshold", {}, {0.5f}); test.AddInput("score_threshold", {}, {-3.0f}); - test.AddOutput("selected_indices", {2, 3}, + test.AddOutput("selected_indices", {2, 3}, {0L, 0L, 3L, 0L, 0L, 0L}); test.Run(); } TEST(NonMaxSuppressionOpTest, FlippedCoordinates) { - OpTester test("NonMaxSuppression", 1, onnxruntime::kMSDomain); + OpTester test("NonMaxSuppression", 10, kOnnxDomain); test.AddInput("boxes", {1, 6, 4}, {1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.1f, 1.0f, 1.1f, @@ -170,10 +170,10 @@ TEST(NonMaxSuppressionOpTest, FlippedCoordinates) { 1.0f, 10.1f, 0.0f, 11.1f, 1.0f, 101.0f, 0.0f, 100.0f}); test.AddInput("scores", {1, 1, 6}, {0.9f, 0.75f, 0.6f, 0.95f, 0.5f, 0.3f}); - test.AddInput("max_output_boxes_per_class", {}, {3L}); + test.AddInput("max_output_boxes_per_class", {}, {3L}); test.AddInput("iou_threshold", {}, {0.5f}); test.AddInput("score_threshold", {}, {0.0f}); - test.AddOutput("selected_indices", {3, 3}, + test.AddOutput("selected_indices", {3, 3}, {0L, 0L, 3L, 0L, 0L, 0L, 0L, 0L, 5L}); @@ -181,7 +181,7 @@ TEST(NonMaxSuppressionOpTest, FlippedCoordinates) { } TEST(NonMaxSuppressionOpTest, SelectTwo) { - OpTester test("NonMaxSuppression", 1, onnxruntime::kMSDomain); + OpTester test("NonMaxSuppression", 10, kOnnxDomain); test.AddInput("boxes", {1, 6, 4}, {0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.1f, 1.0f, 1.1f, @@ -190,17 +190,17 @@ TEST(NonMaxSuppressionOpTest, SelectTwo) { 0.0f, 10.1f, 1.0f, 11.1f, 0.0f, 100.0f, 1.0f, 101.0f}); test.AddInput("scores", {1, 1, 6}, {0.9f, 0.75f, 0.6f, 0.95f, 0.5f, 0.3f}); - test.AddInput("max_output_boxes_per_class", {}, {2L}); + test.AddInput("max_output_boxes_per_class", {}, {2L}); test.AddInput("iou_threshold", {}, {0.5f}); test.AddInput("score_threshold", {}, {0.0f}); - test.AddOutput("selected_indices", {2, 3}, + test.AddOutput("selected_indices", {2, 3}, {0L, 0L, 3L, 0L, 0L, 0L}); test.Run(); } TEST(NonMaxSuppressionOpTest, SelectThirty) { - OpTester test("NonMaxSuppression", 1, onnxruntime::kMSDomain); + OpTester test("NonMaxSuppression", 10, kOnnxDomain); test.AddInput("boxes", {1, 6, 4}, {0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.1f, 1.0f, 1.1f, @@ -209,10 +209,10 @@ TEST(NonMaxSuppressionOpTest, SelectThirty) { 0.0f, 10.1f, 1.0f, 11.1f, 0.0f, 100.0f, 1.0f, 101.0f}); test.AddInput("scores", {1, 1, 6}, {0.9f, 0.75f, 0.6f, 0.95f, 0.5f, 0.3f}); - test.AddInput("max_output_boxes_per_class", {}, {30L}); + test.AddInput("max_output_boxes_per_class", {}, {30L}); test.AddInput("iou_threshold", {}, {0.5f}); test.AddInput("score_threshold", {}, {0.0f}); - test.AddOutput("selected_indices", {3, 3}, + test.AddOutput("selected_indices", {3, 3}, {0L, 0L, 3L, 0L, 0L, 0L, 0L, 0L, 5L}); @@ -220,19 +220,19 @@ TEST(NonMaxSuppressionOpTest, SelectThirty) { } TEST(NonMaxSuppressionOpTest, SelectSingleBox) { - OpTester test("NonMaxSuppression", 1, onnxruntime::kMSDomain); + OpTester test("NonMaxSuppression", 10, kOnnxDomain); test.AddInput("boxes", {1, 1, 4}, {0.0f, 0.0f, 1.0f, 1.0f}); test.AddInput("scores", {1, 1, 1}, {0.9f}); - test.AddInput("max_output_boxes_per_class", {}, {3L}); + test.AddInput("max_output_boxes_per_class", {}, {3L}); test.AddInput("iou_threshold", {}, {0.5f}); test.AddInput("score_threshold", {}, {0.0f}); - test.AddOutput("selected_indices", {1, 3}, {0L, 0L, 0L}); + test.AddOutput("selected_indices", {1, 3}, {0L, 0L, 0L}); test.Run(); } TEST(NonMaxSuppressionOpTest, SelectFromIdenticalBoxes) { - OpTester test("NonMaxSuppression", 1, onnxruntime::kMSDomain); + OpTester test("NonMaxSuppression", 10, kOnnxDomain); test.AddInput("boxes", {1, 10, 4}, {0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f, @@ -246,15 +246,15 @@ TEST(NonMaxSuppressionOpTest, SelectFromIdenticalBoxes) { 0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f}); test.AddInput("scores", {1, 1, 10}, {0.9f, 0.9f, 0.9f, 0.9f, 0.9f, 0.9f, 0.9f, 0.9f, 0.9f, 0.9f}); - test.AddInput("max_output_boxes_per_class", {}, {3L}); + test.AddInput("max_output_boxes_per_class", {}, {3L}); test.AddInput("iou_threshold", {}, {0.5f}); test.AddInput("score_threshold", {}, {0.0f}); - test.AddOutput("selected_indices", {1, 3}, {0L, 0L, 0L}); + test.AddOutput("selected_indices", {1, 3}, {0L, 0L, 0L}); test.Run(); } TEST(NonMaxSuppressionOpTest, InconsistentBoxAndScoreShapes) { - OpTester test("NonMaxSuppression", 1, onnxruntime::kMSDomain); + OpTester test("NonMaxSuppression", 10, kOnnxDomain); test.AddInput("boxes", {1, 6, 4}, {0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.1f, 1.0f, 1.1f, @@ -263,37 +263,37 @@ TEST(NonMaxSuppressionOpTest, InconsistentBoxAndScoreShapes) { 0.0f, 10.1f, 1.0f, 11.1f, 0.0f, 100.0f, 1.0f, 101.0f}); test.AddInput("scores", {1, 1, 5}, {0.9f, 0.75f, 0.6f, 0.95f, 0.5f}); - test.AddInput("max_output_boxes_per_class", {}, {30L}); + test.AddInput("max_output_boxes_per_class", {}, {30L}); test.AddInput("iou_threshold", {}, {0.5f}); test.AddInput("score_threshold", {}, {0.0f}); - test.AddOutput("selected_indices", {0, 3}, {}); + test.AddOutput("selected_indices", {0, 3}, {}); test.Run(OpTester::ExpectResult::kExpectFailure, "boxes and scores should have same spatial_dimention."); } TEST(NonMaxSuppressionOpTest, InvalidIOUThreshold) { - OpTester test("NonMaxSuppression", 1, onnxruntime::kMSDomain); + OpTester test("NonMaxSuppression", 10, kOnnxDomain); test.AddInput("boxes", {1, 1, 4}, {0.0f, 0.0f, 1.0f, 1.0f}); test.AddInput("scores", {1, 1, 1}, {0.9f}); - test.AddInput("max_output_boxes_per_class", {}, {3L}); + test.AddInput("max_output_boxes_per_class", {}, {3L}); test.AddInput("iou_threshold", {}, {1.2f}); test.AddInput("score_threshold", {}, {0.0f}); - test.AddOutput("selected_indices", {0, 3}, {}); + test.AddOutput("selected_indices", {0, 3}, {}); test.Run(OpTester::ExpectResult::kExpectFailure, "iou_threshold must be in range [0, 1]"); } TEST(NonMaxSuppressionOpTest, EmptyInput) { - OpTester test("NonMaxSuppression", 1, onnxruntime::kMSDomain); + OpTester test("NonMaxSuppression", 10, kOnnxDomain); test.AddInput("boxes", {1, 0, 4}, {}); test.AddInput("scores", {1, 1, 0}, {}); - test.AddInput("max_output_boxes_per_class", {}, {30L}); + test.AddInput("max_output_boxes_per_class", {}, {30L}); test.AddInput("iou_threshold", {}, {0.5f}); test.AddInput("score_threshold", {}, {0.0f}); - test.AddOutput("selected_indices", {0, 3}, {}); + test.AddOutput("selected_indices", {0, 3}, {}); test.Run(); } TEST(NonMaxSuppressionOpTest, ZeroMaxOutputPerClass) { - OpTester test("NonMaxSuppression", 1, onnxruntime::kMSDomain); + OpTester test("NonMaxSuppression", 10, kOnnxDomain); test.AddInput("boxes", {1, 6, 4}, {0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.1f, 1.0f, 1.1f, @@ -302,10 +302,10 @@ TEST(NonMaxSuppressionOpTest, ZeroMaxOutputPerClass) { 0.0f, 10.1f, 1.0f, 11.1f, 0.0f, 100.0f, 1.0f, 101.0f}); test.AddInput("scores", {1, 1, 6}, {0.9f, 0.75f, 0.6f, 0.95f, 0.5f, 0.3f}); - test.AddInput("max_output_boxes_per_class", {}, {0L}); + test.AddInput("max_output_boxes_per_class", {}, {0L}); test.AddInput("iou_threshold", {}, {0.5f}); test.AddInput("score_threshold", {}, {0.4f}); - test.AddOutput("selected_indices", {0, 3}, {}); + test.AddOutput("selected_indices", {0, 3}, {}); test.Run(); } diff --git a/onnxruntime/test/python/onnx_backend_test_series.py b/onnxruntime/test/python/onnx_backend_test_series.py index 877f5ad9c8..b954fde623 100644 --- a/onnxruntime/test/python/onnx_backend_test_series.py +++ b/onnxruntime/test/python/onnx_backend_test_series.py @@ -95,6 +95,10 @@ backend_test.exclude(r'(' '|^test_isinf_cpu.*' '|^test_isinf_negative_cpu.*' '|^test_isinf_positive_cpu.*' +'|^test_mod_bcast.*' +'|^test_mod_float_mixed_sign_example.*' +'|^test_mod_fmod_mixed_sign_example.*' +'|^test_mod_int64_mixed_sign_example.*' ')') # import all test cases at global scope to make diff --git a/tools/ci_build/github/linux/docker/scripts/install_deps.sh b/tools/ci_build/github/linux/docker/scripts/install_deps.sh index 7e3f98fc58..c00b3df1e6 100755 --- a/tools/ci_build/github/linux/docker/scripts/install_deps.sh +++ b/tools/ci_build/github/linux/docker/scripts/install_deps.sh @@ -38,8 +38,8 @@ else #5af210ca8a1c73aa6bae8754c9346ec54d0a756e is v1.2.3 #bae6333e149a59a3faa9c4d9c44974373dcf5256 is v1.3.0 #9e55ace55aad1ada27516038dfbdc66a8a0763db is v1.4.1 - #83dd62659fc07d5b7fa93b5d1c1879f93509c7db is v1.4.1 latest - for onnx_version in "5af210ca8a1c73aa6bae8754c9346ec54d0a756e" "bae6333e149a59a3faa9c4d9c44974373dcf5256" "9e55ace55aad1ada27516038dfbdc66a8a0763db" "83dd62659fc07d5b7fa93b5d1c1879f93509c7db"; do + #0e8d2bc5e51455c70ef790b9f65aa632ed9bc8a7 is v1.4.1 latest + for onnx_version in "5af210ca8a1c73aa6bae8754c9346ec54d0a756e" "bae6333e149a59a3faa9c4d9c44974373dcf5256" "9e55ace55aad1ada27516038dfbdc66a8a0763db" "0e8d2bc5e51455c70ef790b9f65aa632ed9bc8a7"; do if [ -z ${lastest_onnx_version+x} ]; then echo "first pass"; else diff --git a/tools/ci_build/github/linux/docker/scripts/install_deps_x86.sh b/tools/ci_build/github/linux/docker/scripts/install_deps_x86.sh index da8c8192d7..54a3c764dd 100755 --- a/tools/ci_build/github/linux/docker/scripts/install_deps_x86.sh +++ b/tools/ci_build/github/linux/docker/scripts/install_deps_x86.sh @@ -32,8 +32,8 @@ else #5af210ca8a1c73aa6bae8754c9346ec54d0a756e is v1.2.3 #bae6333e149a59a3faa9c4d9c44974373dcf5256 is v1.3.0 #9e55ace55aad1ada27516038dfbdc66a8a0763db is v1.4.1 - #83dd62659fc07d5b7fa93b5d1c1879f93509c7db is v1.4.1 latest - for onnx_version in "5af210ca8a1c73aa6bae8754c9346ec54d0a756e" "bae6333e149a59a3faa9c4d9c44974373dcf5256" "9e55ace55aad1ada27516038dfbdc66a8a0763db" "83dd62659fc07d5b7fa93b5d1c1879f93509c7db"; do + #0e8d2bc5e51455c70ef790b9f65aa632ed9bc8a7 is v1.4.1 latest + for onnx_version in "5af210ca8a1c73aa6bae8754c9346ec54d0a756e" "bae6333e149a59a3faa9c4d9c44974373dcf5256" "9e55ace55aad1ada27516038dfbdc66a8a0763db" "0e8d2bc5e51455c70ef790b9f65aa632ed9bc8a7"; do if [ -z ${lastest_onnx_version+x} ]; then echo "first pass"; else