From 2f0479780e4abfa78deb35ad703b832e84bc4033 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 5 May 2021 09:14:40 +0200 Subject: [PATCH] Improves NonMaxSuppression on CPU (#7557) * improves non max suppression * use pointer instead of boxes --- .../object_detection/non_max_suppression.cc | 56 ++++--------------- .../non_max_suppression_helper.h | 45 ++++++++++----- 2 files changed, 43 insertions(+), 58 deletions(-) diff --git a/onnxruntime/core/providers/cpu/object_detection/non_max_suppression.cc b/onnxruntime/core/providers/cpu/object_detection/non_max_suppression.cc index fca1a20bc1..bc6b7e3543 100644 --- a/onnxruntime/core/providers/cpu/object_detection/non_max_suppression.cc +++ b/onnxruntime/core/providers/cpu/object_detection/non_max_suppression.cc @@ -132,62 +132,28 @@ Status NonMaxSuppression::Compute(OpKernelContext* ctx) const { const auto* const boxes_data = pc.boxes_data_; const auto* const scores_data = pc.scores_data_; - struct BoxInfo { + struct BoxInfoPtr { float score_{}; int64_t index_{}; - float box_[4]{}; - float area_{}; - BoxInfo() = default; - explicit BoxInfo(float score, int64_t idx, int64_t center_point_box, const float* box) : score_(score), index_(idx) { - if (0 == center_point_box) { - // boxes data format [y1, x1, y2, x2], - MaxMin(box[1], box[3], box_[1], box_[3]); - MaxMin(box[0], box[2], box_[0], box_[2]); - } else { - // boxes data format [x_center, y_center, width, height] - float box_width_half = box[2] / 2; - float box_height_half = box[3] / 2; - box_[1] = box[0] - box_width_half; - box_[3] = box[0] + box_width_half; - box_[0] = box[1] - box_height_half; - box_[2] = box[1] + box_height_half; - } - area_ = (box_[2] - box_[0]) * (box_[3] - box_[1]); - } - - inline bool operator<(const BoxInfo& rhs) const { + BoxInfoPtr() = default; + explicit BoxInfoPtr(float score, int64_t idx) : score_(score), index_(idx) {} + inline bool operator<(const BoxInfoPtr& rhs) const { return score_ < rhs.score_ || (score_ == rhs.score_ && index_ > rhs.index_); } - - inline bool SuppressByIOU(const BoxInfo& rhs, float iou_threshold) const { - const float intersection_x_min = std::max(box_[1], rhs.box_[1]); - const float intersection_y_min = std::max(box_[0], rhs.box_[0]); - const float intersection_x_max = std::min(box_[3], rhs.box_[3]); - const float intersection_y_max = std::min(box_[2], rhs.box_[2]); - - const float intersection_area = std::max(intersection_x_max - intersection_x_min, .0f) * - std::max(intersection_y_max - intersection_y_min, .0f); - if (intersection_area <= .0f) { - return false; - } - const float union_area = area_ + rhs.area_ - intersection_area; - const float intersection_over_union = intersection_area / union_area; - return intersection_over_union > iou_threshold; - } }; const auto center_point_box = GetCenterPointBox(); std::vector selected_indices; - std::vector selected_boxes_inside_class; + std::vector selected_boxes_inside_class; selected_boxes_inside_class.reserve(std::min(static_cast(max_output_boxes_per_class), pc.num_boxes_)); for (int64_t batch_index = 0; batch_index < pc.num_batches_; ++batch_index) { for (int64_t class_index = 0; class_index < pc.num_classes_; ++class_index) { int64_t box_score_offset = (batch_index * pc.num_classes_ + class_index) * pc.num_boxes_; const float* batch_boxes = boxes_data + (batch_index * pc.num_boxes_ * 4); - std::vector candidate_boxes; + std::vector candidate_boxes; candidate_boxes.reserve(pc.num_boxes_); // Filter by score_threshold_ @@ -195,25 +161,25 @@ Status NonMaxSuppression::Compute(OpKernelContext* ctx) const { if (pc.score_threshold_ != nullptr) { for (int64_t box_index = 0; box_index < pc.num_boxes_; ++box_index, ++class_scores) { if (*class_scores > score_threshold) { - candidate_boxes.emplace_back(*class_scores, box_index, center_point_box, batch_boxes + (box_index * 4)); + candidate_boxes.emplace_back(*class_scores, box_index); } } } else { for (int64_t box_index = 0; box_index < pc.num_boxes_; ++box_index, ++class_scores) { - candidate_boxes.emplace_back(*class_scores, box_index, center_point_box, batch_boxes + (box_index * 4)); + candidate_boxes.emplace_back(*class_scores, box_index); } } - std::priority_queue> sorted_boxes(std::less(), std::move(candidate_boxes)); + std::priority_queue> sorted_boxes(std::less(), std::move(candidate_boxes)); selected_boxes_inside_class.clear(); // Get the next box with top score, filter by iou_threshold while (!sorted_boxes.empty() && static_cast(selected_boxes_inside_class.size()) < max_output_boxes_per_class) { - const BoxInfo& next_top_score = sorted_boxes.top(); + const BoxInfoPtr& next_top_score = sorted_boxes.top(); bool selected = true; // Check with existing selected boxes for this class, suppress if exceed the IOU (Intersection Over Union) threshold for (const auto& selected_index : selected_boxes_inside_class) { - if (next_top_score.SuppressByIOU(selected_index, iou_threshold)) { + if (SuppressByIOU(batch_boxes, next_top_score.index_, selected_index.index_, center_point_box, iou_threshold)) { selected = false; break; } diff --git a/onnxruntime/core/providers/cpu/object_detection/non_max_suppression_helper.h b/onnxruntime/core/providers/cpu/object_detection/non_max_suppression_helper.h index b010005643..6417c125fb 100644 --- a/onnxruntime/core/providers/cpu/object_detection/non_max_suppression_helper.h +++ b/onnxruntime/core/providers/cpu/object_detection/non_max_suppression_helper.h @@ -70,6 +70,10 @@ inline bool SuppressByIOU(const float* boxes_data, int64_t box_index1, int64_t b float y2_min{}; float x2_max{}; float y2_max{}; + float intersection_x_min{}; + float intersection_x_max{}; + float intersection_y_min{}; + float intersection_y_max{}; const float* box1 = boxes_data + 4 * box_index1; const float* box2 = boxes_data + 4 * box_index2; @@ -77,9 +81,19 @@ inline bool SuppressByIOU(const float* boxes_data, int64_t box_index1, int64_t b if (0 == center_point_box) { // boxes data format [y1, x1, y2, x2], MaxMin(box1[1], box1[3], x1_min, x1_max); - MaxMin(box1[0], box1[2], y1_min, y1_max); MaxMin(box2[1], box2[3], x2_min, x2_max); + + intersection_x_min = HelperMax(x1_min, x2_min); + intersection_x_max = HelperMin(x1_max, x2_max); + if (intersection_x_max <= intersection_x_min) + return false; + + MaxMin(box1[0], box1[2], y1_min, y1_max); MaxMin(box2[0], box2[2], y2_min, y2_max); + intersection_y_min = HelperMax(y1_min, y2_min); + intersection_y_max = HelperMin(y1_max, y2_max); + if (intersection_y_max <= intersection_y_min) + return false; } else { // 1 == center_point_box_ => boxes data format [x_center, y_center, width, height] float box1_width_half = box1[2] / 2; @@ -89,22 +103,27 @@ inline bool SuppressByIOU(const float* boxes_data, int64_t box_index1, int64_t b x1_min = box1[0] - box1_width_half; x1_max = box1[0] + box1_width_half; - y1_min = box1[1] - box1_height_half; - y1_max = box1[1] + box1_height_half; - x2_min = box2[0] - box2_width_half; x2_max = box2[0] + box2_width_half; + + intersection_x_min = HelperMax(x1_min, x2_min); + intersection_x_max = HelperMin(x1_max, x2_max); + if (intersection_x_max <= intersection_x_min) + return false; + + y1_min = box1[1] - box1_height_half; + y1_max = box1[1] + box1_height_half; y2_min = box2[1] - box2_height_half; y2_max = box2[1] + box2_height_half; + + intersection_y_min = HelperMax(y1_min, y2_min); + intersection_y_max = HelperMin(y1_max, y2_max); + if (intersection_y_max <= intersection_y_min) + return false; } - const float intersection_x_min = HelperMax(x1_min, x2_min); - const float intersection_y_min = HelperMax(y1_min, y2_min); - const float intersection_x_max = HelperMin(x1_max, x2_max); - const float intersection_y_max = HelperMin(y1_max, y2_max); - - const float intersection_area = HelperMax(intersection_x_max - intersection_x_min, .0f) * - HelperMax(intersection_y_max - intersection_y_min, .0f); + const float intersection_area = (intersection_x_max - intersection_x_min) * + (intersection_y_max - intersection_y_min); if (intersection_area <= .0f) { return false; @@ -123,7 +142,7 @@ inline bool SuppressByIOU(const float* boxes_data, int64_t box_index1, int64_t b return intersection_over_union > iou_threshold; } #ifdef __NVCC__ -} // namespace cuda +} // namespace cuda #endif -} // nms_helpers +} // namespace nms_helpers } // namespace onnxruntime