From 7545b795dfcea271e8a964354d4d012c8c7bdd80 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Thu, 15 Aug 2019 11:41:10 -0700 Subject: [PATCH] Fix incorrect box offset computation in NMS op (#1624) * More changes * Fix NMS * nits --- .../object_detection/non_max_suppression.cc | 2 +- .../non_max_suppression_test.cc | 37 ++++++++++++++++++- 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/cpu/object_detection/non_max_suppression.cc b/onnxruntime/core/providers/cpu/object_detection/non_max_suppression.cc index 6608454781..bdb57248c4 100644 --- a/onnxruntime/core/providers/cpu/object_detection/non_max_suppression.cc +++ b/onnxruntime/core/providers/cpu/object_detection/non_max_suppression.cc @@ -141,7 +141,7 @@ Status NonMaxSuppression::Compute(OpKernelContext* ctx) const { for (int64_t batch_index = 0; batch_index < pc.num_batches_; ++batch_index) { for (int64_t class_index = 0; class_index < pc.num_classes_; ++class_index) { int64_t box_score_offset = (batch_index * pc.num_classes_ + class_index) * pc.num_boxes_; - int64_t box_offset = batch_index * pc.num_classes_ * pc.num_boxes_ * 4; + int64_t box_offset = batch_index * pc.num_boxes_ * 4; // Filter by score_threshold_ std::priority_queue> sorted_scores_with_index; const auto* class_scores = scores_data + box_score_offset; diff --git a/onnxruntime/test/providers/cpu/object_detection/non_max_suppression_test.cc b/onnxruntime/test/providers/cpu/object_detection/non_max_suppression_test.cc index 9675612b7e..45f537bc89 100644 --- a/onnxruntime/test/providers/cpu/object_detection/non_max_suppression_test.cc +++ b/onnxruntime/test/providers/cpu/object_detection/non_max_suppression_test.cc @@ -73,7 +73,7 @@ TEST(NonMaxSuppressionOpTest, TwoClasses) { test.Run(); } -TEST(NonMaxSuppressionOpTest, TwoBathes) { +TEST(NonMaxSuppressionOpTest, TwoBatches_OneClass) { OpTester test("NonMaxSuppression", 10, kOnnxDomain); test.AddInput("boxes", {2, 6, 4}, {0.0f, 0.0f, 1.0f, 1.0f, @@ -103,6 +103,41 @@ TEST(NonMaxSuppressionOpTest, TwoBathes) { test.Run(); } +TEST(NonMaxSuppressionOpTest, TwoBatches_TwoClasses) { + OpTester test("NonMaxSuppression", 10, kOnnxDomain); + test.AddInput("boxes", {2, 5, 4}, + {0.0f, 0.0f, 0.3f, 0.3f, + 0.0f, 0.0f, 0.4f, 0.4f, + 0.0f, 0.0f, 0.5f, 0.5f, + 0.5f, 0.5f, 0.9f, 0.9f, + 0.5f, 0.5f, 1.0f, 1.0f, + + 0.0f, 0.0f, 0.3f, 0.3f, + 0.0f, 0.0f, 0.4f, 0.4f, + 0.5f, 0.5f, 0.95f, 0.95f, + 0.5f, 0.5f, 0.96f, 0.96f, + 0.5f, 0.5f, 1.0f, 1.0f}); + test.AddInput("scores", {2, 2, 5}, + {0.1f, 0.2f, 0.6f, 0.3f, 0.9f, + 0.1f, 0.2f, 0.6f, 0.3f, 0.9f, + + 0.1f, 0.2f, 0.6f, 0.3f, 0.9f, + 0.1f, 0.2f, 0.6f, 0.3f, 0.9f}); + test.AddInput("max_output_boxes_per_class", {}, {2L}); + test.AddInput("iou_threshold", {}, {0.8f}); + test.AddOutput("selected_indices", {8, 3}, + {0L, 0L, 4L, + 0L, 0L, 2L, + 0L, 1L, 4L, + 0L, 1L, 2L, + + 1L, 0L, 4L, + 1L, 0L, 1L, + 1L, 1L, 4L, + 1L, 1L, 1L}); + test.Run(); +} + TEST(NonMaxSuppressionOpTest, WithScoreThreshold) { OpTester test("NonMaxSuppression", 10, kOnnxDomain); test.AddInput("boxes", {1, 6, 4},