mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Update NMS to support max_output_boxes_per_class = 0. NMS will do nothing for this case. (#816)
This commit is contained in:
parent
56749a84ee
commit
0741baf867
3 changed files with 26 additions and 4 deletions
|
|
@ -105,7 +105,7 @@ Status NonMaxSuppression::ParepareCompute(OpKernelContext* ctx, const TensorShap
|
|||
const Tensor* max_output_boxes_per_class_tensor = ctx->Input<Tensor>(2);
|
||||
if (max_output_boxes_per_class_tensor != nullptr) {
|
||||
max_output_boxes_per_class = *(max_output_boxes_per_class_tensor->Data<int32_t>());
|
||||
ORT_RETURN_IF_NOT(max_output_boxes_per_class > 0, "max_output_boxes_per_class should be greater than 0.");
|
||||
max_output_boxes_per_class = max_output_boxes_per_class > 0 ? max_output_boxes_per_class : 0;
|
||||
}
|
||||
|
||||
const Tensor* iou_threshold_tensor = ctx->Input<Tensor>(3);
|
||||
|
|
@ -142,6 +142,11 @@ Status NonMaxSuppression::Compute(OpKernelContext* ctx) const {
|
|||
iou_threshold, score_threshold, has_score_threshold);
|
||||
ORT_RETURN_IF_NOT(ret.IsOK(), ret.ErrorMessage());
|
||||
|
||||
if (0 == max_output_boxes_per_class) {
|
||||
ctx->Output(0, {0, 3});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const float* boxes_data = boxes->Data<float>();
|
||||
const float* scores_data = scores->Data<float>();
|
||||
|
||||
|
|
|
|||
|
|
@ -893,7 +893,7 @@ Note: The boxes doesn't has class dimension which means it alwasy has scores cal
|
|||
.Input(
|
||||
2,
|
||||
"max_output_boxes_per_class",
|
||||
"Integer representing the maximum number of boxes to be selected per batch per class. It is a scalar. Value should be greater than 0",
|
||||
"Integer representing the maximum number of boxes to be selected per batch per class. It is a scalar.",
|
||||
"tensor(int32)",
|
||||
OpSchema::Optional)
|
||||
.Input(
|
||||
|
|
|
|||
|
|
@ -266,7 +266,7 @@ TEST(NonMaxSuppressionOpTest, InconsistentBoxAndScoreShapes) {
|
|||
test.AddInput<int32_t>("max_output_boxes_per_class", {}, {30L});
|
||||
test.AddInput<float>("iou_threshold", {}, {0.5f});
|
||||
test.AddInput<float>("score_threshold", {}, {0.0f});
|
||||
test.AddOutput<int32_t>("selected_indices", {0}, {});
|
||||
test.AddOutput<int32_t>("selected_indices", {0, 3}, {});
|
||||
test.Run(OpTester::ExpectResult::kExpectFailure, "boxes and scores should have same spatial_dimention.");
|
||||
}
|
||||
|
||||
|
|
@ -277,7 +277,7 @@ TEST(NonMaxSuppressionOpTest, InvalidIOUThreshold) {
|
|||
test.AddInput<int32_t>("max_output_boxes_per_class", {}, {3L});
|
||||
test.AddInput<float>("iou_threshold", {}, {1.2f});
|
||||
test.AddInput<float>("score_threshold", {}, {0.0f});
|
||||
test.AddOutput<int32_t>("selected_indices", {0}, {});
|
||||
test.AddOutput<int32_t>("selected_indices", {0, 3}, {});
|
||||
test.Run(OpTester::ExpectResult::kExpectFailure, "iou_threshold must be in range [0, 1]");
|
||||
}
|
||||
|
||||
|
|
@ -292,5 +292,22 @@ TEST(NonMaxSuppressionOpTest, EmptyInput) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(NonMaxSuppressionOpTest, ZeroMaxOutputPerClass) {
|
||||
OpTester test("NonMaxSuppression", 1, onnxruntime::kMSDomain);
|
||||
test.AddInput<float>("boxes", {1, 6, 4},
|
||||
{0.0f, 0.0f, 1.0f, 1.0f,
|
||||
0.0f, 0.1f, 1.0f, 1.1f,
|
||||
0.0f, -0.1f, 1.0f, 0.9f,
|
||||
0.0f, 10.0f, 1.0f, 11.0f,
|
||||
0.0f, 10.1f, 1.0f, 11.1f,
|
||||
0.0f, 100.0f, 1.0f, 101.0f});
|
||||
test.AddInput<float>("scores", {1, 1, 6}, {0.9f, 0.75f, 0.6f, 0.95f, 0.5f, 0.3f});
|
||||
test.AddInput<int32_t>("max_output_boxes_per_class", {}, {0L});
|
||||
test.AddInput<float>("iou_threshold", {}, {0.5f});
|
||||
test.AddInput<float>("score_threshold", {}, {0.4f});
|
||||
test.AddOutput<int32_t>("selected_indices", {0, 3}, {});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue