diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index dc11627ca8..53edf49d03 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -298,7 +298,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int64_t, Slice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, string, Slice); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, Dropout); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, NonMaxSuppression); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, NonMaxSuppression); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, IsInf); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, float, RoiAlign); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, double, RoiAlign); @@ -325,6 +325,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Lo class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, DepthToSpace); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Det); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ScatterElements); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, NonMaxSuppression); void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { @@ -608,7 +609,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -635,6 +636,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/cpu/object_detection/non_max_suppression.cc b/onnxruntime/core/providers/cpu/object_detection/non_max_suppression.cc index c1a376026a..87d21ab321 100644 --- a/onnxruntime/core/providers/cpu/object_detection/non_max_suppression.cc +++ b/onnxruntime/core/providers/cpu/object_detection/non_max_suppression.cc @@ -17,10 +17,18 @@ limitations under the License. namespace onnxruntime { +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + NonMaxSuppression, + kOnnxDomain, + 10, 10, + kCpuExecutionProvider, + KernelDefBuilder(), + NonMaxSuppression); + ONNX_OPERATOR_KERNEL_EX( NonMaxSuppression, kOnnxDomain, - 10, + 11, kCpuExecutionProvider, KernelDefBuilder(), NonMaxSuppression); diff --git a/onnxruntime/test/providers/cpu/object_detection/non_max_suppression_test.cc b/onnxruntime/test/providers/cpu/object_detection/non_max_suppression_test.cc index 45f537bc89..1a796d9da0 100644 --- a/onnxruntime/test/providers/cpu/object_detection/non_max_suppression_test.cc +++ b/onnxruntime/test/providers/cpu/object_detection/non_max_suppression_test.cc @@ -344,5 +344,25 @@ TEST(NonMaxSuppressionOpTest, ZeroMaxOutputPerClass) { test.Run(); } +TEST(NonMaxSuppressionOpTest, WithIOUThresholdOpset11) { + OpTester test("NonMaxSuppression", 11, kOnnxDomain); + test.AddInput("boxes", {1, 6, 4}, + {0.0f, 0.0f, 1.0f, 1.0f, + 0.0f, 0.1f, 1.0f, 1.1f, + 0.0f, -0.1f, 1.0f, 0.9f, + 0.0f, 10.0f, 1.0f, 11.0f, + 0.0f, 10.1f, 1.0f, 11.1f, + 0.0f, 100.0f, 1.0f, 101.0f}); + test.AddInput("scores", {1, 1, 6}, {0.9f, 0.75f, 0.6f, 0.95f, 0.5f, 0.3f}); + test.AddInput("max_output_boxes_per_class", {}, {3L}); + test.AddInput("iou_threshold", {}, {0.5f}); + test.AddInput("score_threshold", {}, {0.0f}); + test.AddOutput("selected_indices", {3, 3}, + {0L, 0L, 3L, + 0L, 0L, 0L, + 0L, 0L, 5L}); + test.Run(); +} + } // namespace test } // namespace onnxruntime