Fix incorrect box offset computation in NMS op (#1624)

* More changes

* Fix NMS

* nits
This commit is contained in:
Hariharan Seshadri 2019-08-15 11:41:10 -07:00 committed by GitHub
parent 0c5d2c998b
commit 7545b795df
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 2 deletions

View file

@ -141,7 +141,7 @@ Status NonMaxSuppression::Compute(OpKernelContext* ctx) const {
for (int64_t batch_index = 0; batch_index < pc.num_batches_; ++batch_index) {
for (int64_t class_index = 0; class_index < pc.num_classes_; ++class_index) {
int64_t box_score_offset = (batch_index * pc.num_classes_ + class_index) * pc.num_boxes_;
int64_t box_offset = batch_index * pc.num_classes_ * pc.num_boxes_ * 4;
int64_t box_offset = batch_index * pc.num_boxes_ * 4;
// Filter by score_threshold_
std::priority_queue<ScoreIndexPair, std::deque<ScoreIndexPair>> sorted_scores_with_index;
const auto* class_scores = scores_data + box_score_offset;

View file

@ -73,7 +73,7 @@ TEST(NonMaxSuppressionOpTest, TwoClasses) {
test.Run();
}
TEST(NonMaxSuppressionOpTest, TwoBathes) {
TEST(NonMaxSuppressionOpTest, TwoBatches_OneClass) {
OpTester test("NonMaxSuppression", 10, kOnnxDomain);
test.AddInput<float>("boxes", {2, 6, 4},
{0.0f, 0.0f, 1.0f, 1.0f,
@ -103,6 +103,41 @@ TEST(NonMaxSuppressionOpTest, TwoBathes) {
test.Run();
}
TEST(NonMaxSuppressionOpTest, TwoBatches_TwoClasses) {
OpTester test("NonMaxSuppression", 10, kOnnxDomain);
test.AddInput<float>("boxes", {2, 5, 4},
{0.0f, 0.0f, 0.3f, 0.3f,
0.0f, 0.0f, 0.4f, 0.4f,
0.0f, 0.0f, 0.5f, 0.5f,
0.5f, 0.5f, 0.9f, 0.9f,
0.5f, 0.5f, 1.0f, 1.0f,
0.0f, 0.0f, 0.3f, 0.3f,
0.0f, 0.0f, 0.4f, 0.4f,
0.5f, 0.5f, 0.95f, 0.95f,
0.5f, 0.5f, 0.96f, 0.96f,
0.5f, 0.5f, 1.0f, 1.0f});
test.AddInput<float>("scores", {2, 2, 5},
{0.1f, 0.2f, 0.6f, 0.3f, 0.9f,
0.1f, 0.2f, 0.6f, 0.3f, 0.9f,
0.1f, 0.2f, 0.6f, 0.3f, 0.9f,
0.1f, 0.2f, 0.6f, 0.3f, 0.9f});
test.AddInput<int64_t>("max_output_boxes_per_class", {}, {2L});
test.AddInput<float>("iou_threshold", {}, {0.8f});
test.AddOutput<int64_t>("selected_indices", {8, 3},
{0L, 0L, 4L,
0L, 0L, 2L,
0L, 1L, 4L,
0L, 1L, 2L,
1L, 0L, 4L,
1L, 0L, 1L,
1L, 1L, 4L,
1L, 1L, 1L});
test.Run();
}
TEST(NonMaxSuppressionOpTest, WithScoreThreshold) {
OpTester test("NonMaxSuppression", 10, kOnnxDomain);
test.AddInput<float>("boxes", {1, 6, 4},