Register CPU kernel for NMS (#1941)

This commit is contained in:
Hariharan Seshadri 2019-09-27 10:25:01 -07:00 committed by GitHub
parent 174c142eff
commit d370aad80c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 33 additions and 3 deletions

View file

@ -298,7 +298,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int64_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, string, Slice);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, Dropout);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, NonMaxSuppression);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, NonMaxSuppression);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, IsInf);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, float, RoiAlign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, double, RoiAlign);
@ -325,6 +325,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Lo
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, DepthToSpace);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Det);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ScatterElements);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, NonMaxSuppression);
void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
@ -608,7 +609,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int64_t, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, string, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, Dropout)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, NonMaxSuppression)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, NonMaxSuppression)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, IsInf)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, float, RoiAlign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, double, RoiAlign)>,
@ -635,6 +636,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, DepthToSpace)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Det)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ScatterElements)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, NonMaxSuppression)>,
};
for (auto& function_table_entry : function_table) {

View file

@ -17,10 +17,18 @@ limitations under the License.
namespace onnxruntime {
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
NonMaxSuppression,
kOnnxDomain,
10, 10,
kCpuExecutionProvider,
KernelDefBuilder(),
NonMaxSuppression);
ONNX_OPERATOR_KERNEL_EX(
NonMaxSuppression,
kOnnxDomain,
10,
11,
kCpuExecutionProvider,
KernelDefBuilder(),
NonMaxSuppression);

View file

@ -344,5 +344,25 @@ TEST(NonMaxSuppressionOpTest, ZeroMaxOutputPerClass) {
test.Run();
}
TEST(NonMaxSuppressionOpTest, WithIOUThresholdOpset11) {
OpTester test("NonMaxSuppression", 11, kOnnxDomain);
test.AddInput<float>("boxes", {1, 6, 4},
{0.0f, 0.0f, 1.0f, 1.0f,
0.0f, 0.1f, 1.0f, 1.1f,
0.0f, -0.1f, 1.0f, 0.9f,
0.0f, 10.0f, 1.0f, 11.0f,
0.0f, 10.1f, 1.0f, 11.1f,
0.0f, 100.0f, 1.0f, 101.0f});
test.AddInput<float>("scores", {1, 1, 6}, {0.9f, 0.75f, 0.6f, 0.95f, 0.5f, 0.3f});
test.AddInput<int64_t>("max_output_boxes_per_class", {}, {3L});
test.AddInput<float>("iou_threshold", {}, {0.5f});
test.AddInput<float>("score_threshold", {}, {0.0f});
test.AddOutput<int64_t>("selected_indices", {3, 3},
{0L, 0L, 3L,
0L, 0L, 0L,
0L, 0L, 5L});
test.Run();
}
} // namespace test
} // namespace onnxruntime