From 13f8b49d58bd58c3e12c42779bd2dcbd0d8760e7 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Thu, 17 Oct 2019 23:10:54 -0700 Subject: [PATCH] Fix kernel registry bug (#2137) --- .../core/framework/kernel_registry.h | 18 +++- onnxruntime/automl_ops/cpu_automl_kernels.cc | 5 +- onnxruntime/automl_ops/cpu_automl_kernels.h | 2 +- onnxruntime/contrib_ops/cpu/activations.cc | 3 +- .../contrib_ops/cpu_contrib_kernels.cc | 19 ++-- onnxruntime/contrib_ops/cpu_contrib_kernels.h | 2 +- .../core/framework/kernel_def_builder.cc | 1 + onnxruntime/core/framework/kernel_registry.cc | 58 ++++-------- .../providers/cpu/cpu_execution_provider.cc | 54 ++++++----- .../providers/cpu/math/element_wise_ops.cc | 6 +- .../test/framework/kernel_registry_test.cc | 94 +++++++++++++++++++ 11 files changed, 176 insertions(+), 86 deletions(-) create mode 100644 onnxruntime/test/framework/kernel_registry_test.cc diff --git a/include/onnxruntime/core/framework/kernel_registry.h b/include/onnxruntime/core/framework/kernel_registry.h index 95d9b1d415..aac9c4db5f 100644 --- a/include/onnxruntime/core/framework/kernel_registry.h +++ b/include/onnxruntime/core/framework/kernel_registry.h @@ -40,9 +40,8 @@ class KernelRegistry { bool IsEmpty() const { return kernel_creator_fn_map_.empty(); } #ifdef onnxruntime_PYBIND_EXPORT_OPSCHEMA -// This is used by the opkernel doc generator to enlist all registered operators for a given provider's opkernel - const KernelCreateMap& GetKernelCreateMap() const - { + // This is used by the opkernel doc generator to enlist all registered operators for a given provider's opkernel + const KernelCreateMap& GetKernelCreateMap() const { return kernel_creator_fn_map_; } #endif @@ -64,10 +63,19 @@ class KernelRegistry { // otherwise, kernel_def.provider must equal to node.provider. exec_provider is ignored. static bool VerifyKernelDef(const onnxruntime::Node& node, const KernelDef& kernel_def, - std::string& error_str, - onnxruntime::ProviderType exec_provider = ""); + std::string& error_str); + static std::string GetMapKey(const std::string& op_name, const std::string& domain, const std::string& provider) { + std::string key(op_name); + key.append(1, ' ').append(domain.empty() ? kOnnxDomainAlias : domain).append(1, ' ').append(provider); + return key; + } + + static std::string GetMapKey(const KernelDef& kernel_def) { + return GetMapKey(kernel_def.OpName(), kernel_def.Domain(), kernel_def.Provider()); + } // Kernel create function map from op name to kernel creation info. + // key is opname+domain_name+provider_name KernelCreateMap kernel_creator_fn_map_; }; } // namespace onnxruntime diff --git a/onnxruntime/automl_ops/cpu_automl_kernels.cc b/onnxruntime/automl_ops/cpu_automl_kernels.cc index 23d5e2ad72..e2a0e41ab9 100644 --- a/onnxruntime/automl_ops/cpu_automl_kernels.cc +++ b/onnxruntime/automl_ops/cpu_automl_kernels.cc @@ -10,15 +10,16 @@ namespace automl { class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSAutoMLDomain, 1, DateTimeTransformer); -void RegisterCpuAutoMLKernels(KernelRegistry& kernel_registry) { +Status RegisterCpuAutoMLKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { // add more kernels here BuildKernelCreateInfo }; for (auto& function_table_entry : function_table) { - kernel_registry.Register(function_table_entry()); + ORT_RETURN_IF_ERROR(kernel_registry.Register(function_table_entry())); } + return Status::OK(); } } // namespace automl diff --git a/onnxruntime/automl_ops/cpu_automl_kernels.h b/onnxruntime/automl_ops/cpu_automl_kernels.h index f14a8983d5..1f08cdf127 100644 --- a/onnxruntime/automl_ops/cpu_automl_kernels.h +++ b/onnxruntime/automl_ops/cpu_automl_kernels.h @@ -8,6 +8,6 @@ namespace onnxruntime { namespace automl { -void RegisterCpuAutoMLKernels(KernelRegistry& kernel_registry); +Status RegisterCpuAutoMLKernels(KernelRegistry& kernel_registry); } // namespace automl } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/activations.cc b/onnxruntime/contrib_ops/cpu/activations.cc index 28f96dbe61..ff5d781e75 100644 --- a/onnxruntime/contrib_ops/cpu/activations.cc +++ b/onnxruntime/contrib_ops/cpu/activations.cc @@ -19,9 +19,10 @@ ONNX_CPU_OPERATOR_KERNEL( KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()), ScaledTanh); -ONNX_CPU_OPERATOR_KERNEL( +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( ThresholdedRelu, 1, + 9, KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()), ThresholdedRelu); diff --git a/onnxruntime/contrib_ops/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu_contrib_kernels.cc index 02bc05b78a..fc51291cac 100644 --- a/onnxruntime/contrib_ops/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu_contrib_kernels.cc @@ -54,7 +54,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Ima class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 8, MeanVarianceNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ParametricSoftplus); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ScaledTanh); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ThresholdedRelu); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, ThresholdedRelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Scale); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomain, 1, float, ReorderInput); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomain, 1, float, ReorderOutput); @@ -66,7 +66,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomai class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, LayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, LayerNormalization); -void RegisterNchwcKernels(KernelRegistry& kernel_registry) { +Status RegisterNchwcKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -77,11 +77,12 @@ void RegisterNchwcKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo}; for (auto& function_table_entry : function_table) { - kernel_registry.Register(function_table_entry()); + ORT_RETURN_IF_ERROR(kernel_registry.Register(function_table_entry())); } + return Status::OK(); } -void RegisterCpuContribKernels(KernelRegistry& kernel_registry) { +Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, @@ -129,20 +130,20 @@ void RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo - }; + BuildKernelCreateInfo}; for (auto& function_table_entry : function_table) { - kernel_registry.Register(function_table_entry()); + ORT_RETURN_IF_ERROR(kernel_registry.Register(function_table_entry())); } // Register the NCHWc kernels if supported by the platform. if (MlasNchwcGetBlockSize() > 1) { - RegisterNchwcKernels(kernel_registry); + ORT_RETURN_IF_ERROR(RegisterNchwcKernels(kernel_registry)); } + return Status::OK(); } } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu_contrib_kernels.h b/onnxruntime/contrib_ops/cpu_contrib_kernels.h index aaa38161a0..1158af019e 100644 --- a/onnxruntime/contrib_ops/cpu_contrib_kernels.h +++ b/onnxruntime/contrib_ops/cpu_contrib_kernels.h @@ -8,6 +8,6 @@ namespace onnxruntime { namespace contrib { -void RegisterCpuContribKernels(KernelRegistry& kernel_registry); +Status RegisterCpuContribKernels(KernelRegistry& kernel_registry); } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/framework/kernel_def_builder.cc b/onnxruntime/core/framework/kernel_def_builder.cc index 724d43c3a0..0d475eab46 100644 --- a/onnxruntime/core/framework/kernel_def_builder.cc +++ b/onnxruntime/core/framework/kernel_def_builder.cc @@ -24,6 +24,7 @@ inline bool AreVectorsOverlap(const std::vector& v1, const std::vector& v2 } } // namespace +//TODO: Tell user why it has conflicts bool KernelDef::IsConflict(const KernelDef& other) const { if (op_name_ != other.OpName() || provider_type_ != other.Provider()) return false; diff --git a/onnxruntime/core/framework/kernel_registry.cc b/onnxruntime/core/framework/kernel_registry.cc index feb26f06a0..cb6e8bc83f 100644 --- a/onnxruntime/core/framework/kernel_registry.cc +++ b/onnxruntime/core/framework/kernel_registry.cc @@ -115,32 +115,7 @@ class TypeBindingResolver { bool KernelRegistry::VerifyKernelDef(const onnxruntime::Node& node, const KernelDef& kernel_def, - std::string& error_str, - onnxruntime::ProviderType exec_provider) { - // check if domain matches - if (node.Domain() != kernel_def.Domain()) { - std::ostringstream ostr; - ostr << "Op: " << node.OpType() - << " Domain mismatch: " - << " Expected: " << kernel_def.Domain() - << " Actual: " << node.Domain(); - error_str = ostr.str(); - return false; - } - - // check if execution provider matches - const auto& node_provider = node.GetExecutionProviderType(); - const auto& expected_provider = (node_provider.empty() ? exec_provider : node_provider); - if (expected_provider != kernel_def.Provider()) { - std::ostringstream ostr; - ostr << "Op: " << node.OpType() - << " Execution provider mismatch." - << " Expected: " << expected_provider - << " Actual: " << kernel_def.Provider(); - error_str = ostr.str(); - return false; - } - + std::string& error_str) { // check if version matches int kernel_start_version; int kernel_end_version; @@ -215,28 +190,26 @@ Status KernelRegistry::Register(KernelDefBuilder& kernel_builder, } Status KernelRegistry::Register(KernelCreateInfo&& create_info) { - auto& op_name = create_info.kernel_def->OpName(); - + if (!create_info.kernel_def) { + return Status(ONNXRUNTIME, FAIL, "kernel def can't be NULL"); + } + std::string key = GetMapKey(*create_info.kernel_def); // Check op version conflicts. - auto range = kernel_creator_fn_map_.equal_range(op_name); + auto range = kernel_creator_fn_map_.equal_range(key); for (auto i = range.first; i != range.second; ++i) { if (i->second.kernel_def && i->second.status.IsOK() && i->second.kernel_def->IsConflict(*create_info.kernel_def)) { - auto st = create_info.status = - Status(ONNXRUNTIME, FAIL, - "Failed to add kernel for " + op_name + - ": Conflicting with a registered kernel with op versions."); - // For invalid entries, we keep them in the map now. Must check for status - // when using the entries from the map. - kernel_creator_fn_map_.emplace(op_name, std::move(create_info)); - return st; + return create_info.status = + Status(ONNXRUNTIME, FAIL, + "Failed to add kernel for " + key + + ": Conflicting with a registered kernel with op versions."); } } // Register the kernel. // Ownership of the KernelDef is transferred to the map. - kernel_creator_fn_map_.emplace(op_name, std::move(create_info)); + kernel_creator_fn_map_.emplace(key, std::move(create_info)); return Status::OK(); } @@ -278,7 +251,10 @@ static std::string ToString(const std::vector& error_strs) { // otherwise, kernel_def.provider must equal to node.provider. exec_provider is ignored. const KernelCreateInfo* KernelRegistry::TryFindKernel(const onnxruntime::Node& node, onnxruntime::ProviderType exec_provider) const { - auto range = kernel_creator_fn_map_.equal_range(node.OpType()); + const auto& node_provider = node.GetExecutionProviderType(); + const auto& expected_provider = (node_provider.empty() ? exec_provider : node_provider); + + auto range = kernel_creator_fn_map_.equal_range(GetMapKey(node.OpType(), node.Domain(), expected_provider)); std::vector error_strs; for (auto i = range.first; i != range.second; ++i) { if (!i->second.status.IsOK()) { @@ -287,13 +263,11 @@ const KernelCreateInfo* KernelRegistry::TryFindKernel(const onnxruntime::Node& n continue; } std::string error_str; - if (VerifyKernelDef(node, *i->second.kernel_def, error_str, exec_provider)) { + if (VerifyKernelDef(node, *i->second.kernel_def, error_str)) { return &i->second; } error_strs.push_back(error_str); } - std::string expected_provider = - (node.GetExecutionProviderType().empty() ? exec_provider : node.GetExecutionProviderType()); LOGS_DEFAULT(INFO) << node.OpType() << " kernel is not supported in " << expected_provider << " Encountered following errors: " << ToString(error_strs); return nullptr; diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 5ecf2c8000..9c0f25231d 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -88,9 +88,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Xor class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, float, Less); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, double, Less); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, float, Greater); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, bool, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int32_t, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int64_t, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 10, bool, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 10, int32_t, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 10, int64_t, Equal); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, bool, Equal); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int32_t, Equal); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int64_t, Equal); @@ -410,7 +410,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Ra class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Unique); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, TopK); -void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { +Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, @@ -509,11 +509,11 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { double, Less)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1196,31 +1197,40 @@ void RegisterOnnxMLOperatorKernels(KernelRegistry& kernel_registry) { }; for (auto& function_table_entry : function_table) { - kernel_registry.Register(function_table_entry()); + ORT_RETURN_IF_ERROR(kernel_registry.Register(function_table_entry())); } + return Status::OK(); } } // namespace ml -static void RegisterCPUKernels(KernelRegistry& kernel_registry) { - RegisterOnnxOperatorKernels(kernel_registry); - ::onnxruntime::ml::RegisterOnnxMLOperatorKernels(kernel_registry); +static Status RegisterCPUKernels(KernelRegistry& kernel_registry) { + ORT_RETURN_IF_ERROR(RegisterOnnxOperatorKernels(kernel_registry)); + ORT_RETURN_IF_ERROR(::onnxruntime::ml::RegisterOnnxMLOperatorKernels(kernel_registry)); #ifndef DISABLE_CONTRIB_OPS - ::onnxruntime::contrib::RegisterCpuContribKernels(kernel_registry); + ORT_RETURN_IF_ERROR(::onnxruntime::contrib::RegisterCpuContribKernels(kernel_registry)); #endif #ifdef MICROSOFT_AUTOML - ::onnxruntime::automl::RegisterCpuAutoMLKernels(kernel_registry); + ORT_RETURN_IF_ERROR(::onnxruntime::automl::RegisterCpuAutoMLKernels(kernel_registry)); #endif + return Status::OK(); } -std::shared_ptr GetCpuKernelRegistry() { +struct KernelRegistryAndStatus{ std::shared_ptr kernel_registry = std::make_shared(); - RegisterCPUKernels(*kernel_registry); - return kernel_registry; + Status st; +}; + +KernelRegistryAndStatus GetCpuKernelRegistry() { + KernelRegistryAndStatus ret; + ret.st = RegisterCPUKernels(*ret.kernel_registry); + return ret; } std::shared_ptr CPUExecutionProvider::GetKernelRegistry() const { - static std::shared_ptr kernel_registry = GetCpuKernelRegistry(); - return kernel_registry; + static KernelRegistryAndStatus k = GetCpuKernelRegistry(); + //throw if the registry failed to initialize + ORT_THROW_IF_ERROR(k.st); + return k.kernel_registry; } std::unique_ptr CPUExecutionProvider::GetDataTransfer() const { diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index 949812bd1f..93c37c53ec 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -118,9 +118,9 @@ REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Greater, 7, 9, float, Greater) REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Greater, 9, int32_t, Greater); REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Greater, 9, int64_t, Greater); -REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 7, bool, Equal); -REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 7, int32_t, Equal); -REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 7, int64_t, Equal); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 7, 10, bool, Equal); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 7, 10, int32_t, Equal); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 7, 10, int64_t, Equal); REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 11, bool, Equal); REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 11, int32_t, Equal); REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 11, int64_t, Equal); diff --git a/onnxruntime/test/framework/kernel_registry_test.cc b/onnxruntime/test/framework/kernel_registry_test.cc new file mode 100644 index 0000000000..c186a0b035 --- /dev/null +++ b/onnxruntime/test/framework/kernel_registry_test.cc @@ -0,0 +1,94 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include + +using namespace onnxruntime; +static Status RegKernels(KernelRegistry& r, std::vector >& function_table, const KernelCreateFn& kernel_creator) { + for (auto& function_table_entry : function_table) { + ORT_RETURN_IF_ERROR(r.Register(KernelCreateInfo(std::move(function_table_entry), kernel_creator))); + } + return Status::OK(); +} + +class FakeKernel final : public OpKernel { + public: + FakeKernel(const OpKernelInfo& info) : OpKernel(info) {} + Status Compute(OpKernelContext*) const override { + return Status::OK(); + } +}; + +OpKernel* CreateFakeKernel(const OpKernelInfo& info) { + return new FakeKernel(info); +} + +TEST(KernelRegistryTests, simple) { + KernelRegistry r; + std::vector > function_table; + function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()).SetName("Elu").SetDomain("").SinceVersion(6).Provider(kCpuExecutionProvider).Build()); + + Status st; + ASSERT_TRUE((st = RegKernels(r, function_table, CreateFakeKernel)).IsOK()) << st.ErrorMessage(); +} + +TEST(KernelRegistryTests, dup_simple) { + KernelRegistry r; + std::vector > function_table; + function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()).SetName("Elu").SetDomain("").SinceVersion(6).Provider(kCpuExecutionProvider).Build()); + function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()).SetName("Elu").SetDomain("").SinceVersion(6).Provider(kCpuExecutionProvider).Build()); + Status st; + ASSERT_FALSE((st = RegKernels(r, function_table, CreateFakeKernel)).IsOK()) << st.ErrorMessage(); +} + +//duplicated registration. One in default("") domain, another in "ai.onnx" domain +TEST(KernelRegistryTests, dup_simple2) { + KernelRegistry r; + std::vector > function_table; + function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()).SetName("Elu").SetDomain("").SinceVersion(6).Provider(kCpuExecutionProvider).Build()); + function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()).SetName("Elu").SetDomain("ai.onnx").SinceVersion(6).Provider(kCpuExecutionProvider).Build()); + Status st; + ASSERT_FALSE((st = RegKernels(r, function_table, CreateFakeKernel)).IsOK()) << st.ErrorMessage(); +} + +//One in default("") domain, another in ms domain. Should be ok +TEST(KernelRegistryTests, one_op_name_in_two_domains) { + KernelRegistry r; + std::vector > function_table; + function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()).SetName("Elu").SetDomain("").SinceVersion(6).Provider(kCpuExecutionProvider).Build()); + function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()).SetName("Elu").SetDomain(kMSDomain).SinceVersion(6).Provider(kCpuExecutionProvider).Build()); + Status st; + ASSERT_TRUE((st = RegKernels(r, function_table, CreateFakeKernel)).IsOK()) << st.ErrorMessage(); +} + +//One op two versions +TEST(KernelRegistryTests, two_versions) { + KernelRegistry r; + std::vector > function_table; + function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()).SetName("Elu").SetDomain("").SinceVersion(6).Provider(kCpuExecutionProvider).Build()); + function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()).SetName("Elu").SetDomain("").SinceVersion(1, 5).Provider(kCpuExecutionProvider).Build()); + Status st; + ASSERT_TRUE((st = RegKernels(r, function_table, CreateFakeKernel)).IsOK()) << st.ErrorMessage(); +} + +//One op two versions +TEST(KernelRegistryTests, two_versions2) { + KernelRegistry r; + std::vector > function_table; + function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()).SetName("Elu").SetDomain("").SinceVersion(6).Provider(kCpuExecutionProvider).Build()); + function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()).SetName("Elu").SetDomain("").SinceVersion(1, 6).Provider(kCpuExecutionProvider).Build()); + Status st; + ASSERT_FALSE((st = RegKernels(r, function_table, CreateFakeKernel)).IsOK()); +} + +//One op two versions +TEST(KernelRegistryTests, two_versions3) { + KernelRegistry r; + std::vector > function_table; + function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()).SetName("Elu").SetDomain("").SinceVersion(6).Provider(kCpuExecutionProvider).Build()); + function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()).SetName("Elu").SetDomain("").SinceVersion(1).Provider(kCpuExecutionProvider).Build()); + Status st; + ASSERT_FALSE((st = RegKernels(r, function_table, CreateFakeKernel)).IsOK()); +} \ No newline at end of file