diff --git a/onnxruntime/core/providers/cpu/object_detection/non_max_suppression.cc b/onnxruntime/core/providers/cpu/object_detection/non_max_suppression.cc index 449bcd407a..5698aa5381 100644 --- a/onnxruntime/core/providers/cpu/object_detection/non_max_suppression.cc +++ b/onnxruntime/core/providers/cpu/object_detection/non_max_suppression.cc @@ -14,6 +14,7 @@ limitations under the License. #include "non_max_suppression.h" #include "non_max_suppression_helper.h" #include +#include namespace onnxruntime { @@ -130,65 +131,98 @@ Status NonMaxSuppression::Compute(OpKernelContext* ctx) const { const auto* const boxes_data = pc.boxes_data_; const auto* const scores_data = pc.scores_data_; - struct ScoreIndexPair { + struct BoxInfo { float score_{}; int64_t index_{}; + float box_[4]{}; + float area_{}; - ScoreIndexPair() = default; - explicit ScoreIndexPair(float score, int64_t idx) : score_(score), index_(idx) {} + BoxInfo() = default; + explicit BoxInfo(float score, int64_t idx, int64_t center_point_box, const float* box) : score_(score), index_(idx) { + if (0 == center_point_box) { + // boxes data format [y1, x1, y2, x2], + MaxMin(box[1], box[3], box_[1], box_[3]); + MaxMin(box[0], box[2], box_[0], box_[2]); + } else { + // boxes data format [x_center, y_center, width, height] + float box_width_half = box[2] / 2; + float box_height_half = box[3] / 2; + box_[1] = box[0] - box_width_half; + box_[3] = box[0] + box_width_half; + box_[0] = box[1] - box_height_half; + box_[2] = box[1] + box_height_half; + } + area_ = (box_[2] - box_[0]) * (box_[3] - box_[1]); + } - bool operator<(const ScoreIndexPair& rhs) const { - return score_ < rhs.score_; + inline bool operator<(const BoxInfo& rhs) const { + return score_ < rhs.score_ || (score_ == rhs.score_ && index_ > rhs.index_); + } + + inline bool SuppressByIOU(const BoxInfo& rhs, float iou_threshold) const { + const float intersection_x_min = std::max(box_[1], rhs.box_[1]); + const float intersection_y_min = std::max(box_[0], rhs.box_[0]); + const float intersection_x_max = std::min(box_[3], rhs.box_[3]); + const float intersection_y_max = std::min(box_[2], rhs.box_[2]); + + const float intersection_area = std::max(intersection_x_max - intersection_x_min, .0f) * + std::max(intersection_y_max - intersection_y_min, .0f); + if (intersection_area <= .0f) { + return false; + } + const float union_area = area_ + rhs.area_ - intersection_area; + const float intersection_over_union = intersection_area / union_area; + return intersection_over_union > iou_threshold; } }; const auto center_point_box = GetCenterPointBox(); std::vector selected_indices; + std::vector selected_boxes_inside_class; + selected_boxes_inside_class.reserve(std::min(max_output_boxes_per_class, pc.num_boxes_)); + for (int64_t batch_index = 0; batch_index < pc.num_batches_; ++batch_index) { for (int64_t class_index = 0; class_index < pc.num_classes_; ++class_index) { int64_t box_score_offset = (batch_index * pc.num_classes_ + class_index) * pc.num_boxes_; - int64_t box_offset = batch_index * pc.num_boxes_ * 4; + const float* batch_boxes = boxes_data + (batch_index * pc.num_boxes_ * 4); + std::vector candidate_boxes; + candidate_boxes.reserve(pc.num_boxes_); + // Filter by score_threshold_ - std::priority_queue> sorted_scores_with_index; const auto* class_scores = scores_data + box_score_offset; if (pc.score_threshold_ != nullptr) { for (int64_t box_index = 0; box_index < pc.num_boxes_; ++box_index, ++class_scores) { if (*class_scores > score_threshold) { - sorted_scores_with_index.push(ScoreIndexPair(*class_scores, box_index)); + candidate_boxes.emplace_back(*class_scores, box_index, center_point_box, batch_boxes + (box_index * 4)); } } } else { for (int64_t box_index = 0; box_index < pc.num_boxes_; ++box_index, ++class_scores) { - sorted_scores_with_index.push(ScoreIndexPair(*class_scores, box_index)); + candidate_boxes.emplace_back(*class_scores, box_index, center_point_box, batch_boxes + (box_index * 4)); } } + std::priority_queue> sorted_boxes(std::less(), std::move(candidate_boxes)); - ScoreIndexPair next_top_score; - std::vector selected_indices_inside_class; + selected_boxes_inside_class.clear(); // Get the next box with top score, filter by iou_threshold - while (!sorted_scores_with_index.empty()) { - next_top_score = sorted_scores_with_index.top(); - sorted_scores_with_index.pop(); + while (!sorted_boxes.empty() && static_cast(selected_boxes_inside_class.size()) < max_output_boxes_per_class) { + const BoxInfo& next_top_score = sorted_boxes.top(); bool selected = true; // Check with existing selected boxes for this class, suppress if exceed the IOU (Intersection Over Union) threshold - for (int64_t selected_index : selected_indices_inside_class) { - if (SuppressByIOU(boxes_data + box_offset, selected_index, next_top_score.index_, - center_point_box, iou_threshold)) { + for (const auto& selected_index : selected_boxes_inside_class) { + if (next_top_score.SuppressByIOU(selected_index, iou_threshold)) { selected = false; break; } } if (selected) { - if (max_output_boxes_per_class > 0 && - static_cast(selected_indices_inside_class.size()) >= max_output_boxes_per_class) { - break; - } - selected_indices_inside_class.push_back(next_top_score.index_); + selected_boxes_inside_class.push_back(next_top_score); selected_indices.emplace_back(batch_index, class_index, next_top_score.index_); } + sorted_boxes.pop(); } //while } //for class_index } //for batch_index