Zhalei/optimize nms (#4875)

* double the speed of non_max_suppression for cpu.

* handle edge case in test case.
This commit is contained in:
Zhang Lei 2020-08-31 23:33:54 -07:00 committed by GitHub
parent cf1b74396a
commit 464bbd27a9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -14,6 +14,7 @@ limitations under the License.
#include "non_max_suppression.h"
#include "non_max_suppression_helper.h"
#include <queue>
#include <utility>
namespace onnxruntime {
@ -130,65 +131,98 @@ Status NonMaxSuppression::Compute(OpKernelContext* ctx) const {
const auto* const boxes_data = pc.boxes_data_;
const auto* const scores_data = pc.scores_data_;
struct ScoreIndexPair {
struct BoxInfo {
float score_{};
int64_t index_{};
float box_[4]{};
float area_{};
ScoreIndexPair() = default;
explicit ScoreIndexPair(float score, int64_t idx) : score_(score), index_(idx) {}
BoxInfo() = default;
explicit BoxInfo(float score, int64_t idx, int64_t center_point_box, const float* box) : score_(score), index_(idx) {
if (0 == center_point_box) {
// boxes data format [y1, x1, y2, x2],
MaxMin(box[1], box[3], box_[1], box_[3]);
MaxMin(box[0], box[2], box_[0], box_[2]);
} else {
// boxes data format [x_center, y_center, width, height]
float box_width_half = box[2] / 2;
float box_height_half = box[3] / 2;
box_[1] = box[0] - box_width_half;
box_[3] = box[0] + box_width_half;
box_[0] = box[1] - box_height_half;
box_[2] = box[1] + box_height_half;
}
area_ = (box_[2] - box_[0]) * (box_[3] - box_[1]);
}
bool operator<(const ScoreIndexPair& rhs) const {
return score_ < rhs.score_;
inline bool operator<(const BoxInfo& rhs) const {
return score_ < rhs.score_ || (score_ == rhs.score_ && index_ > rhs.index_);
}
inline bool SuppressByIOU(const BoxInfo& rhs, float iou_threshold) const {
const float intersection_x_min = std::max(box_[1], rhs.box_[1]);
const float intersection_y_min = std::max(box_[0], rhs.box_[0]);
const float intersection_x_max = std::min(box_[3], rhs.box_[3]);
const float intersection_y_max = std::min(box_[2], rhs.box_[2]);
const float intersection_area = std::max(intersection_x_max - intersection_x_min, .0f) *
std::max(intersection_y_max - intersection_y_min, .0f);
if (intersection_area <= .0f) {
return false;
}
const float union_area = area_ + rhs.area_ - intersection_area;
const float intersection_over_union = intersection_area / union_area;
return intersection_over_union > iou_threshold;
}
};
const auto center_point_box = GetCenterPointBox();
std::vector<SelectedIndex> selected_indices;
std::vector<BoxInfo> selected_boxes_inside_class;
selected_boxes_inside_class.reserve(std::min(max_output_boxes_per_class, pc.num_boxes_));
for (int64_t batch_index = 0; batch_index < pc.num_batches_; ++batch_index) {
for (int64_t class_index = 0; class_index < pc.num_classes_; ++class_index) {
int64_t box_score_offset = (batch_index * pc.num_classes_ + class_index) * pc.num_boxes_;
int64_t box_offset = batch_index * pc.num_boxes_ * 4;
const float* batch_boxes = boxes_data + (batch_index * pc.num_boxes_ * 4);
std::vector<BoxInfo> candidate_boxes;
candidate_boxes.reserve(pc.num_boxes_);
// Filter by score_threshold_
std::priority_queue<ScoreIndexPair, std::deque<ScoreIndexPair>> sorted_scores_with_index;
const auto* class_scores = scores_data + box_score_offset;
if (pc.score_threshold_ != nullptr) {
for (int64_t box_index = 0; box_index < pc.num_boxes_; ++box_index, ++class_scores) {
if (*class_scores > score_threshold) {
sorted_scores_with_index.push(ScoreIndexPair(*class_scores, box_index));
candidate_boxes.emplace_back(*class_scores, box_index, center_point_box, batch_boxes + (box_index * 4));
}
}
} else {
for (int64_t box_index = 0; box_index < pc.num_boxes_; ++box_index, ++class_scores) {
sorted_scores_with_index.push(ScoreIndexPair(*class_scores, box_index));
candidate_boxes.emplace_back(*class_scores, box_index, center_point_box, batch_boxes + (box_index * 4));
}
}
std::priority_queue<BoxInfo, std::vector<BoxInfo>> sorted_boxes(std::less<BoxInfo>(), std::move(candidate_boxes));
ScoreIndexPair next_top_score;
std::vector<int64_t> selected_indices_inside_class;
selected_boxes_inside_class.clear();
// Get the next box with top score, filter by iou_threshold
while (!sorted_scores_with_index.empty()) {
next_top_score = sorted_scores_with_index.top();
sorted_scores_with_index.pop();
while (!sorted_boxes.empty() && static_cast<int64_t>(selected_boxes_inside_class.size()) < max_output_boxes_per_class) {
const BoxInfo& next_top_score = sorted_boxes.top();
bool selected = true;
// Check with existing selected boxes for this class, suppress if exceed the IOU (Intersection Over Union) threshold
for (int64_t selected_index : selected_indices_inside_class) {
if (SuppressByIOU(boxes_data + box_offset, selected_index, next_top_score.index_,
center_point_box, iou_threshold)) {
for (const auto& selected_index : selected_boxes_inside_class) {
if (next_top_score.SuppressByIOU(selected_index, iou_threshold)) {
selected = false;
break;
}
}
if (selected) {
if (max_output_boxes_per_class > 0 &&
static_cast<int64_t>(selected_indices_inside_class.size()) >= max_output_boxes_per_class) {
break;
}
selected_indices_inside_class.push_back(next_top_score.index_);
selected_boxes_inside_class.push_back(next_top_score);
selected_indices.emplace_back(batch_index, class_index, next_top_score.index_);
}
sorted_boxes.pop();
} //while
} //for class_index
} //for batch_index