From 0741baf867b4c995601d8510552ecc2afcc6d8ab Mon Sep 17 00:00:00 2001 From: Hector Li Date: Thu, 11 Apr 2019 10:09:33 -0700 Subject: [PATCH] Update NMS to support max_output_boxes_per_class = 0. NMS will do nothing for this case. (#816) --- .../contrib_ops/cpu/non_max_suppression.cc | 7 ++++++- .../core/graph/contrib_ops/contrib_defs.cc | 2 +- .../contrib_ops/non_max_suppression_test.cc | 21 +++++++++++++++++-- 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/non_max_suppression.cc b/onnxruntime/contrib_ops/cpu/non_max_suppression.cc index 86ea63ba22..66cb48e833 100644 --- a/onnxruntime/contrib_ops/cpu/non_max_suppression.cc +++ b/onnxruntime/contrib_ops/cpu/non_max_suppression.cc @@ -105,7 +105,7 @@ Status NonMaxSuppression::ParepareCompute(OpKernelContext* ctx, const TensorShap const Tensor* max_output_boxes_per_class_tensor = ctx->Input(2); if (max_output_boxes_per_class_tensor != nullptr) { max_output_boxes_per_class = *(max_output_boxes_per_class_tensor->Data()); - ORT_RETURN_IF_NOT(max_output_boxes_per_class > 0, "max_output_boxes_per_class should be greater than 0."); + max_output_boxes_per_class = max_output_boxes_per_class > 0 ? max_output_boxes_per_class : 0; } const Tensor* iou_threshold_tensor = ctx->Input(3); @@ -142,6 +142,11 @@ Status NonMaxSuppression::Compute(OpKernelContext* ctx) const { iou_threshold, score_threshold, has_score_threshold); ORT_RETURN_IF_NOT(ret.IsOK(), ret.ErrorMessage()); + if (0 == max_output_boxes_per_class) { + ctx->Output(0, {0, 3}); + return Status::OK(); + } + const float* boxes_data = boxes->Data(); const float* scores_data = scores->Data(); diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 724f037f16..348b7714cf 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -893,7 +893,7 @@ Note: The boxes doesn't has class dimension which means it alwasy has scores cal .Input( 2, "max_output_boxes_per_class", - "Integer representing the maximum number of boxes to be selected per batch per class. It is a scalar. Value should be greater than 0", + "Integer representing the maximum number of boxes to be selected per batch per class. It is a scalar.", "tensor(int32)", OpSchema::Optional) .Input( diff --git a/onnxruntime/test/contrib_ops/non_max_suppression_test.cc b/onnxruntime/test/contrib_ops/non_max_suppression_test.cc index 5e034ccd3e..181a52358b 100644 --- a/onnxruntime/test/contrib_ops/non_max_suppression_test.cc +++ b/onnxruntime/test/contrib_ops/non_max_suppression_test.cc @@ -266,7 +266,7 @@ TEST(NonMaxSuppressionOpTest, InconsistentBoxAndScoreShapes) { test.AddInput("max_output_boxes_per_class", {}, {30L}); test.AddInput("iou_threshold", {}, {0.5f}); test.AddInput("score_threshold", {}, {0.0f}); - test.AddOutput("selected_indices", {0}, {}); + test.AddOutput("selected_indices", {0, 3}, {}); test.Run(OpTester::ExpectResult::kExpectFailure, "boxes and scores should have same spatial_dimention."); } @@ -277,7 +277,7 @@ TEST(NonMaxSuppressionOpTest, InvalidIOUThreshold) { test.AddInput("max_output_boxes_per_class", {}, {3L}); test.AddInput("iou_threshold", {}, {1.2f}); test.AddInput("score_threshold", {}, {0.0f}); - test.AddOutput("selected_indices", {0}, {}); + test.AddOutput("selected_indices", {0, 3}, {}); test.Run(OpTester::ExpectResult::kExpectFailure, "iou_threshold must be in range [0, 1]"); } @@ -292,5 +292,22 @@ TEST(NonMaxSuppressionOpTest, EmptyInput) { test.Run(); } +TEST(NonMaxSuppressionOpTest, ZeroMaxOutputPerClass) { + OpTester test("NonMaxSuppression", 1, onnxruntime::kMSDomain); + test.AddInput("boxes", {1, 6, 4}, + {0.0f, 0.0f, 1.0f, 1.0f, + 0.0f, 0.1f, 1.0f, 1.1f, + 0.0f, -0.1f, 1.0f, 0.9f, + 0.0f, 10.0f, 1.0f, 11.0f, + 0.0f, 10.1f, 1.0f, 11.1f, + 0.0f, 100.0f, 1.0f, 101.0f}); + test.AddInput("scores", {1, 1, 6}, {0.9f, 0.75f, 0.6f, 0.95f, 0.5f, 0.3f}); + test.AddInput("max_output_boxes_per_class", {}, {0L}); + test.AddInput("iou_threshold", {}, {0.5f}); + test.AddInput("score_threshold", {}, {0.4f}); + test.AddOutput("selected_indices", {0, 3}, {}); + test.Run(); +} + } // namespace test } // namespace onnxruntime