Update NMS to support max_output_boxes_per_class = 0. NMS will do nothing for this case. (#816)

This commit is contained in:
Hector Li 2019-04-11 10:09:33 -07:00 committed by GitHub
parent 56749a84ee
commit 0741baf867
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 26 additions and 4 deletions

View file

@ -105,7 +105,7 @@ Status NonMaxSuppression::ParepareCompute(OpKernelContext* ctx, const TensorShap
const Tensor* max_output_boxes_per_class_tensor = ctx->Input<Tensor>(2);
if (max_output_boxes_per_class_tensor != nullptr) {
max_output_boxes_per_class = *(max_output_boxes_per_class_tensor->Data<int32_t>());
ORT_RETURN_IF_NOT(max_output_boxes_per_class > 0, "max_output_boxes_per_class should be greater than 0.");
max_output_boxes_per_class = max_output_boxes_per_class > 0 ? max_output_boxes_per_class : 0;
}
const Tensor* iou_threshold_tensor = ctx->Input<Tensor>(3);
@ -142,6 +142,11 @@ Status NonMaxSuppression::Compute(OpKernelContext* ctx) const {
iou_threshold, score_threshold, has_score_threshold);
ORT_RETURN_IF_NOT(ret.IsOK(), ret.ErrorMessage());
if (0 == max_output_boxes_per_class) {
ctx->Output(0, {0, 3});
return Status::OK();
}
const float* boxes_data = boxes->Data<float>();
const float* scores_data = scores->Data<float>();

View file

@ -893,7 +893,7 @@ Note: The boxes doesn't has class dimension which means it alwasy has scores cal
.Input(
2,
"max_output_boxes_per_class",
"Integer representing the maximum number of boxes to be selected per batch per class. It is a scalar. Value should be greater than 0",
"Integer representing the maximum number of boxes to be selected per batch per class. It is a scalar.",
"tensor(int32)",
OpSchema::Optional)
.Input(

View file

@ -266,7 +266,7 @@ TEST(NonMaxSuppressionOpTest, InconsistentBoxAndScoreShapes) {
test.AddInput<int32_t>("max_output_boxes_per_class", {}, {30L});
test.AddInput<float>("iou_threshold", {}, {0.5f});
test.AddInput<float>("score_threshold", {}, {0.0f});
test.AddOutput<int32_t>("selected_indices", {0}, {});
test.AddOutput<int32_t>("selected_indices", {0, 3}, {});
test.Run(OpTester::ExpectResult::kExpectFailure, "boxes and scores should have same spatial_dimention.");
}
@ -277,7 +277,7 @@ TEST(NonMaxSuppressionOpTest, InvalidIOUThreshold) {
test.AddInput<int32_t>("max_output_boxes_per_class", {}, {3L});
test.AddInput<float>("iou_threshold", {}, {1.2f});
test.AddInput<float>("score_threshold", {}, {0.0f});
test.AddOutput<int32_t>("selected_indices", {0}, {});
test.AddOutput<int32_t>("selected_indices", {0, 3}, {});
test.Run(OpTester::ExpectResult::kExpectFailure, "iou_threshold must be in range [0, 1]");
}
@ -292,5 +292,22 @@ TEST(NonMaxSuppressionOpTest, EmptyInput) {
test.Run();
}
TEST(NonMaxSuppressionOpTest, ZeroMaxOutputPerClass) {
OpTester test("NonMaxSuppression", 1, onnxruntime::kMSDomain);
test.AddInput<float>("boxes", {1, 6, 4},
{0.0f, 0.0f, 1.0f, 1.0f,
0.0f, 0.1f, 1.0f, 1.1f,
0.0f, -0.1f, 1.0f, 0.9f,
0.0f, 10.0f, 1.0f, 11.0f,
0.0f, 10.1f, 1.0f, 11.1f,
0.0f, 100.0f, 1.0f, 101.0f});
test.AddInput<float>("scores", {1, 1, 6}, {0.9f, 0.75f, 0.6f, 0.95f, 0.5f, 0.3f});
test.AddInput<int32_t>("max_output_boxes_per_class", {}, {0L});
test.AddInput<float>("iou_threshold", {}, {0.5f});
test.AddInput<float>("score_threshold", {}, {0.4f});
test.AddOutput<int32_t>("selected_indices", {0, 3}, {});
test.Run();
}
} // namespace test
} // namespace onnxruntime