diff --git a/onnxruntime/core/providers/cpu/object_detection/non_max_suppression.cc b/onnxruntime/core/providers/cpu/object_detection/non_max_suppression.cc index 04b28b4b2d..6608454781 100644 --- a/onnxruntime/core/providers/cpu/object_detection/non_max_suppression.cc +++ b/onnxruntime/core/providers/cpu/object_detection/non_max_suppression.cc @@ -12,6 +12,7 @@ limitations under the License. /* Modifications Copyright (c) Microsoft. */ #include "non_max_suppression.h" +#include "non_max_suppression_helper.h" #include namespace onnxruntime { @@ -24,128 +25,94 @@ ONNX_OPERATOR_KERNEL_EX( KernelDefBuilder(), NonMaxSuppression); -void NonMaxSuppression::MaxMin(const float& lhs, const float& rhs, float& min, float& max) const { - if (lhs >= rhs) { - min = rhs; - max = lhs; - } else { - min = lhs; - max = rhs; +using namespace nms_helpers; + +// CPU version +namespace nms_helpers { +Status GetThresholdsFromInputs(const PrepareContext& pc, + int64_t& max_output_boxes_per_class, + float& iou_threshold, + float& score_threshold) { + if (pc.max_output_boxes_per_class_ != nullptr) { + max_output_boxes_per_class = std::max(*pc.max_output_boxes_per_class_, 0); } + + if (pc.iou_threshold_ != nullptr) { + iou_threshold = *pc.iou_threshold_; + ORT_RETURN_IF_NOT((iou_threshold >= 0 && iou_threshold <= 1.f), "iou_threshold must be in range [0, 1]."); + } + + if (pc.score_threshold_ != nullptr) { + score_threshold = *pc.score_threshold_; + } + + return Status::OK(); } +} // namespace nms_helpers -bool NonMaxSuppression::SuppressByIOU(const float* boxes_data, int64_t box_index1, int64_t box_index2, float iou_threshold) const { - float x1_min; - float y1_min; - float x1_max; - float y1_max; - float x2_min; - float y2_min; - float x2_max; - float y2_max; - // center_point_box_ only support 0 or 1 - if (0 == center_point_box_) { - // boxes data format [y1, x1, y2, x2], - MaxMin(boxes_data[4 * box_index1 + 1], boxes_data[4 * box_index1 + 3], x1_min, x1_max); - MaxMin(boxes_data[4 * box_index1 + 0], boxes_data[4 * box_index1 + 2], y1_min, y1_max); - MaxMin(boxes_data[4 * box_index2 + 1], boxes_data[4 * box_index2 + 3], x2_min, x2_max); - MaxMin(boxes_data[4 * box_index2 + 0], boxes_data[4 * box_index2 + 2], y2_min, y2_max); - } else { - // 1 == center_point_box_ => boxes data format [x_center, y_center, width, height] - float box1_width_half = boxes_data[4 * box_index1 + 2] / 2; - float box1_height_half = boxes_data[4 * box_index1 + 3] / 2; - float box2_width_half = boxes_data[4 * box_index2 + 2] / 2; - float box2_height_half = boxes_data[4 * box_index2 + 3] / 2; +Status NonMaxSuppressionBase::PrepareCompute(OpKernelContext* ctx, PrepareContext& pc) { + const auto* boxes_tensor = ctx->Input(0); + ORT_ENFORCE(boxes_tensor); + pc.boxes_data_ = boxes_tensor->Data(); - x1_min = boxes_data[4 * box_index1 + 0] - box1_width_half; - x1_max = boxes_data[4 * box_index1 + 0] + box1_width_half; - y1_min = boxes_data[4 * box_index1 + 1] - box1_height_half; - y1_max = boxes_data[4 * box_index1 + 1] + box1_height_half; + const auto* scores_tensor = ctx->Input(1); + ORT_ENFORCE(scores_tensor); + pc.scores_data_ = scores_tensor->Data(); - x2_min = boxes_data[4 * box_index2 + 0] - box2_width_half; - x2_max = boxes_data[4 * box_index2 + 0] + box2_width_half; - y2_min = boxes_data[4 * box_index2 + 1] - box2_height_half; - y2_max = boxes_data[4 * box_index2 + 1] + box2_height_half; + const auto num_inputs = ctx->InputCount(); + + if (num_inputs > 2) { + const auto* max_output_boxes_per_class_tensor = ctx->Input(2); + if (max_output_boxes_per_class_tensor != nullptr) { + pc.max_output_boxes_per_class_ = max_output_boxes_per_class_tensor->Data(); + } } - const float intersection_x_min = std::max(x1_min, x2_min); - const float intersection_y_min = std::max(y1_min, y2_min); - const float intersection_x_max = std::min(x1_max, x2_max); - const float intersection_y_max = std::min(y1_max, y2_max); - - const float intersection_area = std::max(intersection_x_max - intersection_x_min, static_cast(0.0)) * - std::max(intersection_y_max - intersection_y_min, static_cast(0.0)); - - if (intersection_area <= static_cast(0.0)) { - return false; + if (num_inputs > 3) { + const auto* iou_threshold_tensor = ctx->Input(3); + if (iou_threshold_tensor != nullptr) { + pc.iou_threshold_ = iou_threshold_tensor->Data(); + } } - const float area1 = (x1_max - x1_min) * (y1_max - y1_min); - const float area2 = (x2_max - x2_min) * (y2_max - y2_min); - const float union_area = area1 + area2 - intersection_area; - - if (area1 <= static_cast(0.0) || area2 <= static_cast(0.0) || union_area <= static_cast(0.0)) { - return false; + if (num_inputs > 4) { + const auto* score_threshold_tensor = ctx->Input(4); + if (score_threshold_tensor != nullptr) { + pc.score_threshold_ = score_threshold_tensor->Data(); + } } - const float intersection_over_union = intersection_area / union_area; + const auto& boxes_shape = boxes_tensor->Shape(); + pc.boxes_size_ = boxes_shape.Size(); + const auto& scores_shape = scores_tensor->Shape(); + pc.scores_size_ = scores_shape.Size(); - return intersection_over_union > iou_threshold; -} - -Status NonMaxSuppression::ParepareCompute(OpKernelContext* ctx, const TensorShape& boxes_shape, const TensorShape& scores_shape, - int64_t& max_output_boxes_per_class, float& iou_threshold, float& score_threshold, bool& has_score_threshold) const { ORT_RETURN_IF_NOT(boxes_shape.NumDimensions() == 3, "boxes must be a 3D tensor."); ORT_RETURN_IF_NOT(scores_shape.NumDimensions() == 3, "scores must be a 3D tensor."); auto boxes_dims = boxes_shape.GetDims(); auto scores_dims = scores_shape.GetDims(); ORT_RETURN_IF_NOT(boxes_dims[0] == scores_dims[0], "boxes and scores should have same num_batches."); - ORT_RETURN_IF_NOT(boxes_dims[1] == scores_dims[2], "boxes and scores should have same spatial_dimention."); + ORT_RETURN_IF_NOT(boxes_dims[1] == scores_dims[2], "boxes and scores should have same spatial_dimension."); ORT_RETURN_IF_NOT(boxes_dims[2] == 4, "The most inner dimension in boxes must have 4 data."); - const_cast(num_batches_) = boxes_dims[0]; - const_cast(num_classes_) = scores_dims[1]; - const_cast(num_boxes_) = boxes_dims[1]; - - const auto* max_output_boxes_per_class_tensor = ctx->Input(2); - if (max_output_boxes_per_class_tensor != nullptr) { - max_output_boxes_per_class = *(max_output_boxes_per_class_tensor->Data()); - max_output_boxes_per_class = max_output_boxes_per_class > 0 ? max_output_boxes_per_class : 0; - } - - const auto* iou_threshold_tensor = ctx->Input(3); - if (iou_threshold_tensor != nullptr) { - iou_threshold = *(iou_threshold_tensor->Data()); - ORT_RETURN_IF_NOT((iou_threshold >= 0 && iou_threshold <= 1), "iou_threshold must be in range [0, 1]."); - } - - const auto* score_threshold_tensor = ctx->Input(4); - if (score_threshold_tensor != nullptr) { - has_score_threshold = true; - score_threshold = *(score_threshold_tensor->Data()); - } + pc.num_batches_ = boxes_dims[0]; + pc.num_classes_ = scores_dims[1]; + pc.num_boxes_ = boxes_dims[1]; return Status::OK(); } Status NonMaxSuppression::Compute(OpKernelContext* ctx) const { - const auto* boxes = ctx->Input(0); - ORT_ENFORCE(boxes); - const auto* scores = ctx->Input(1); - ORT_ENFORCE(scores); - - auto& boxes_shape = boxes->Shape(); - auto& scores_shape = scores->Shape(); + PrepareContext pc; + auto ret = PrepareCompute(ctx, pc); + ORT_RETURN_IF_NOT(ret.IsOK(), ret.ErrorMessage()); int64_t max_output_boxes_per_class = 0; - float iou_threshold = 0; - // Not so sure for the value range of score_threshold, so set a bool to indicate whether it has this input - bool has_score_threshold = false; - float score_threshold = 0; + float iou_threshold = .0f; + float score_threshold = .0f; - auto ret = ParepareCompute(ctx, boxes_shape, scores_shape, max_output_boxes_per_class, - iou_threshold, score_threshold, has_score_threshold); + ret = GetThresholdsFromInputs(pc, max_output_boxes_per_class, iou_threshold, score_threshold); ORT_RETURN_IF_NOT(ret.IsOK(), ret.ErrorMessage()); if (0 == max_output_boxes_per_class) { @@ -153,63 +120,78 @@ Status NonMaxSuppression::Compute(OpKernelContext* ctx) const { return Status::OK(); } - const auto* boxes_data = boxes->Data(); - const auto* scores_data = scores->Data(); + const auto* const boxes_data = pc.boxes_data_; + const auto* const scores_data = pc.scores_data_; struct ScoreIndexPair { - float score; - int64_t index; + float score_{}; + int64_t index_{}; + + ScoreIndexPair() = default; + explicit ScoreIndexPair(float score, int64_t idx) : score_(score), index_(idx) {} + + bool operator<(const ScoreIndexPair& rhs) const { + return score_ < rhs.score_; + } }; - auto LessCompare = [](const ScoreIndexPair& lhs, const ScoreIndexPair& rhs) { - return lhs.score < rhs.score; - }; + const auto center_point_box = GetCenterPointBox(); - std::vector tmp_selected_indices; - for (int64_t batch_index = 0; batch_index < num_batches_; ++batch_index) { - for (int64_t class_index = 0; class_index < num_classes_; ++class_index) { - int64_t box_score_offset = (batch_index * num_classes_ + class_index) * num_boxes_; - int64_t box_offset = batch_index * num_classes_ * num_boxes_ * 4; + std::vector selected_indices; + for (int64_t batch_index = 0; batch_index < pc.num_batches_; ++batch_index) { + for (int64_t class_index = 0; class_index < pc.num_classes_; ++class_index) { + int64_t box_score_offset = (batch_index * pc.num_classes_ + class_index) * pc.num_boxes_; + int64_t box_offset = batch_index * pc.num_classes_ * pc.num_boxes_ * 4; // Filter by score_threshold_ - std::priority_queue, decltype(LessCompare)> sorted_scores_with_index(LessCompare); - for (int64_t box_index = 0; box_index < num_boxes_; ++box_index) { - if (!has_score_threshold || (has_score_threshold && scores_data[box_score_offset + box_index] > score_threshold)) { - sorted_scores_with_index.emplace(ScoreIndexPair({scores_data[box_score_offset + box_index], box_index})); + std::priority_queue> sorted_scores_with_index; + const auto* class_scores = scores_data + box_score_offset; + if (pc.score_threshold_ != nullptr) { + for (int64_t box_index = 0; box_index < pc.num_boxes_; ++box_index, ++class_scores) { + if (*class_scores > score_threshold) { + sorted_scores_with_index.push(ScoreIndexPair(*class_scores, box_index)); + } + } + } else { + for (int64_t box_index = 0; box_index < pc.num_boxes_; ++box_index, ++class_scores) { + sorted_scores_with_index.push(ScoreIndexPair(*class_scores, box_index)); } } ScoreIndexPair next_top_score; std::vector selected_indicies_inside_class; - // Get the next box with top score, filter by iou_threshold_ + // Get the next box with top score, filter by iou_threshold while (!sorted_scores_with_index.empty()) { next_top_score = sorted_scores_with_index.top(); sorted_scores_with_index.pop(); bool selected = true; // Check with existing selected boxes for this class, suppress if exceed the IOU (Intersection Over Union) threshold - for (int64_t selected_indicies_inside_clas : selected_indicies_inside_class) { - if (SuppressByIOU(boxes_data + box_offset, selected_indicies_inside_clas, next_top_score.index, - iou_threshold)) { + for (int64_t selected_index : selected_indicies_inside_class) { + if (SuppressByIOU(boxes_data + box_offset, selected_index, next_top_score.index_, + center_point_box, iou_threshold)) { selected = false; break; } } if (selected) { - if (max_output_boxes_per_class > 0 && static_cast(selected_indicies_inside_class.size()) >= max_output_boxes_per_class) { + if (max_output_boxes_per_class > 0 && + static_cast(selected_indicies_inside_class.size()) >= max_output_boxes_per_class) { break; } - selected_indicies_inside_class.push_back(next_top_score.index); - tmp_selected_indices.emplace_back(batch_index, class_index, next_top_score.index); + selected_indicies_inside_class.push_back(next_top_score.index_); + selected_indices.emplace_back(batch_index, class_index, next_top_score.index_); } } //while } //for class_index } //for batch_index - auto num_selected = static_cast(tmp_selected_indices.size()); - Tensor* selected_indices = ctx->Output(0, {num_selected, 3}); - ORT_ENFORCE(selected_indices); - memcpy(selected_indices->MutableData(), tmp_selected_indices.data(), num_selected * sizeof(selected_index)); + const auto last_dim = 3; + const auto num_selected = selected_indices.size(); + Tensor* output = ctx->Output(0, {static_cast(num_selected), last_dim}); + ORT_ENFORCE(output != nullptr); + static_assert(last_dim * sizeof(int64_t) == sizeof(SelectedIndex), "Possible modification of SelectedIndex"); + memcpy(output->MutableData(), selected_indices.data(), num_selected * sizeof(SelectedIndex)); return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/object_detection/non_max_suppression.h b/onnxruntime/core/providers/cpu/object_detection/non_max_suppression.h index d384628013..37578539cd 100644 --- a/onnxruntime/core/providers/cpu/object_detection/non_max_suppression.h +++ b/onnxruntime/core/providers/cpu/object_detection/non_max_suppression.h @@ -8,37 +8,30 @@ namespace onnxruntime { -class NonMaxSuppression final : public OpKernel { - public: - NonMaxSuppression(const OpKernelInfo& info) : OpKernel(info) { +struct PrepareContext; + +class NonMaxSuppressionBase { + protected: + explicit NonMaxSuppressionBase(const OpKernelInfo& info) { center_point_box_ = info.GetAttrOrDefault("center_point_box", 0); ORT_ENFORCE(0 == center_point_box_ || 1 == center_point_box_, "center_point_box only support 0 or 1"); - num_batches_ = 0; - num_classes_ = 0; - num_boxes_ = 0; } - Status Compute(OpKernelContext* context) const override; + static Status PrepareCompute(OpKernelContext* ctx, PrepareContext& pc); - private: - bool SuppressByIOU(const float* boxes_data, int64_t box_index1, int64_t box_index2, float iou_threshold) const; - void MaxMin(const float& lhs, const float& rhs, float& min, float& max) const; - Status ParepareCompute(OpKernelContext* ctx, const TensorShape& boxes_shape, const TensorShape& scores_shape, - int64_t& max_output_boxes_per_batch, float& iou_threshold, float& score_threshold, bool& has_score_threshold) const; + int64_t GetCenterPointBox() const { + return center_point_box_; + } private: int64_t center_point_box_; +}; - int64_t num_batches_; - int64_t num_classes_; - int64_t num_boxes_; +class NonMaxSuppression final : public OpKernel, public NonMaxSuppressionBase { + public: + explicit NonMaxSuppression(const OpKernelInfo& info) : OpKernel(info), NonMaxSuppressionBase(info) { + } - struct selected_index { - selected_index(int64_t batch_index, int64_t class_index, int64_t box_index) - : batch_index_(batch_index), class_index_(class_index), box_index_(box_index) {} - int64_t batch_index_ = 0; - int64_t class_index_ = 0; - int64_t box_index_ = 0; - }; + Status Compute(OpKernelContext* context) const override; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/object_detection/non_max_suppression_helper.h b/onnxruntime/core/providers/cpu/object_detection/non_max_suppression_helper.h new file mode 100644 index 0000000000..40b0f29618 --- /dev/null +++ b/onnxruntime/core/providers/cpu/object_detection/non_max_suppression_helper.h @@ -0,0 +1,129 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#ifdef __NVCC__ +#include "core/providers/cuda/cu_inc/common.cuh" +#define ORT_DEVICE __device__ +#define HelperMin(a, b) _Min(a, b) +#define HelperMax(a, b) _Max(a, b) +#else +#include +#define ORT_DEVICE +#define HelperMin(a, b) std::min(a, b) +#define HelperMax(a, b) std::max(a, b) +#endif + +namespace onnxruntime { + +struct PrepareContext { + const float* boxes_data_ = nullptr; + int64_t boxes_size_ = 0ll; + const float* scores_data_ = nullptr; + int64_t scores_size_ = 0ll; + // The below are ptrs since they cab be device specific + const int64_t* max_output_boxes_per_class_ = nullptr; + const float* score_threshold_ = nullptr; + const float* iou_threshold_ = nullptr; + int64_t num_batches_ = 0; + int64_t num_classes_ = 0; + int64_t num_boxes_ = 0; +}; + +struct SelectedIndex { + ORT_DEVICE + SelectedIndex(int64_t batch_index, int64_t class_index, int64_t box_index) + : batch_index_(batch_index), class_index_(class_index), box_index_(box_index) {} + SelectedIndex() = default; + int64_t batch_index_ = 0; + int64_t class_index_ = 0; + int64_t box_index_ = 0; +}; + +#ifdef __NVCC__ +namespace cuda { +#endif +namespace nms_helpers { + +ORT_DEVICE +inline void MaxMin(float lhs, float rhs, float& min, float& max) { + if (lhs >= rhs) { + min = rhs; + max = lhs; + } else { + min = lhs; + max = rhs; + } +} + +ORT_DEVICE +inline bool SuppressByIOU(const float* boxes_data, int64_t box_index1, int64_t box_index2, + int64_t center_point_box, float iou_threshold) { + float x1_min{}; + float y1_min{}; + float x1_max{}; + float y1_max{}; + float x2_min{}; + float y2_min{}; + float x2_max{}; + float y2_max{}; + + const float* box1 = boxes_data + 4 * box_index1; + const float* box2 = boxes_data + 4 * box_index2; + // center_point_box_ only support 0 or 1 + if (0 == center_point_box) { + // boxes data format [y1, x1, y2, x2], + MaxMin(box1[1], box1[3], x1_min, x1_max); + MaxMin(box1[0], box1[2], y1_min, y1_max); + MaxMin(box2[1], box2[3], x2_min, x2_max); + MaxMin(box2[0], box2[2], y2_min, y2_max); + } else { + // 1 == center_point_box_ => boxes data format [x_center, y_center, width, height] + float box1_width_half = box1[2] / 2; + float box1_height_half = box1[3] / 2; + float box2_width_half = box2[2] / 2; + float box2_height_half = box2[3] / 2; + + x1_min = box1[0] - box1_width_half; + x1_max = box1[0] + box1_width_half; + y1_min = box1[1] - box1_height_half; + y1_max = box1[1] + box1_height_half; + + x2_min = box2[0] - box2_width_half; + x2_max = box2[0] + box2_width_half; + y2_min = box2[1] - box2_height_half; + y2_max = box2[1] + box2_height_half; + } + + const float intersection_x_min = HelperMax(x1_min, x2_min); + const float intersection_y_min = HelperMax(y1_min, y2_min); + const float intersection_x_max = HelperMin(x1_max, x2_max); + const float intersection_y_max = HelperMin(y1_max, y2_max); + + const float intersection_area = HelperMax(intersection_x_max - intersection_x_min, .0f) * + HelperMax(intersection_y_max - intersection_y_min, .0f); + + if (intersection_area <= .0f) { + return false; + } + + const float area1 = (x1_max - x1_min) * (y1_max - y1_min); + const float area2 = (x2_max - x2_min) * (y2_max - y2_min); + const float union_area = area1 + area2 - intersection_area; + + if (area1 <= .0f || area2 <= .0f || union_area <= .0f) { + return false; + } + + const float intersection_over_union = intersection_area / union_area; + + return intersection_over_union > iou_threshold; +} +#ifdef __NVCC__ +} // namespace cuda +#endif +} // nms_helpers +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/nn/non_max_suppression_test.cc b/onnxruntime/test/providers/cpu/nn/non_max_suppression_test.cc index e309bedaa6..9675612b7e 100644 --- a/onnxruntime/test/providers/cpu/nn/non_max_suppression_test.cc +++ b/onnxruntime/test/providers/cpu/nn/non_max_suppression_test.cc @@ -267,7 +267,7 @@ TEST(NonMaxSuppressionOpTest, InconsistentBoxAndScoreShapes) { test.AddInput("iou_threshold", {}, {0.5f}); test.AddInput("score_threshold", {}, {0.0f}); test.AddOutput("selected_indices", {0, 3}, {}); - test.Run(OpTester::ExpectResult::kExpectFailure, "boxes and scores should have same spatial_dimention."); + test.Run(OpTester::ExpectResult::kExpectFailure, "boxes and scores should have same spatial_dimension."); } TEST(NonMaxSuppressionOpTest, InvalidIOUThreshold) {