mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-30 23:18:20 +00:00
Fix NMS const_cast that modified kernel state creating (#1303)
* Fix NMS const_cast that modified kernel state creating thread safety issue. Re-factor for future CUDA implementation.
This commit is contained in:
parent
04d581995d
commit
2f698bd54b
4 changed files with 248 additions and 144 deletions
|
|
@ -12,6 +12,7 @@ limitations under the License.
|
|||
/* Modifications Copyright (c) Microsoft. */
|
||||
|
||||
#include "non_max_suppression.h"
|
||||
#include "non_max_suppression_helper.h"
|
||||
#include <queue>
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -24,128 +25,94 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
KernelDefBuilder(),
|
||||
NonMaxSuppression);
|
||||
|
||||
void NonMaxSuppression::MaxMin(const float& lhs, const float& rhs, float& min, float& max) const {
|
||||
if (lhs >= rhs) {
|
||||
min = rhs;
|
||||
max = lhs;
|
||||
} else {
|
||||
min = lhs;
|
||||
max = rhs;
|
||||
using namespace nms_helpers;
|
||||
|
||||
// CPU version
|
||||
namespace nms_helpers {
|
||||
Status GetThresholdsFromInputs(const PrepareContext& pc,
|
||||
int64_t& max_output_boxes_per_class,
|
||||
float& iou_threshold,
|
||||
float& score_threshold) {
|
||||
if (pc.max_output_boxes_per_class_ != nullptr) {
|
||||
max_output_boxes_per_class = std::max<int64_t>(*pc.max_output_boxes_per_class_, 0);
|
||||
}
|
||||
|
||||
if (pc.iou_threshold_ != nullptr) {
|
||||
iou_threshold = *pc.iou_threshold_;
|
||||
ORT_RETURN_IF_NOT((iou_threshold >= 0 && iou_threshold <= 1.f), "iou_threshold must be in range [0, 1].");
|
||||
}
|
||||
|
||||
if (pc.score_threshold_ != nullptr) {
|
||||
score_threshold = *pc.score_threshold_;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace nms_helpers
|
||||
|
||||
bool NonMaxSuppression::SuppressByIOU(const float* boxes_data, int64_t box_index1, int64_t box_index2, float iou_threshold) const {
|
||||
float x1_min;
|
||||
float y1_min;
|
||||
float x1_max;
|
||||
float y1_max;
|
||||
float x2_min;
|
||||
float y2_min;
|
||||
float x2_max;
|
||||
float y2_max;
|
||||
// center_point_box_ only support 0 or 1
|
||||
if (0 == center_point_box_) {
|
||||
// boxes data format [y1, x1, y2, x2],
|
||||
MaxMin(boxes_data[4 * box_index1 + 1], boxes_data[4 * box_index1 + 3], x1_min, x1_max);
|
||||
MaxMin(boxes_data[4 * box_index1 + 0], boxes_data[4 * box_index1 + 2], y1_min, y1_max);
|
||||
MaxMin(boxes_data[4 * box_index2 + 1], boxes_data[4 * box_index2 + 3], x2_min, x2_max);
|
||||
MaxMin(boxes_data[4 * box_index2 + 0], boxes_data[4 * box_index2 + 2], y2_min, y2_max);
|
||||
} else {
|
||||
// 1 == center_point_box_ => boxes data format [x_center, y_center, width, height]
|
||||
float box1_width_half = boxes_data[4 * box_index1 + 2] / 2;
|
||||
float box1_height_half = boxes_data[4 * box_index1 + 3] / 2;
|
||||
float box2_width_half = boxes_data[4 * box_index2 + 2] / 2;
|
||||
float box2_height_half = boxes_data[4 * box_index2 + 3] / 2;
|
||||
Status NonMaxSuppressionBase::PrepareCompute(OpKernelContext* ctx, PrepareContext& pc) {
|
||||
const auto* boxes_tensor = ctx->Input<Tensor>(0);
|
||||
ORT_ENFORCE(boxes_tensor);
|
||||
pc.boxes_data_ = boxes_tensor->Data<float>();
|
||||
|
||||
x1_min = boxes_data[4 * box_index1 + 0] - box1_width_half;
|
||||
x1_max = boxes_data[4 * box_index1 + 0] + box1_width_half;
|
||||
y1_min = boxes_data[4 * box_index1 + 1] - box1_height_half;
|
||||
y1_max = boxes_data[4 * box_index1 + 1] + box1_height_half;
|
||||
const auto* scores_tensor = ctx->Input<Tensor>(1);
|
||||
ORT_ENFORCE(scores_tensor);
|
||||
pc.scores_data_ = scores_tensor->Data<float>();
|
||||
|
||||
x2_min = boxes_data[4 * box_index2 + 0] - box2_width_half;
|
||||
x2_max = boxes_data[4 * box_index2 + 0] + box2_width_half;
|
||||
y2_min = boxes_data[4 * box_index2 + 1] - box2_height_half;
|
||||
y2_max = boxes_data[4 * box_index2 + 1] + box2_height_half;
|
||||
const auto num_inputs = ctx->InputCount();
|
||||
|
||||
if (num_inputs > 2) {
|
||||
const auto* max_output_boxes_per_class_tensor = ctx->Input<Tensor>(2);
|
||||
if (max_output_boxes_per_class_tensor != nullptr) {
|
||||
pc.max_output_boxes_per_class_ = max_output_boxes_per_class_tensor->Data<int64_t>();
|
||||
}
|
||||
}
|
||||
|
||||
const float intersection_x_min = std::max(x1_min, x2_min);
|
||||
const float intersection_y_min = std::max(y1_min, y2_min);
|
||||
const float intersection_x_max = std::min(x1_max, x2_max);
|
||||
const float intersection_y_max = std::min(y1_max, y2_max);
|
||||
|
||||
const float intersection_area = std::max(intersection_x_max - intersection_x_min, static_cast<float>(0.0)) *
|
||||
std::max(intersection_y_max - intersection_y_min, static_cast<float>(0.0));
|
||||
|
||||
if (intersection_area <= static_cast<float>(0.0)) {
|
||||
return false;
|
||||
if (num_inputs > 3) {
|
||||
const auto* iou_threshold_tensor = ctx->Input<Tensor>(3);
|
||||
if (iou_threshold_tensor != nullptr) {
|
||||
pc.iou_threshold_ = iou_threshold_tensor->Data<float>();
|
||||
}
|
||||
}
|
||||
|
||||
const float area1 = (x1_max - x1_min) * (y1_max - y1_min);
|
||||
const float area2 = (x2_max - x2_min) * (y2_max - y2_min);
|
||||
const float union_area = area1 + area2 - intersection_area;
|
||||
|
||||
if (area1 <= static_cast<float>(0.0) || area2 <= static_cast<float>(0.0) || union_area <= static_cast<float>(0.0)) {
|
||||
return false;
|
||||
if (num_inputs > 4) {
|
||||
const auto* score_threshold_tensor = ctx->Input<Tensor>(4);
|
||||
if (score_threshold_tensor != nullptr) {
|
||||
pc.score_threshold_ = score_threshold_tensor->Data<float>();
|
||||
}
|
||||
}
|
||||
|
||||
const float intersection_over_union = intersection_area / union_area;
|
||||
const auto& boxes_shape = boxes_tensor->Shape();
|
||||
pc.boxes_size_ = boxes_shape.Size();
|
||||
const auto& scores_shape = scores_tensor->Shape();
|
||||
pc.scores_size_ = scores_shape.Size();
|
||||
|
||||
return intersection_over_union > iou_threshold;
|
||||
}
|
||||
|
||||
Status NonMaxSuppression::ParepareCompute(OpKernelContext* ctx, const TensorShape& boxes_shape, const TensorShape& scores_shape,
|
||||
int64_t& max_output_boxes_per_class, float& iou_threshold, float& score_threshold, bool& has_score_threshold) const {
|
||||
ORT_RETURN_IF_NOT(boxes_shape.NumDimensions() == 3, "boxes must be a 3D tensor.");
|
||||
ORT_RETURN_IF_NOT(scores_shape.NumDimensions() == 3, "scores must be a 3D tensor.");
|
||||
|
||||
auto boxes_dims = boxes_shape.GetDims();
|
||||
auto scores_dims = scores_shape.GetDims();
|
||||
ORT_RETURN_IF_NOT(boxes_dims[0] == scores_dims[0], "boxes and scores should have same num_batches.");
|
||||
ORT_RETURN_IF_NOT(boxes_dims[1] == scores_dims[2], "boxes and scores should have same spatial_dimention.");
|
||||
ORT_RETURN_IF_NOT(boxes_dims[1] == scores_dims[2], "boxes and scores should have same spatial_dimension.");
|
||||
ORT_RETURN_IF_NOT(boxes_dims[2] == 4, "The most inner dimension in boxes must have 4 data.");
|
||||
|
||||
const_cast<int64_t&>(num_batches_) = boxes_dims[0];
|
||||
const_cast<int64_t&>(num_classes_) = scores_dims[1];
|
||||
const_cast<int64_t&>(num_boxes_) = boxes_dims[1];
|
||||
|
||||
const auto* max_output_boxes_per_class_tensor = ctx->Input<Tensor>(2);
|
||||
if (max_output_boxes_per_class_tensor != nullptr) {
|
||||
max_output_boxes_per_class = *(max_output_boxes_per_class_tensor->Data<int64_t>());
|
||||
max_output_boxes_per_class = max_output_boxes_per_class > 0 ? max_output_boxes_per_class : 0;
|
||||
}
|
||||
|
||||
const auto* iou_threshold_tensor = ctx->Input<Tensor>(3);
|
||||
if (iou_threshold_tensor != nullptr) {
|
||||
iou_threshold = *(iou_threshold_tensor->Data<float>());
|
||||
ORT_RETURN_IF_NOT((iou_threshold >= 0 && iou_threshold <= 1), "iou_threshold must be in range [0, 1].");
|
||||
}
|
||||
|
||||
const auto* score_threshold_tensor = ctx->Input<Tensor>(4);
|
||||
if (score_threshold_tensor != nullptr) {
|
||||
has_score_threshold = true;
|
||||
score_threshold = *(score_threshold_tensor->Data<float>());
|
||||
}
|
||||
pc.num_batches_ = boxes_dims[0];
|
||||
pc.num_classes_ = scores_dims[1];
|
||||
pc.num_boxes_ = boxes_dims[1];
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status NonMaxSuppression::Compute(OpKernelContext* ctx) const {
|
||||
const auto* boxes = ctx->Input<Tensor>(0);
|
||||
ORT_ENFORCE(boxes);
|
||||
const auto* scores = ctx->Input<Tensor>(1);
|
||||
ORT_ENFORCE(scores);
|
||||
|
||||
auto& boxes_shape = boxes->Shape();
|
||||
auto& scores_shape = scores->Shape();
|
||||
PrepareContext pc;
|
||||
auto ret = PrepareCompute(ctx, pc);
|
||||
ORT_RETURN_IF_NOT(ret.IsOK(), ret.ErrorMessage());
|
||||
|
||||
int64_t max_output_boxes_per_class = 0;
|
||||
float iou_threshold = 0;
|
||||
// Not so sure for the value range of score_threshold, so set a bool to indicate whether it has this input
|
||||
bool has_score_threshold = false;
|
||||
float score_threshold = 0;
|
||||
float iou_threshold = .0f;
|
||||
float score_threshold = .0f;
|
||||
|
||||
auto ret = ParepareCompute(ctx, boxes_shape, scores_shape, max_output_boxes_per_class,
|
||||
iou_threshold, score_threshold, has_score_threshold);
|
||||
ret = GetThresholdsFromInputs(pc, max_output_boxes_per_class, iou_threshold, score_threshold);
|
||||
ORT_RETURN_IF_NOT(ret.IsOK(), ret.ErrorMessage());
|
||||
|
||||
if (0 == max_output_boxes_per_class) {
|
||||
|
|
@ -153,63 +120,78 @@ Status NonMaxSuppression::Compute(OpKernelContext* ctx) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
const auto* boxes_data = boxes->Data<float>();
|
||||
const auto* scores_data = scores->Data<float>();
|
||||
const auto* const boxes_data = pc.boxes_data_;
|
||||
const auto* const scores_data = pc.scores_data_;
|
||||
|
||||
struct ScoreIndexPair {
|
||||
float score;
|
||||
int64_t index;
|
||||
float score_{};
|
||||
int64_t index_{};
|
||||
|
||||
ScoreIndexPair() = default;
|
||||
explicit ScoreIndexPair(float score, int64_t idx) : score_(score), index_(idx) {}
|
||||
|
||||
bool operator<(const ScoreIndexPair& rhs) const {
|
||||
return score_ < rhs.score_;
|
||||
}
|
||||
};
|
||||
|
||||
auto LessCompare = [](const ScoreIndexPair& lhs, const ScoreIndexPair& rhs) {
|
||||
return lhs.score < rhs.score;
|
||||
};
|
||||
const auto center_point_box = GetCenterPointBox();
|
||||
|
||||
std::vector<selected_index> tmp_selected_indices;
|
||||
for (int64_t batch_index = 0; batch_index < num_batches_; ++batch_index) {
|
||||
for (int64_t class_index = 0; class_index < num_classes_; ++class_index) {
|
||||
int64_t box_score_offset = (batch_index * num_classes_ + class_index) * num_boxes_;
|
||||
int64_t box_offset = batch_index * num_classes_ * num_boxes_ * 4;
|
||||
std::vector<SelectedIndex> selected_indices;
|
||||
for (int64_t batch_index = 0; batch_index < pc.num_batches_; ++batch_index) {
|
||||
for (int64_t class_index = 0; class_index < pc.num_classes_; ++class_index) {
|
||||
int64_t box_score_offset = (batch_index * pc.num_classes_ + class_index) * pc.num_boxes_;
|
||||
int64_t box_offset = batch_index * pc.num_classes_ * pc.num_boxes_ * 4;
|
||||
// Filter by score_threshold_
|
||||
std::priority_queue<ScoreIndexPair, std::deque<ScoreIndexPair>, decltype(LessCompare)> sorted_scores_with_index(LessCompare);
|
||||
for (int64_t box_index = 0; box_index < num_boxes_; ++box_index) {
|
||||
if (!has_score_threshold || (has_score_threshold && scores_data[box_score_offset + box_index] > score_threshold)) {
|
||||
sorted_scores_with_index.emplace(ScoreIndexPair({scores_data[box_score_offset + box_index], box_index}));
|
||||
std::priority_queue<ScoreIndexPair, std::deque<ScoreIndexPair>> sorted_scores_with_index;
|
||||
const auto* class_scores = scores_data + box_score_offset;
|
||||
if (pc.score_threshold_ != nullptr) {
|
||||
for (int64_t box_index = 0; box_index < pc.num_boxes_; ++box_index, ++class_scores) {
|
||||
if (*class_scores > score_threshold) {
|
||||
sorted_scores_with_index.push(ScoreIndexPair(*class_scores, box_index));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int64_t box_index = 0; box_index < pc.num_boxes_; ++box_index, ++class_scores) {
|
||||
sorted_scores_with_index.push(ScoreIndexPair(*class_scores, box_index));
|
||||
}
|
||||
}
|
||||
|
||||
ScoreIndexPair next_top_score;
|
||||
std::vector<int64_t> selected_indicies_inside_class;
|
||||
// Get the next box with top score, filter by iou_threshold_
|
||||
// Get the next box with top score, filter by iou_threshold
|
||||
while (!sorted_scores_with_index.empty()) {
|
||||
next_top_score = sorted_scores_with_index.top();
|
||||
sorted_scores_with_index.pop();
|
||||
|
||||
bool selected = true;
|
||||
// Check with existing selected boxes for this class, suppress if exceed the IOU (Intersection Over Union) threshold
|
||||
for (int64_t selected_indicies_inside_clas : selected_indicies_inside_class) {
|
||||
if (SuppressByIOU(boxes_data + box_offset, selected_indicies_inside_clas, next_top_score.index,
|
||||
iou_threshold)) {
|
||||
for (int64_t selected_index : selected_indicies_inside_class) {
|
||||
if (SuppressByIOU(boxes_data + box_offset, selected_index, next_top_score.index_,
|
||||
center_point_box, iou_threshold)) {
|
||||
selected = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (selected) {
|
||||
if (max_output_boxes_per_class > 0 && static_cast<int64_t>(selected_indicies_inside_class.size()) >= max_output_boxes_per_class) {
|
||||
if (max_output_boxes_per_class > 0 &&
|
||||
static_cast<int64_t>(selected_indicies_inside_class.size()) >= max_output_boxes_per_class) {
|
||||
break;
|
||||
}
|
||||
selected_indicies_inside_class.push_back(next_top_score.index);
|
||||
tmp_selected_indices.emplace_back(batch_index, class_index, next_top_score.index);
|
||||
selected_indicies_inside_class.push_back(next_top_score.index_);
|
||||
selected_indices.emplace_back(batch_index, class_index, next_top_score.index_);
|
||||
}
|
||||
} //while
|
||||
} //for class_index
|
||||
} //for batch_index
|
||||
|
||||
auto num_selected = static_cast<int32_t>(tmp_selected_indices.size());
|
||||
Tensor* selected_indices = ctx->Output(0, {num_selected, 3});
|
||||
ORT_ENFORCE(selected_indices);
|
||||
memcpy(selected_indices->MutableData<int64_t>(), tmp_selected_indices.data(), num_selected * sizeof(selected_index));
|
||||
const auto last_dim = 3;
|
||||
const auto num_selected = selected_indices.size();
|
||||
Tensor* output = ctx->Output(0, {static_cast<int64_t>(num_selected), last_dim});
|
||||
ORT_ENFORCE(output != nullptr);
|
||||
static_assert(last_dim * sizeof(int64_t) == sizeof(SelectedIndex), "Possible modification of SelectedIndex");
|
||||
memcpy(output->MutableData<int64_t>(), selected_indices.data(), num_selected * sizeof(SelectedIndex));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,37 +8,30 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
class NonMaxSuppression final : public OpKernel {
|
||||
public:
|
||||
NonMaxSuppression(const OpKernelInfo& info) : OpKernel(info) {
|
||||
struct PrepareContext;
|
||||
|
||||
class NonMaxSuppressionBase {
|
||||
protected:
|
||||
explicit NonMaxSuppressionBase(const OpKernelInfo& info) {
|
||||
center_point_box_ = info.GetAttrOrDefault<int64_t>("center_point_box", 0);
|
||||
ORT_ENFORCE(0 == center_point_box_ || 1 == center_point_box_, "center_point_box only support 0 or 1");
|
||||
num_batches_ = 0;
|
||||
num_classes_ = 0;
|
||||
num_boxes_ = 0;
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
static Status PrepareCompute(OpKernelContext* ctx, PrepareContext& pc);
|
||||
|
||||
private:
|
||||
bool SuppressByIOU(const float* boxes_data, int64_t box_index1, int64_t box_index2, float iou_threshold) const;
|
||||
void MaxMin(const float& lhs, const float& rhs, float& min, float& max) const;
|
||||
Status ParepareCompute(OpKernelContext* ctx, const TensorShape& boxes_shape, const TensorShape& scores_shape,
|
||||
int64_t& max_output_boxes_per_batch, float& iou_threshold, float& score_threshold, bool& has_score_threshold) const;
|
||||
int64_t GetCenterPointBox() const {
|
||||
return center_point_box_;
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t center_point_box_;
|
||||
};
|
||||
|
||||
int64_t num_batches_;
|
||||
int64_t num_classes_;
|
||||
int64_t num_boxes_;
|
||||
class NonMaxSuppression final : public OpKernel, public NonMaxSuppressionBase {
|
||||
public:
|
||||
explicit NonMaxSuppression(const OpKernelInfo& info) : OpKernel(info), NonMaxSuppressionBase(info) {
|
||||
}
|
||||
|
||||
struct selected_index {
|
||||
selected_index(int64_t batch_index, int64_t class_index, int64_t box_index)
|
||||
: batch_index_(batch_index), class_index_(class_index), box_index_(box_index) {}
|
||||
int64_t batch_index_ = 0;
|
||||
int64_t class_index_ = 0;
|
||||
int64_t box_index_ = 0;
|
||||
};
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
};
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -0,0 +1,129 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#ifdef __NVCC__
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
#define ORT_DEVICE __device__
|
||||
#define HelperMin(a, b) _Min(a, b)
|
||||
#define HelperMax(a, b) _Max(a, b)
|
||||
#else
|
||||
#include <algorithm>
|
||||
#define ORT_DEVICE
|
||||
#define HelperMin(a, b) std::min(a, b)
|
||||
#define HelperMax(a, b) std::max(a, b)
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
struct PrepareContext {
|
||||
const float* boxes_data_ = nullptr;
|
||||
int64_t boxes_size_ = 0ll;
|
||||
const float* scores_data_ = nullptr;
|
||||
int64_t scores_size_ = 0ll;
|
||||
// The below are ptrs since they cab be device specific
|
||||
const int64_t* max_output_boxes_per_class_ = nullptr;
|
||||
const float* score_threshold_ = nullptr;
|
||||
const float* iou_threshold_ = nullptr;
|
||||
int64_t num_batches_ = 0;
|
||||
int64_t num_classes_ = 0;
|
||||
int64_t num_boxes_ = 0;
|
||||
};
|
||||
|
||||
struct SelectedIndex {
|
||||
ORT_DEVICE
|
||||
SelectedIndex(int64_t batch_index, int64_t class_index, int64_t box_index)
|
||||
: batch_index_(batch_index), class_index_(class_index), box_index_(box_index) {}
|
||||
SelectedIndex() = default;
|
||||
int64_t batch_index_ = 0;
|
||||
int64_t class_index_ = 0;
|
||||
int64_t box_index_ = 0;
|
||||
};
|
||||
|
||||
#ifdef __NVCC__
|
||||
namespace cuda {
|
||||
#endif
|
||||
namespace nms_helpers {
|
||||
|
||||
ORT_DEVICE
|
||||
inline void MaxMin(float lhs, float rhs, float& min, float& max) {
|
||||
if (lhs >= rhs) {
|
||||
min = rhs;
|
||||
max = lhs;
|
||||
} else {
|
||||
min = lhs;
|
||||
max = rhs;
|
||||
}
|
||||
}
|
||||
|
||||
ORT_DEVICE
|
||||
inline bool SuppressByIOU(const float* boxes_data, int64_t box_index1, int64_t box_index2,
|
||||
int64_t center_point_box, float iou_threshold) {
|
||||
float x1_min{};
|
||||
float y1_min{};
|
||||
float x1_max{};
|
||||
float y1_max{};
|
||||
float x2_min{};
|
||||
float y2_min{};
|
||||
float x2_max{};
|
||||
float y2_max{};
|
||||
|
||||
const float* box1 = boxes_data + 4 * box_index1;
|
||||
const float* box2 = boxes_data + 4 * box_index2;
|
||||
// center_point_box_ only support 0 or 1
|
||||
if (0 == center_point_box) {
|
||||
// boxes data format [y1, x1, y2, x2],
|
||||
MaxMin(box1[1], box1[3], x1_min, x1_max);
|
||||
MaxMin(box1[0], box1[2], y1_min, y1_max);
|
||||
MaxMin(box2[1], box2[3], x2_min, x2_max);
|
||||
MaxMin(box2[0], box2[2], y2_min, y2_max);
|
||||
} else {
|
||||
// 1 == center_point_box_ => boxes data format [x_center, y_center, width, height]
|
||||
float box1_width_half = box1[2] / 2;
|
||||
float box1_height_half = box1[3] / 2;
|
||||
float box2_width_half = box2[2] / 2;
|
||||
float box2_height_half = box2[3] / 2;
|
||||
|
||||
x1_min = box1[0] - box1_width_half;
|
||||
x1_max = box1[0] + box1_width_half;
|
||||
y1_min = box1[1] - box1_height_half;
|
||||
y1_max = box1[1] + box1_height_half;
|
||||
|
||||
x2_min = box2[0] - box2_width_half;
|
||||
x2_max = box2[0] + box2_width_half;
|
||||
y2_min = box2[1] - box2_height_half;
|
||||
y2_max = box2[1] + box2_height_half;
|
||||
}
|
||||
|
||||
const float intersection_x_min = HelperMax(x1_min, x2_min);
|
||||
const float intersection_y_min = HelperMax(y1_min, y2_min);
|
||||
const float intersection_x_max = HelperMin(x1_max, x2_max);
|
||||
const float intersection_y_max = HelperMin(y1_max, y2_max);
|
||||
|
||||
const float intersection_area = HelperMax(intersection_x_max - intersection_x_min, .0f) *
|
||||
HelperMax(intersection_y_max - intersection_y_min, .0f);
|
||||
|
||||
if (intersection_area <= .0f) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const float area1 = (x1_max - x1_min) * (y1_max - y1_min);
|
||||
const float area2 = (x2_max - x2_min) * (y2_max - y2_min);
|
||||
const float union_area = area1 + area2 - intersection_area;
|
||||
|
||||
if (area1 <= .0f || area2 <= .0f || union_area <= .0f) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const float intersection_over_union = intersection_area / union_area;
|
||||
|
||||
return intersection_over_union > iou_threshold;
|
||||
}
|
||||
#ifdef __NVCC__
|
||||
} // namespace cuda
|
||||
#endif
|
||||
} // nms_helpers
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -267,7 +267,7 @@ TEST(NonMaxSuppressionOpTest, InconsistentBoxAndScoreShapes) {
|
|||
test.AddInput<float>("iou_threshold", {}, {0.5f});
|
||||
test.AddInput<float>("score_threshold", {}, {0.0f});
|
||||
test.AddOutput<int64_t>("selected_indices", {0, 3}, {});
|
||||
test.Run(OpTester::ExpectResult::kExpectFailure, "boxes and scores should have same spatial_dimention.");
|
||||
test.Run(OpTester::ExpectResult::kExpectFailure, "boxes and scores should have same spatial_dimension.");
|
||||
}
|
||||
|
||||
TEST(NonMaxSuppressionOpTest, InvalidIOUThreshold) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue