mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Fix incorrect box offset computation in NMS op (#1624)
* More changes * Fix NMS * nits
This commit is contained in:
parent
0c5d2c998b
commit
7545b795df
2 changed files with 37 additions and 2 deletions
|
|
@ -141,7 +141,7 @@ Status NonMaxSuppression::Compute(OpKernelContext* ctx) const {
|
|||
for (int64_t batch_index = 0; batch_index < pc.num_batches_; ++batch_index) {
|
||||
for (int64_t class_index = 0; class_index < pc.num_classes_; ++class_index) {
|
||||
int64_t box_score_offset = (batch_index * pc.num_classes_ + class_index) * pc.num_boxes_;
|
||||
int64_t box_offset = batch_index * pc.num_classes_ * pc.num_boxes_ * 4;
|
||||
int64_t box_offset = batch_index * pc.num_boxes_ * 4;
|
||||
// Filter by score_threshold_
|
||||
std::priority_queue<ScoreIndexPair, std::deque<ScoreIndexPair>> sorted_scores_with_index;
|
||||
const auto* class_scores = scores_data + box_score_offset;
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ TEST(NonMaxSuppressionOpTest, TwoClasses) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(NonMaxSuppressionOpTest, TwoBathes) {
|
||||
TEST(NonMaxSuppressionOpTest, TwoBatches_OneClass) {
|
||||
OpTester test("NonMaxSuppression", 10, kOnnxDomain);
|
||||
test.AddInput<float>("boxes", {2, 6, 4},
|
||||
{0.0f, 0.0f, 1.0f, 1.0f,
|
||||
|
|
@ -103,6 +103,41 @@ TEST(NonMaxSuppressionOpTest, TwoBathes) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(NonMaxSuppressionOpTest, TwoBatches_TwoClasses) {
|
||||
OpTester test("NonMaxSuppression", 10, kOnnxDomain);
|
||||
test.AddInput<float>("boxes", {2, 5, 4},
|
||||
{0.0f, 0.0f, 0.3f, 0.3f,
|
||||
0.0f, 0.0f, 0.4f, 0.4f,
|
||||
0.0f, 0.0f, 0.5f, 0.5f,
|
||||
0.5f, 0.5f, 0.9f, 0.9f,
|
||||
0.5f, 0.5f, 1.0f, 1.0f,
|
||||
|
||||
0.0f, 0.0f, 0.3f, 0.3f,
|
||||
0.0f, 0.0f, 0.4f, 0.4f,
|
||||
0.5f, 0.5f, 0.95f, 0.95f,
|
||||
0.5f, 0.5f, 0.96f, 0.96f,
|
||||
0.5f, 0.5f, 1.0f, 1.0f});
|
||||
test.AddInput<float>("scores", {2, 2, 5},
|
||||
{0.1f, 0.2f, 0.6f, 0.3f, 0.9f,
|
||||
0.1f, 0.2f, 0.6f, 0.3f, 0.9f,
|
||||
|
||||
0.1f, 0.2f, 0.6f, 0.3f, 0.9f,
|
||||
0.1f, 0.2f, 0.6f, 0.3f, 0.9f});
|
||||
test.AddInput<int64_t>("max_output_boxes_per_class", {}, {2L});
|
||||
test.AddInput<float>("iou_threshold", {}, {0.8f});
|
||||
test.AddOutput<int64_t>("selected_indices", {8, 3},
|
||||
{0L, 0L, 4L,
|
||||
0L, 0L, 2L,
|
||||
0L, 1L, 4L,
|
||||
0L, 1L, 2L,
|
||||
|
||||
1L, 0L, 4L,
|
||||
1L, 0L, 1L,
|
||||
1L, 1L, 4L,
|
||||
1L, 1L, 1L});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(NonMaxSuppressionOpTest, WithScoreThreshold) {
|
||||
OpTester test("NonMaxSuppression", 10, kOnnxDomain);
|
||||
test.AddInput<float>("boxes", {1, 6, 4},
|
||||
|
|
|
|||
Loading…
Reference in a new issue