diff --git a/onnxruntime/contrib_ops/contrib_kernels.cc b/onnxruntime/contrib_ops/contrib_kernels.cc new file mode 100644 index 0000000000..2a8172ec4f --- /dev/null +++ b/onnxruntime/contrib_ops/contrib_kernels.cc @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/contrib_kernels.h" +#include "core/graph/constants.h" + +namespace onnxruntime { +namespace contrib { + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SampleOp); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, StringNormalizer); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, NonMaxSuppression); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Range); + +void RegisterContribKernels(std::function fn) { + fn(BuildKernel()); + + // add more kernels here + + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); + fn(BuildKernel()); +} +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/contrib_kernels.h b/onnxruntime/contrib_ops/contrib_kernels.h new file mode 100644 index 0000000000..a8e6f44157 --- /dev/null +++ b/onnxruntime/contrib_ops/contrib_kernels.h @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/op_kernel.h" +#include "core/framework/kernel_registry.h" + +namespace onnxruntime { +namespace contrib { +void RegisterContribKernels(std::function create_fn); +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/attnlstm/deep_cpu_attn_lstm.h b/onnxruntime/contrib_ops/cpu/attnlstm/deep_cpu_attn_lstm.h index 6705ce744c..fdbb65a5de 100644 --- a/onnxruntime/contrib_ops/cpu/attnlstm/deep_cpu_attn_lstm.h +++ b/onnxruntime/contrib_ops/cpu/attnlstm/deep_cpu_attn_lstm.h @@ -53,7 +53,7 @@ class DeepCpuAttnLstmOp final : public OpKernel { } } - ONNXRUNTIME_ENFORCE(activation_func_names.size() == num_directions_ * 3); + ONNXRUNTIME_ENFORCE(static_cast(activation_func_names.size()) == num_directions_ * 3); activation_funcs_ = ActivationFuncs(activation_func_names, activation_func_alphas, diff --git a/onnxruntime/core/framework/environment.cc b/onnxruntime/core/framework/environment.cc index 48bc2fcff1..58c51b2f06 100644 --- a/onnxruntime/core/framework/environment.cc +++ b/onnxruntime/core/framework/environment.cc @@ -4,9 +4,8 @@ #include "core/framework/environment.h" #include "core/framework/allocatormgr.h" #include "core/graph/constants.h" +#include "core/graph/contrib_ops/contrib_defs.h" #include "core/graph/op.h" -#include "onnx/defs/schema.h" -#include "contrib_ops/contrib_ops.h" namespace onnxruntime { using namespace ::onnxruntime::common; diff --git a/onnxruntime/contrib_ops/cpu/attnlstm/attn_lstm_schema_defs.cc b/onnxruntime/core/graph/contrib_ops/attn_lstm_schema_defs.cc similarity index 100% rename from onnxruntime/contrib_ops/cpu/attnlstm/attn_lstm_schema_defs.cc rename to onnxruntime/core/graph/contrib_ops/attn_lstm_schema_defs.cc diff --git a/onnxruntime/contrib_ops/cpu/attnlstm/attn_lstm_schema_defs.h b/onnxruntime/core/graph/contrib_ops/attn_lstm_schema_defs.h similarity index 100% rename from onnxruntime/contrib_ops/cpu/attnlstm/attn_lstm_schema_defs.h rename to onnxruntime/core/graph/contrib_ops/attn_lstm_schema_defs.h diff --git a/onnxruntime/contrib_ops/contrib_ops.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc similarity index 92% rename from onnxruntime/contrib_ops/contrib_ops.cc rename to onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 7c8f8df3c7..c60e93e66b 100644 --- a/onnxruntime/contrib_ops/contrib_ops.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1,14 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "contrib_ops/contrib_ops.h" - #include "core/graph/constants.h" +#include "core/graph/contrib_ops/attn_lstm_schema_defs.h" +#include "core/graph/contrib_ops/contrib_defs.h" +#include "core/graph/contrib_ops/range_schema_defs.h" #include "core/graph/op.h" -#include "onnx/defs/schema.h" - -#include "./cpu/attnlstm/attn_lstm_schema_defs.h" -#include "./cpu/range_schema_defs.h" namespace onnxruntime { namespace contrib { @@ -502,32 +499,5 @@ The bounding box coordinates corresponding to the selected indices can then be o }) .SetDoc(R"DOC([optional] Step1: Remove elements in X if they match any of the stop words so that the output tensor will not contain any stop words. This operator only accepts [C]- and [1, C]-tensors. If all elements in X are dropped, the output will be the default value of string tensor with shape [1] if input shape is [C] and shape [1, 1] if input shape is [1, C].)DOC"); } - -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SampleOp); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QuantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, StringNormalizer); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, NonMaxSuppression); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Range); - -void RegisterContribKernels(std::function fn) { - fn(BuildKernel()); - - // add more kernels here - - fn(BuildKernel()); - fn(BuildKernel()); - fn(BuildKernel()); - fn(BuildKernel()); - fn(BuildKernel()); - fn(BuildKernel()); - fn(BuildKernel()); - fn(BuildKernel()); - fn(BuildKernel()); -} } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/contrib_ops.h b/onnxruntime/core/graph/contrib_ops/contrib_defs.h similarity index 85% rename from onnxruntime/contrib_ops/contrib_ops.h rename to onnxruntime/core/graph/contrib_ops/contrib_defs.h index 2acd2ecfaf..10da0479d7 100644 --- a/onnxruntime/contrib_ops/contrib_ops.h +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.h @@ -3,8 +3,7 @@ #pragma once -#include "core/framework/op_kernel.h" -#include "core/framework/kernel_registry.h" +#include "core/graph/onnx_protobuf.h" namespace onnxruntime { namespace contrib { @@ -21,12 +20,11 @@ namespace contrib { ONNX_CONTRIB_OPERATOR_SCHEMA_UNIQ_HELPER_ELSEWHERE(__COUNTER__, name, schema_func) #define ONNX_CONTRIB_OPERATOR_SCHEMA_UNIQ_HELPER_ELSEWHERE(Counter, name, schema_func) \ ONNX_CONTRIB_OPERATOR_SCHEMA_UNIQ_ELSEWHERE(Counter, name, schema_func) -#define ONNX_CONTRIB_OPERATOR_SCHEMA_UNIQ_ELSEWHERE(Counter, name, schema_func) \ - static ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce( \ - op_schema_register_once##name##Counter) ONNX_UNUSED = \ +#define ONNX_CONTRIB_OPERATOR_SCHEMA_UNIQ_ELSEWHERE(Counter, name, schema_func) \ + static ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce( \ + op_schema_register_once##name##Counter) ONNX_UNUSED = \ schema_func(ONNX_NAMESPACE::OpSchema(#name, __FILE__, __LINE__)) void RegisterContribSchemas(); -void RegisterContribKernels(std::function create_fn); -} // namespace contrib +} // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/range_schema_defs.cc b/onnxruntime/core/graph/contrib_ops/range_schema_defs.cc similarity index 100% rename from onnxruntime/contrib_ops/cpu/range_schema_defs.cc rename to onnxruntime/core/graph/contrib_ops/range_schema_defs.cc diff --git a/onnxruntime/contrib_ops/cpu/range_schema_defs.h b/onnxruntime/core/graph/contrib_ops/range_schema_defs.h similarity index 100% rename from onnxruntime/contrib_ops/cpu/range_schema_defs.h rename to onnxruntime/core/graph/contrib_ops/range_schema_defs.h diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index cdcb09e523..4cc36976ac 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -4,7 +4,7 @@ #include "core/providers/cpu/cpu_execution_provider.h" #include "core/framework/op_kernel.h" #include "core/framework/kernel_registry.h" -#include "contrib_ops/contrib_ops.h" +#include "contrib_ops/contrib_kernels.h" #include "core/framework/computation_capacity.h" namespace onnxruntime {