mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
Register CPU kernel for NMS (#1941)
This commit is contained in:
parent
174c142eff
commit
d370aad80c
3 changed files with 33 additions and 3 deletions
|
|
@ -298,7 +298,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int64_t, Slice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, string, Slice);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, Dropout);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, NonMaxSuppression);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, NonMaxSuppression);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, IsInf);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, float, RoiAlign);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, double, RoiAlign);
|
||||
|
|
@ -325,6 +325,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Lo
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, DepthToSpace);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Det);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ScatterElements);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, NonMaxSuppression);
|
||||
|
||||
void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
|
|
@ -608,7 +609,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int64_t, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, string, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, Dropout)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, NonMaxSuppression)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, NonMaxSuppression)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, IsInf)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, float, RoiAlign)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, double, RoiAlign)>,
|
||||
|
|
@ -635,6 +636,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, DepthToSpace)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Det)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ScatterElements)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, NonMaxSuppression)>,
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
|
|
|
|||
|
|
@ -17,10 +17,18 @@ limitations under the License.
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
NonMaxSuppression,
|
||||
kOnnxDomain,
|
||||
10, 10,
|
||||
kCpuExecutionProvider,
|
||||
KernelDefBuilder(),
|
||||
NonMaxSuppression);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
NonMaxSuppression,
|
||||
kOnnxDomain,
|
||||
10,
|
||||
11,
|
||||
kCpuExecutionProvider,
|
||||
KernelDefBuilder(),
|
||||
NonMaxSuppression);
|
||||
|
|
|
|||
|
|
@ -344,5 +344,25 @@ TEST(NonMaxSuppressionOpTest, ZeroMaxOutputPerClass) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(NonMaxSuppressionOpTest, WithIOUThresholdOpset11) {
|
||||
OpTester test("NonMaxSuppression", 11, kOnnxDomain);
|
||||
test.AddInput<float>("boxes", {1, 6, 4},
|
||||
{0.0f, 0.0f, 1.0f, 1.0f,
|
||||
0.0f, 0.1f, 1.0f, 1.1f,
|
||||
0.0f, -0.1f, 1.0f, 0.9f,
|
||||
0.0f, 10.0f, 1.0f, 11.0f,
|
||||
0.0f, 10.1f, 1.0f, 11.1f,
|
||||
0.0f, 100.0f, 1.0f, 101.0f});
|
||||
test.AddInput<float>("scores", {1, 1, 6}, {0.9f, 0.75f, 0.6f, 0.95f, 0.5f, 0.3f});
|
||||
test.AddInput<int64_t>("max_output_boxes_per_class", {}, {3L});
|
||||
test.AddInput<float>("iou_threshold", {}, {0.5f});
|
||||
test.AddInput<float>("score_threshold", {}, {0.0f});
|
||||
test.AddOutput<int64_t>("selected_indices", {3, 3},
|
||||
{0L, 0L, 3L,
|
||||
0L, 0L, 0L,
|
||||
0L, 0L, 5L});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue