Fix NMS const_cast that modified kernel state creating (#1303)

* Fix NMS const_cast that modified kernel state creating
  thread safety issue. Re-factor for future CUDA implementation.
This commit is contained in:
Dmitri Smirnov 2019-06-28 09:41:17 -07:00 committed by GitHub
parent 04d581995d
commit 2f698bd54b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 248 additions and 144 deletions

View file

@ -12,6 +12,7 @@ limitations under the License.
/* Modifications Copyright (c) Microsoft. */
#include "non_max_suppression.h"
#include "non_max_suppression_helper.h"
#include <queue>
namespace onnxruntime {
@ -24,128 +25,94 @@ ONNX_OPERATOR_KERNEL_EX(
KernelDefBuilder(),
NonMaxSuppression);
void NonMaxSuppression::MaxMin(const float& lhs, const float& rhs, float& min, float& max) const {
if (lhs >= rhs) {
min = rhs;
max = lhs;
} else {
min = lhs;
max = rhs;
using namespace nms_helpers;
// CPU version
namespace nms_helpers {
Status GetThresholdsFromInputs(const PrepareContext& pc,
int64_t& max_output_boxes_per_class,
float& iou_threshold,
float& score_threshold) {
if (pc.max_output_boxes_per_class_ != nullptr) {
max_output_boxes_per_class = std::max<int64_t>(*pc.max_output_boxes_per_class_, 0);
}
if (pc.iou_threshold_ != nullptr) {
iou_threshold = *pc.iou_threshold_;
ORT_RETURN_IF_NOT((iou_threshold >= 0 && iou_threshold <= 1.f), "iou_threshold must be in range [0, 1].");
}
if (pc.score_threshold_ != nullptr) {
score_threshold = *pc.score_threshold_;
}
return Status::OK();
}
} // namespace nms_helpers
bool NonMaxSuppression::SuppressByIOU(const float* boxes_data, int64_t box_index1, int64_t box_index2, float iou_threshold) const {
float x1_min;
float y1_min;
float x1_max;
float y1_max;
float x2_min;
float y2_min;
float x2_max;
float y2_max;
// center_point_box_ only support 0 or 1
if (0 == center_point_box_) {
// boxes data format [y1, x1, y2, x2],
MaxMin(boxes_data[4 * box_index1 + 1], boxes_data[4 * box_index1 + 3], x1_min, x1_max);
MaxMin(boxes_data[4 * box_index1 + 0], boxes_data[4 * box_index1 + 2], y1_min, y1_max);
MaxMin(boxes_data[4 * box_index2 + 1], boxes_data[4 * box_index2 + 3], x2_min, x2_max);
MaxMin(boxes_data[4 * box_index2 + 0], boxes_data[4 * box_index2 + 2], y2_min, y2_max);
} else {
// 1 == center_point_box_ => boxes data format [x_center, y_center, width, height]
float box1_width_half = boxes_data[4 * box_index1 + 2] / 2;
float box1_height_half = boxes_data[4 * box_index1 + 3] / 2;
float box2_width_half = boxes_data[4 * box_index2 + 2] / 2;
float box2_height_half = boxes_data[4 * box_index2 + 3] / 2;
Status NonMaxSuppressionBase::PrepareCompute(OpKernelContext* ctx, PrepareContext& pc) {
const auto* boxes_tensor = ctx->Input<Tensor>(0);
ORT_ENFORCE(boxes_tensor);
pc.boxes_data_ = boxes_tensor->Data<float>();
x1_min = boxes_data[4 * box_index1 + 0] - box1_width_half;
x1_max = boxes_data[4 * box_index1 + 0] + box1_width_half;
y1_min = boxes_data[4 * box_index1 + 1] - box1_height_half;
y1_max = boxes_data[4 * box_index1 + 1] + box1_height_half;
const auto* scores_tensor = ctx->Input<Tensor>(1);
ORT_ENFORCE(scores_tensor);
pc.scores_data_ = scores_tensor->Data<float>();
x2_min = boxes_data[4 * box_index2 + 0] - box2_width_half;
x2_max = boxes_data[4 * box_index2 + 0] + box2_width_half;
y2_min = boxes_data[4 * box_index2 + 1] - box2_height_half;
y2_max = boxes_data[4 * box_index2 + 1] + box2_height_half;
const auto num_inputs = ctx->InputCount();
if (num_inputs > 2) {
const auto* max_output_boxes_per_class_tensor = ctx->Input<Tensor>(2);
if (max_output_boxes_per_class_tensor != nullptr) {
pc.max_output_boxes_per_class_ = max_output_boxes_per_class_tensor->Data<int64_t>();
}
}
const float intersection_x_min = std::max(x1_min, x2_min);
const float intersection_y_min = std::max(y1_min, y2_min);
const float intersection_x_max = std::min(x1_max, x2_max);
const float intersection_y_max = std::min(y1_max, y2_max);
const float intersection_area = std::max(intersection_x_max - intersection_x_min, static_cast<float>(0.0)) *
std::max(intersection_y_max - intersection_y_min, static_cast<float>(0.0));
if (intersection_area <= static_cast<float>(0.0)) {
return false;
if (num_inputs > 3) {
const auto* iou_threshold_tensor = ctx->Input<Tensor>(3);
if (iou_threshold_tensor != nullptr) {
pc.iou_threshold_ = iou_threshold_tensor->Data<float>();
}
}
const float area1 = (x1_max - x1_min) * (y1_max - y1_min);
const float area2 = (x2_max - x2_min) * (y2_max - y2_min);
const float union_area = area1 + area2 - intersection_area;
if (area1 <= static_cast<float>(0.0) || area2 <= static_cast<float>(0.0) || union_area <= static_cast<float>(0.0)) {
return false;
if (num_inputs > 4) {
const auto* score_threshold_tensor = ctx->Input<Tensor>(4);
if (score_threshold_tensor != nullptr) {
pc.score_threshold_ = score_threshold_tensor->Data<float>();
}
}
const float intersection_over_union = intersection_area / union_area;
const auto& boxes_shape = boxes_tensor->Shape();
pc.boxes_size_ = boxes_shape.Size();
const auto& scores_shape = scores_tensor->Shape();
pc.scores_size_ = scores_shape.Size();
return intersection_over_union > iou_threshold;
}
Status NonMaxSuppression::ParepareCompute(OpKernelContext* ctx, const TensorShape& boxes_shape, const TensorShape& scores_shape,
int64_t& max_output_boxes_per_class, float& iou_threshold, float& score_threshold, bool& has_score_threshold) const {
ORT_RETURN_IF_NOT(boxes_shape.NumDimensions() == 3, "boxes must be a 3D tensor.");
ORT_RETURN_IF_NOT(scores_shape.NumDimensions() == 3, "scores must be a 3D tensor.");
auto boxes_dims = boxes_shape.GetDims();
auto scores_dims = scores_shape.GetDims();
ORT_RETURN_IF_NOT(boxes_dims[0] == scores_dims[0], "boxes and scores should have same num_batches.");
ORT_RETURN_IF_NOT(boxes_dims[1] == scores_dims[2], "boxes and scores should have same spatial_dimention.");
ORT_RETURN_IF_NOT(boxes_dims[1] == scores_dims[2], "boxes and scores should have same spatial_dimension.");
ORT_RETURN_IF_NOT(boxes_dims[2] == 4, "The most inner dimension in boxes must have 4 data.");
const_cast<int64_t&>(num_batches_) = boxes_dims[0];
const_cast<int64_t&>(num_classes_) = scores_dims[1];
const_cast<int64_t&>(num_boxes_) = boxes_dims[1];
const auto* max_output_boxes_per_class_tensor = ctx->Input<Tensor>(2);
if (max_output_boxes_per_class_tensor != nullptr) {
max_output_boxes_per_class = *(max_output_boxes_per_class_tensor->Data<int64_t>());
max_output_boxes_per_class = max_output_boxes_per_class > 0 ? max_output_boxes_per_class : 0;
}
const auto* iou_threshold_tensor = ctx->Input<Tensor>(3);
if (iou_threshold_tensor != nullptr) {
iou_threshold = *(iou_threshold_tensor->Data<float>());
ORT_RETURN_IF_NOT((iou_threshold >= 0 && iou_threshold <= 1), "iou_threshold must be in range [0, 1].");
}
const auto* score_threshold_tensor = ctx->Input<Tensor>(4);
if (score_threshold_tensor != nullptr) {
has_score_threshold = true;
score_threshold = *(score_threshold_tensor->Data<float>());
}
pc.num_batches_ = boxes_dims[0];
pc.num_classes_ = scores_dims[1];
pc.num_boxes_ = boxes_dims[1];
return Status::OK();
}
Status NonMaxSuppression::Compute(OpKernelContext* ctx) const {
const auto* boxes = ctx->Input<Tensor>(0);
ORT_ENFORCE(boxes);
const auto* scores = ctx->Input<Tensor>(1);
ORT_ENFORCE(scores);
auto& boxes_shape = boxes->Shape();
auto& scores_shape = scores->Shape();
PrepareContext pc;
auto ret = PrepareCompute(ctx, pc);
ORT_RETURN_IF_NOT(ret.IsOK(), ret.ErrorMessage());
int64_t max_output_boxes_per_class = 0;
float iou_threshold = 0;
// Not so sure for the value range of score_threshold, so set a bool to indicate whether it has this input
bool has_score_threshold = false;
float score_threshold = 0;
float iou_threshold = .0f;
float score_threshold = .0f;
auto ret = ParepareCompute(ctx, boxes_shape, scores_shape, max_output_boxes_per_class,
iou_threshold, score_threshold, has_score_threshold);
ret = GetThresholdsFromInputs(pc, max_output_boxes_per_class, iou_threshold, score_threshold);
ORT_RETURN_IF_NOT(ret.IsOK(), ret.ErrorMessage());
if (0 == max_output_boxes_per_class) {
@ -153,63 +120,78 @@ Status NonMaxSuppression::Compute(OpKernelContext* ctx) const {
return Status::OK();
}
const auto* boxes_data = boxes->Data<float>();
const auto* scores_data = scores->Data<float>();
const auto* const boxes_data = pc.boxes_data_;
const auto* const scores_data = pc.scores_data_;
struct ScoreIndexPair {
float score;
int64_t index;
float score_{};
int64_t index_{};
ScoreIndexPair() = default;
explicit ScoreIndexPair(float score, int64_t idx) : score_(score), index_(idx) {}
bool operator<(const ScoreIndexPair& rhs) const {
return score_ < rhs.score_;
}
};
auto LessCompare = [](const ScoreIndexPair& lhs, const ScoreIndexPair& rhs) {
return lhs.score < rhs.score;
};
const auto center_point_box = GetCenterPointBox();
std::vector<selected_index> tmp_selected_indices;
for (int64_t batch_index = 0; batch_index < num_batches_; ++batch_index) {
for (int64_t class_index = 0; class_index < num_classes_; ++class_index) {
int64_t box_score_offset = (batch_index * num_classes_ + class_index) * num_boxes_;
int64_t box_offset = batch_index * num_classes_ * num_boxes_ * 4;
std::vector<SelectedIndex> selected_indices;
for (int64_t batch_index = 0; batch_index < pc.num_batches_; ++batch_index) {
for (int64_t class_index = 0; class_index < pc.num_classes_; ++class_index) {
int64_t box_score_offset = (batch_index * pc.num_classes_ + class_index) * pc.num_boxes_;
int64_t box_offset = batch_index * pc.num_classes_ * pc.num_boxes_ * 4;
// Filter by score_threshold_
std::priority_queue<ScoreIndexPair, std::deque<ScoreIndexPair>, decltype(LessCompare)> sorted_scores_with_index(LessCompare);
for (int64_t box_index = 0; box_index < num_boxes_; ++box_index) {
if (!has_score_threshold || (has_score_threshold && scores_data[box_score_offset + box_index] > score_threshold)) {
sorted_scores_with_index.emplace(ScoreIndexPair({scores_data[box_score_offset + box_index], box_index}));
std::priority_queue<ScoreIndexPair, std::deque<ScoreIndexPair>> sorted_scores_with_index;
const auto* class_scores = scores_data + box_score_offset;
if (pc.score_threshold_ != nullptr) {
for (int64_t box_index = 0; box_index < pc.num_boxes_; ++box_index, ++class_scores) {
if (*class_scores > score_threshold) {
sorted_scores_with_index.push(ScoreIndexPair(*class_scores, box_index));
}
}
} else {
for (int64_t box_index = 0; box_index < pc.num_boxes_; ++box_index, ++class_scores) {
sorted_scores_with_index.push(ScoreIndexPair(*class_scores, box_index));
}
}
ScoreIndexPair next_top_score;
std::vector<int64_t> selected_indicies_inside_class;
// Get the next box with top score, filter by iou_threshold_
// Get the next box with top score, filter by iou_threshold
while (!sorted_scores_with_index.empty()) {
next_top_score = sorted_scores_with_index.top();
sorted_scores_with_index.pop();
bool selected = true;
// Check with existing selected boxes for this class, suppress if exceed the IOU (Intersection Over Union) threshold
for (int64_t selected_indicies_inside_clas : selected_indicies_inside_class) {
if (SuppressByIOU(boxes_data + box_offset, selected_indicies_inside_clas, next_top_score.index,
iou_threshold)) {
for (int64_t selected_index : selected_indicies_inside_class) {
if (SuppressByIOU(boxes_data + box_offset, selected_index, next_top_score.index_,
center_point_box, iou_threshold)) {
selected = false;
break;
}
}
if (selected) {
if (max_output_boxes_per_class > 0 && static_cast<int64_t>(selected_indicies_inside_class.size()) >= max_output_boxes_per_class) {
if (max_output_boxes_per_class > 0 &&
static_cast<int64_t>(selected_indicies_inside_class.size()) >= max_output_boxes_per_class) {
break;
}
selected_indicies_inside_class.push_back(next_top_score.index);
tmp_selected_indices.emplace_back(batch_index, class_index, next_top_score.index);
selected_indicies_inside_class.push_back(next_top_score.index_);
selected_indices.emplace_back(batch_index, class_index, next_top_score.index_);
}
} //while
} //for class_index
} //for batch_index
auto num_selected = static_cast<int32_t>(tmp_selected_indices.size());
Tensor* selected_indices = ctx->Output(0, {num_selected, 3});
ORT_ENFORCE(selected_indices);
memcpy(selected_indices->MutableData<int64_t>(), tmp_selected_indices.data(), num_selected * sizeof(selected_index));
const auto last_dim = 3;
const auto num_selected = selected_indices.size();
Tensor* output = ctx->Output(0, {static_cast<int64_t>(num_selected), last_dim});
ORT_ENFORCE(output != nullptr);
static_assert(last_dim * sizeof(int64_t) == sizeof(SelectedIndex), "Possible modification of SelectedIndex");
memcpy(output->MutableData<int64_t>(), selected_indices.data(), num_selected * sizeof(SelectedIndex));
return Status::OK();
}

View file

@ -8,37 +8,30 @@
namespace onnxruntime {
class NonMaxSuppression final : public OpKernel {
public:
NonMaxSuppression(const OpKernelInfo& info) : OpKernel(info) {
struct PrepareContext;
class NonMaxSuppressionBase {
protected:
explicit NonMaxSuppressionBase(const OpKernelInfo& info) {
center_point_box_ = info.GetAttrOrDefault<int64_t>("center_point_box", 0);
ORT_ENFORCE(0 == center_point_box_ || 1 == center_point_box_, "center_point_box only support 0 or 1");
num_batches_ = 0;
num_classes_ = 0;
num_boxes_ = 0;
}
Status Compute(OpKernelContext* context) const override;
static Status PrepareCompute(OpKernelContext* ctx, PrepareContext& pc);
private:
bool SuppressByIOU(const float* boxes_data, int64_t box_index1, int64_t box_index2, float iou_threshold) const;
void MaxMin(const float& lhs, const float& rhs, float& min, float& max) const;
Status ParepareCompute(OpKernelContext* ctx, const TensorShape& boxes_shape, const TensorShape& scores_shape,
int64_t& max_output_boxes_per_batch, float& iou_threshold, float& score_threshold, bool& has_score_threshold) const;
int64_t GetCenterPointBox() const {
return center_point_box_;
}
private:
int64_t center_point_box_;
};
int64_t num_batches_;
int64_t num_classes_;
int64_t num_boxes_;
class NonMaxSuppression final : public OpKernel, public NonMaxSuppressionBase {
public:
explicit NonMaxSuppression(const OpKernelInfo& info) : OpKernel(info), NonMaxSuppressionBase(info) {
}
struct selected_index {
selected_index(int64_t batch_index, int64_t class_index, int64_t box_index)
: batch_index_(batch_index), class_index_(class_index), box_index_(box_index) {}
int64_t batch_index_ = 0;
int64_t class_index_ = 0;
int64_t box_index_ = 0;
};
Status Compute(OpKernelContext* context) const override;
};
} // namespace onnxruntime

View file

@ -0,0 +1,129 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <stdint.h>
#ifdef __NVCC__
#include "core/providers/cuda/cu_inc/common.cuh"
#define ORT_DEVICE __device__
#define HelperMin(a, b) _Min(a, b)
#define HelperMax(a, b) _Max(a, b)
#else
#include <algorithm>
#define ORT_DEVICE
#define HelperMin(a, b) std::min(a, b)
#define HelperMax(a, b) std::max(a, b)
#endif
namespace onnxruntime {
struct PrepareContext {
const float* boxes_data_ = nullptr;
int64_t boxes_size_ = 0ll;
const float* scores_data_ = nullptr;
int64_t scores_size_ = 0ll;
// The below are ptrs since they cab be device specific
const int64_t* max_output_boxes_per_class_ = nullptr;
const float* score_threshold_ = nullptr;
const float* iou_threshold_ = nullptr;
int64_t num_batches_ = 0;
int64_t num_classes_ = 0;
int64_t num_boxes_ = 0;
};
struct SelectedIndex {
ORT_DEVICE
SelectedIndex(int64_t batch_index, int64_t class_index, int64_t box_index)
: batch_index_(batch_index), class_index_(class_index), box_index_(box_index) {}
SelectedIndex() = default;
int64_t batch_index_ = 0;
int64_t class_index_ = 0;
int64_t box_index_ = 0;
};
#ifdef __NVCC__
namespace cuda {
#endif
namespace nms_helpers {
ORT_DEVICE
inline void MaxMin(float lhs, float rhs, float& min, float& max) {
if (lhs >= rhs) {
min = rhs;
max = lhs;
} else {
min = lhs;
max = rhs;
}
}
ORT_DEVICE
inline bool SuppressByIOU(const float* boxes_data, int64_t box_index1, int64_t box_index2,
int64_t center_point_box, float iou_threshold) {
float x1_min{};
float y1_min{};
float x1_max{};
float y1_max{};
float x2_min{};
float y2_min{};
float x2_max{};
float y2_max{};
const float* box1 = boxes_data + 4 * box_index1;
const float* box2 = boxes_data + 4 * box_index2;
// center_point_box_ only support 0 or 1
if (0 == center_point_box) {
// boxes data format [y1, x1, y2, x2],
MaxMin(box1[1], box1[3], x1_min, x1_max);
MaxMin(box1[0], box1[2], y1_min, y1_max);
MaxMin(box2[1], box2[3], x2_min, x2_max);
MaxMin(box2[0], box2[2], y2_min, y2_max);
} else {
// 1 == center_point_box_ => boxes data format [x_center, y_center, width, height]
float box1_width_half = box1[2] / 2;
float box1_height_half = box1[3] / 2;
float box2_width_half = box2[2] / 2;
float box2_height_half = box2[3] / 2;
x1_min = box1[0] - box1_width_half;
x1_max = box1[0] + box1_width_half;
y1_min = box1[1] - box1_height_half;
y1_max = box1[1] + box1_height_half;
x2_min = box2[0] - box2_width_half;
x2_max = box2[0] + box2_width_half;
y2_min = box2[1] - box2_height_half;
y2_max = box2[1] + box2_height_half;
}
const float intersection_x_min = HelperMax(x1_min, x2_min);
const float intersection_y_min = HelperMax(y1_min, y2_min);
const float intersection_x_max = HelperMin(x1_max, x2_max);
const float intersection_y_max = HelperMin(y1_max, y2_max);
const float intersection_area = HelperMax(intersection_x_max - intersection_x_min, .0f) *
HelperMax(intersection_y_max - intersection_y_min, .0f);
if (intersection_area <= .0f) {
return false;
}
const float area1 = (x1_max - x1_min) * (y1_max - y1_min);
const float area2 = (x2_max - x2_min) * (y2_max - y2_min);
const float union_area = area1 + area2 - intersection_area;
if (area1 <= .0f || area2 <= .0f || union_area <= .0f) {
return false;
}
const float intersection_over_union = intersection_area / union_area;
return intersection_over_union > iou_threshold;
}
#ifdef __NVCC__
} // namespace cuda
#endif
} // nms_helpers
} // namespace onnxruntime

View file

@ -267,7 +267,7 @@ TEST(NonMaxSuppressionOpTest, InconsistentBoxAndScoreShapes) {
test.AddInput<float>("iou_threshold", {}, {0.5f});
test.AddInput<float>("score_threshold", {}, {0.0f});
test.AddOutput<int64_t>("selected_indices", {0, 3}, {});
test.Run(OpTester::ExpectResult::kExpectFailure, "boxes and scores should have same spatial_dimention.");
test.Run(OpTester::ExpectResult::kExpectFailure, "boxes and scores should have same spatial_dimension.");
}
TEST(NonMaxSuppressionOpTest, InvalidIOUThreshold) {