mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-19 21:32:23 +00:00
Improves NonMaxSuppression on CPU (#7557)
* improves non max suppression * use pointer instead of boxes
This commit is contained in:
parent
ade6ed51eb
commit
2f0479780e
2 changed files with 43 additions and 58 deletions
|
|
@ -132,62 +132,28 @@ Status NonMaxSuppression::Compute(OpKernelContext* ctx) const {
|
|||
const auto* const boxes_data = pc.boxes_data_;
|
||||
const auto* const scores_data = pc.scores_data_;
|
||||
|
||||
struct BoxInfo {
|
||||
struct BoxInfoPtr {
|
||||
float score_{};
|
||||
int64_t index_{};
|
||||
float box_[4]{};
|
||||
float area_{};
|
||||
|
||||
BoxInfo() = default;
|
||||
explicit BoxInfo(float score, int64_t idx, int64_t center_point_box, const float* box) : score_(score), index_(idx) {
|
||||
if (0 == center_point_box) {
|
||||
// boxes data format [y1, x1, y2, x2],
|
||||
MaxMin(box[1], box[3], box_[1], box_[3]);
|
||||
MaxMin(box[0], box[2], box_[0], box_[2]);
|
||||
} else {
|
||||
// boxes data format [x_center, y_center, width, height]
|
||||
float box_width_half = box[2] / 2;
|
||||
float box_height_half = box[3] / 2;
|
||||
box_[1] = box[0] - box_width_half;
|
||||
box_[3] = box[0] + box_width_half;
|
||||
box_[0] = box[1] - box_height_half;
|
||||
box_[2] = box[1] + box_height_half;
|
||||
}
|
||||
area_ = (box_[2] - box_[0]) * (box_[3] - box_[1]);
|
||||
}
|
||||
|
||||
inline bool operator<(const BoxInfo& rhs) const {
|
||||
BoxInfoPtr() = default;
|
||||
explicit BoxInfoPtr(float score, int64_t idx) : score_(score), index_(idx) {}
|
||||
inline bool operator<(const BoxInfoPtr& rhs) const {
|
||||
return score_ < rhs.score_ || (score_ == rhs.score_ && index_ > rhs.index_);
|
||||
}
|
||||
|
||||
inline bool SuppressByIOU(const BoxInfo& rhs, float iou_threshold) const {
|
||||
const float intersection_x_min = std::max(box_[1], rhs.box_[1]);
|
||||
const float intersection_y_min = std::max(box_[0], rhs.box_[0]);
|
||||
const float intersection_x_max = std::min(box_[3], rhs.box_[3]);
|
||||
const float intersection_y_max = std::min(box_[2], rhs.box_[2]);
|
||||
|
||||
const float intersection_area = std::max(intersection_x_max - intersection_x_min, .0f) *
|
||||
std::max(intersection_y_max - intersection_y_min, .0f);
|
||||
if (intersection_area <= .0f) {
|
||||
return false;
|
||||
}
|
||||
const float union_area = area_ + rhs.area_ - intersection_area;
|
||||
const float intersection_over_union = intersection_area / union_area;
|
||||
return intersection_over_union > iou_threshold;
|
||||
}
|
||||
};
|
||||
|
||||
const auto center_point_box = GetCenterPointBox();
|
||||
|
||||
std::vector<SelectedIndex> selected_indices;
|
||||
std::vector<BoxInfo> selected_boxes_inside_class;
|
||||
std::vector<BoxInfoPtr> selected_boxes_inside_class;
|
||||
selected_boxes_inside_class.reserve(std::min<size_t>(static_cast<size_t>(max_output_boxes_per_class), pc.num_boxes_));
|
||||
|
||||
for (int64_t batch_index = 0; batch_index < pc.num_batches_; ++batch_index) {
|
||||
for (int64_t class_index = 0; class_index < pc.num_classes_; ++class_index) {
|
||||
int64_t box_score_offset = (batch_index * pc.num_classes_ + class_index) * pc.num_boxes_;
|
||||
const float* batch_boxes = boxes_data + (batch_index * pc.num_boxes_ * 4);
|
||||
std::vector<BoxInfo> candidate_boxes;
|
||||
std::vector<BoxInfoPtr> candidate_boxes;
|
||||
candidate_boxes.reserve(pc.num_boxes_);
|
||||
|
||||
// Filter by score_threshold_
|
||||
|
|
@ -195,25 +161,25 @@ Status NonMaxSuppression::Compute(OpKernelContext* ctx) const {
|
|||
if (pc.score_threshold_ != nullptr) {
|
||||
for (int64_t box_index = 0; box_index < pc.num_boxes_; ++box_index, ++class_scores) {
|
||||
if (*class_scores > score_threshold) {
|
||||
candidate_boxes.emplace_back(*class_scores, box_index, center_point_box, batch_boxes + (box_index * 4));
|
||||
candidate_boxes.emplace_back(*class_scores, box_index);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int64_t box_index = 0; box_index < pc.num_boxes_; ++box_index, ++class_scores) {
|
||||
candidate_boxes.emplace_back(*class_scores, box_index, center_point_box, batch_boxes + (box_index * 4));
|
||||
candidate_boxes.emplace_back(*class_scores, box_index);
|
||||
}
|
||||
}
|
||||
std::priority_queue<BoxInfo, std::vector<BoxInfo>> sorted_boxes(std::less<BoxInfo>(), std::move(candidate_boxes));
|
||||
std::priority_queue<BoxInfoPtr, std::vector<BoxInfoPtr>> sorted_boxes(std::less<BoxInfoPtr>(), std::move(candidate_boxes));
|
||||
|
||||
selected_boxes_inside_class.clear();
|
||||
// Get the next box with top score, filter by iou_threshold
|
||||
while (!sorted_boxes.empty() && static_cast<int64_t>(selected_boxes_inside_class.size()) < max_output_boxes_per_class) {
|
||||
const BoxInfo& next_top_score = sorted_boxes.top();
|
||||
const BoxInfoPtr& next_top_score = sorted_boxes.top();
|
||||
|
||||
bool selected = true;
|
||||
// Check with existing selected boxes for this class, suppress if exceed the IOU (Intersection Over Union) threshold
|
||||
for (const auto& selected_index : selected_boxes_inside_class) {
|
||||
if (next_top_score.SuppressByIOU(selected_index, iou_threshold)) {
|
||||
if (SuppressByIOU(batch_boxes, next_top_score.index_, selected_index.index_, center_point_box, iou_threshold)) {
|
||||
selected = false;
|
||||
break;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -70,6 +70,10 @@ inline bool SuppressByIOU(const float* boxes_data, int64_t box_index1, int64_t b
|
|||
float y2_min{};
|
||||
float x2_max{};
|
||||
float y2_max{};
|
||||
float intersection_x_min{};
|
||||
float intersection_x_max{};
|
||||
float intersection_y_min{};
|
||||
float intersection_y_max{};
|
||||
|
||||
const float* box1 = boxes_data + 4 * box_index1;
|
||||
const float* box2 = boxes_data + 4 * box_index2;
|
||||
|
|
@ -77,9 +81,19 @@ inline bool SuppressByIOU(const float* boxes_data, int64_t box_index1, int64_t b
|
|||
if (0 == center_point_box) {
|
||||
// boxes data format [y1, x1, y2, x2],
|
||||
MaxMin(box1[1], box1[3], x1_min, x1_max);
|
||||
MaxMin(box1[0], box1[2], y1_min, y1_max);
|
||||
MaxMin(box2[1], box2[3], x2_min, x2_max);
|
||||
|
||||
intersection_x_min = HelperMax(x1_min, x2_min);
|
||||
intersection_x_max = HelperMin(x1_max, x2_max);
|
||||
if (intersection_x_max <= intersection_x_min)
|
||||
return false;
|
||||
|
||||
MaxMin(box1[0], box1[2], y1_min, y1_max);
|
||||
MaxMin(box2[0], box2[2], y2_min, y2_max);
|
||||
intersection_y_min = HelperMax(y1_min, y2_min);
|
||||
intersection_y_max = HelperMin(y1_max, y2_max);
|
||||
if (intersection_y_max <= intersection_y_min)
|
||||
return false;
|
||||
} else {
|
||||
// 1 == center_point_box_ => boxes data format [x_center, y_center, width, height]
|
||||
float box1_width_half = box1[2] / 2;
|
||||
|
|
@ -89,22 +103,27 @@ inline bool SuppressByIOU(const float* boxes_data, int64_t box_index1, int64_t b
|
|||
|
||||
x1_min = box1[0] - box1_width_half;
|
||||
x1_max = box1[0] + box1_width_half;
|
||||
y1_min = box1[1] - box1_height_half;
|
||||
y1_max = box1[1] + box1_height_half;
|
||||
|
||||
x2_min = box2[0] - box2_width_half;
|
||||
x2_max = box2[0] + box2_width_half;
|
||||
|
||||
intersection_x_min = HelperMax(x1_min, x2_min);
|
||||
intersection_x_max = HelperMin(x1_max, x2_max);
|
||||
if (intersection_x_max <= intersection_x_min)
|
||||
return false;
|
||||
|
||||
y1_min = box1[1] - box1_height_half;
|
||||
y1_max = box1[1] + box1_height_half;
|
||||
y2_min = box2[1] - box2_height_half;
|
||||
y2_max = box2[1] + box2_height_half;
|
||||
|
||||
intersection_y_min = HelperMax(y1_min, y2_min);
|
||||
intersection_y_max = HelperMin(y1_max, y2_max);
|
||||
if (intersection_y_max <= intersection_y_min)
|
||||
return false;
|
||||
}
|
||||
|
||||
const float intersection_x_min = HelperMax(x1_min, x2_min);
|
||||
const float intersection_y_min = HelperMax(y1_min, y2_min);
|
||||
const float intersection_x_max = HelperMin(x1_max, x2_max);
|
||||
const float intersection_y_max = HelperMin(y1_max, y2_max);
|
||||
|
||||
const float intersection_area = HelperMax(intersection_x_max - intersection_x_min, .0f) *
|
||||
HelperMax(intersection_y_max - intersection_y_min, .0f);
|
||||
const float intersection_area = (intersection_x_max - intersection_x_min) *
|
||||
(intersection_y_max - intersection_y_min);
|
||||
|
||||
if (intersection_area <= .0f) {
|
||||
return false;
|
||||
|
|
@ -123,7 +142,7 @@ inline bool SuppressByIOU(const float* boxes_data, int64_t box_index1, int64_t b
|
|||
return intersection_over_union > iou_threshold;
|
||||
}
|
||||
#ifdef __NVCC__
|
||||
} // namespace cuda
|
||||
} // namespace cuda
|
||||
#endif
|
||||
} // nms_helpers
|
||||
} // namespace nms_helpers
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue