Fix kernel registry bug (#2137)

This commit is contained in:
Changming Sun 2019-10-17 23:10:54 -07:00 committed by GitHub
parent 2bf1778a5c
commit 13f8b49d58
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 176 additions and 86 deletions

View file

@ -40,9 +40,8 @@ class KernelRegistry {
bool IsEmpty() const { return kernel_creator_fn_map_.empty(); }
#ifdef onnxruntime_PYBIND_EXPORT_OPSCHEMA
// This is used by the opkernel doc generator to enlist all registered operators for a given provider's opkernel
const KernelCreateMap& GetKernelCreateMap() const
{
// This is used by the opkernel doc generator to enlist all registered operators for a given provider's opkernel
const KernelCreateMap& GetKernelCreateMap() const {
return kernel_creator_fn_map_;
}
#endif
@ -64,10 +63,19 @@ class KernelRegistry {
// otherwise, kernel_def.provider must equal to node.provider. exec_provider is ignored.
static bool VerifyKernelDef(const onnxruntime::Node& node,
const KernelDef& kernel_def,
std::string& error_str,
onnxruntime::ProviderType exec_provider = "");
std::string& error_str);
static std::string GetMapKey(const std::string& op_name, const std::string& domain, const std::string& provider) {
std::string key(op_name);
key.append(1, ' ').append(domain.empty() ? kOnnxDomainAlias : domain).append(1, ' ').append(provider);
return key;
}
static std::string GetMapKey(const KernelDef& kernel_def) {
return GetMapKey(kernel_def.OpName(), kernel_def.Domain(), kernel_def.Provider());
}
// Kernel create function map from op name to kernel creation info.
// key is opname+domain_name+provider_name
KernelCreateMap kernel_creator_fn_map_;
};
} // namespace onnxruntime

View file

@ -10,15 +10,16 @@ namespace automl {
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSAutoMLDomain, 1, DateTimeTransformer);
void RegisterCpuAutoMLKernels(KernelRegistry& kernel_registry) {
Status RegisterCpuAutoMLKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
// add more kernels here
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSAutoMLDomain, 1, DateTimeTransformer)>
};
for (auto& function_table_entry : function_table) {
kernel_registry.Register(function_table_entry());
ORT_RETURN_IF_ERROR(kernel_registry.Register(function_table_entry()));
}
return Status::OK();
}
} // namespace automl

View file

@ -8,6 +8,6 @@
namespace onnxruntime {
namespace automl {
void RegisterCpuAutoMLKernels(KernelRegistry& kernel_registry);
Status RegisterCpuAutoMLKernels(KernelRegistry& kernel_registry);
} // namespace automl
} // namespace onnxruntime

View file

@ -19,9 +19,10 @@ ONNX_CPU_OPERATOR_KERNEL(
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
ScaledTanh<float>);
ONNX_CPU_OPERATOR_KERNEL(
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
ThresholdedRelu,
1,
9,
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
ThresholdedRelu<float>);

View file

@ -54,7 +54,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Ima
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 8, MeanVarianceNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ParametricSoftplus);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ScaledTanh);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ThresholdedRelu);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, ThresholdedRelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Scale);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomain, 1, float, ReorderInput);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomain, 1, float, ReorderOutput);
@ -66,7 +66,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomai
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, LayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, LayerNormalization);
void RegisterNchwcKernels(KernelRegistry& kernel_registry) {
Status RegisterNchwcKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomain, 1, float, ReorderInput)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomain, 1, float, ReorderOutput)>,
@ -77,11 +77,12 @@ void RegisterNchwcKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomain, 1, float, GlobalAveragePool)>};
for (auto& function_table_entry : function_table) {
kernel_registry.Register(function_table_entry());
ORT_RETURN_IF_ERROR(kernel_registry.Register(function_table_entry()));
}
return Status::OK();
}
void RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SampleOp)>,
@ -129,20 +130,20 @@ void RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 8, MeanVarianceNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ParametricSoftplus)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ScaledTanh)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ThresholdedRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, ThresholdedRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Scale)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, LayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, LayerNormalization)>
};
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, LayerNormalization)>};
for (auto& function_table_entry : function_table) {
kernel_registry.Register(function_table_entry());
ORT_RETURN_IF_ERROR(kernel_registry.Register(function_table_entry()));
}
// Register the NCHWc kernels if supported by the platform.
if (MlasNchwcGetBlockSize() > 1) {
RegisterNchwcKernels(kernel_registry);
ORT_RETURN_IF_ERROR(RegisterNchwcKernels(kernel_registry));
}
return Status::OK();
}
} // namespace contrib

View file

@ -8,6 +8,6 @@
namespace onnxruntime {
namespace contrib {
void RegisterCpuContribKernels(KernelRegistry& kernel_registry);
Status RegisterCpuContribKernels(KernelRegistry& kernel_registry);
} // namespace contrib
} // namespace onnxruntime

View file

@ -24,6 +24,7 @@ inline bool AreVectorsOverlap(const std::vector<T>& v1, const std::vector<T>& v2
}
} // namespace
//TODO: Tell user why it has conflicts
bool KernelDef::IsConflict(const KernelDef& other) const {
if (op_name_ != other.OpName() || provider_type_ != other.Provider())
return false;

View file

@ -115,32 +115,7 @@ class TypeBindingResolver {
bool KernelRegistry::VerifyKernelDef(const onnxruntime::Node& node,
const KernelDef& kernel_def,
std::string& error_str,
onnxruntime::ProviderType exec_provider) {
// check if domain matches
if (node.Domain() != kernel_def.Domain()) {
std::ostringstream ostr;
ostr << "Op: " << node.OpType()
<< " Domain mismatch: "
<< " Expected: " << kernel_def.Domain()
<< " Actual: " << node.Domain();
error_str = ostr.str();
return false;
}
// check if execution provider matches
const auto& node_provider = node.GetExecutionProviderType();
const auto& expected_provider = (node_provider.empty() ? exec_provider : node_provider);
if (expected_provider != kernel_def.Provider()) {
std::ostringstream ostr;
ostr << "Op: " << node.OpType()
<< " Execution provider mismatch."
<< " Expected: " << expected_provider
<< " Actual: " << kernel_def.Provider();
error_str = ostr.str();
return false;
}
std::string& error_str) {
// check if version matches
int kernel_start_version;
int kernel_end_version;
@ -215,28 +190,26 @@ Status KernelRegistry::Register(KernelDefBuilder& kernel_builder,
}
Status KernelRegistry::Register(KernelCreateInfo&& create_info) {
auto& op_name = create_info.kernel_def->OpName();
if (!create_info.kernel_def) {
return Status(ONNXRUNTIME, FAIL, "kernel def can't be NULL");
}
std::string key = GetMapKey(*create_info.kernel_def);
// Check op version conflicts.
auto range = kernel_creator_fn_map_.equal_range(op_name);
auto range = kernel_creator_fn_map_.equal_range(key);
for (auto i = range.first; i != range.second; ++i) {
if (i->second.kernel_def &&
i->second.status.IsOK() &&
i->second.kernel_def->IsConflict(*create_info.kernel_def)) {
auto st = create_info.status =
Status(ONNXRUNTIME, FAIL,
"Failed to add kernel for " + op_name +
": Conflicting with a registered kernel with op versions.");
// For invalid entries, we keep them in the map now. Must check for status
// when using the entries from the map.
kernel_creator_fn_map_.emplace(op_name, std::move(create_info));
return st;
return create_info.status =
Status(ONNXRUNTIME, FAIL,
"Failed to add kernel for " + key +
": Conflicting with a registered kernel with op versions.");
}
}
// Register the kernel.
// Ownership of the KernelDef is transferred to the map.
kernel_creator_fn_map_.emplace(op_name, std::move(create_info));
kernel_creator_fn_map_.emplace(key, std::move(create_info));
return Status::OK();
}
@ -278,7 +251,10 @@ static std::string ToString(const std::vector<std::string>& error_strs) {
// otherwise, kernel_def.provider must equal to node.provider. exec_provider is ignored.
const KernelCreateInfo* KernelRegistry::TryFindKernel(const onnxruntime::Node& node,
onnxruntime::ProviderType exec_provider) const {
auto range = kernel_creator_fn_map_.equal_range(node.OpType());
const auto& node_provider = node.GetExecutionProviderType();
const auto& expected_provider = (node_provider.empty() ? exec_provider : node_provider);
auto range = kernel_creator_fn_map_.equal_range(GetMapKey(node.OpType(), node.Domain(), expected_provider));
std::vector<std::string> error_strs;
for (auto i = range.first; i != range.second; ++i) {
if (!i->second.status.IsOK()) {
@ -287,13 +263,11 @@ const KernelCreateInfo* KernelRegistry::TryFindKernel(const onnxruntime::Node& n
continue;
}
std::string error_str;
if (VerifyKernelDef(node, *i->second.kernel_def, error_str, exec_provider)) {
if (VerifyKernelDef(node, *i->second.kernel_def, error_str)) {
return &i->second;
}
error_strs.push_back(error_str);
}
std::string expected_provider =
(node.GetExecutionProviderType().empty() ? exec_provider : node.GetExecutionProviderType());
LOGS_DEFAULT(INFO) << node.OpType() << " kernel is not supported in " << expected_provider
<< " Encountered following errors: " << ToString(error_strs);
return nullptr;

View file

@ -88,9 +88,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Xor
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, float, Less);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, double, Less);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, float, Greater);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, bool, Equal);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int32_t, Equal);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int64_t, Equal);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 10, bool, Equal);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 10, int32_t, Equal);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 10, int64_t, Equal);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, bool, Equal);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int32_t, Equal);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int64_t, Equal);
@ -410,7 +410,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Ra
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Unique);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, TopK);
void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 10,
Clip)>,
@ -509,11 +509,11 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
double, Less)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9,
float, Greater)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, bool, Equal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int32_t,
Equal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int64_t,
Equal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 10, bool, Equal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 10, int32_t,
Equal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 10, int64_t,
Equal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, bool,
Equal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int32_t,
@ -1047,8 +1047,9 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
};
for (auto& function_table_entry : function_table) {
kernel_registry.Register(function_table_entry());
ORT_RETURN_IF_ERROR(kernel_registry.Register(function_table_entry()));
}
return Status::OK();
} // namespace onnxruntime
// Forward declarations of ml op kernels
@ -1104,7 +1105,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, int64_string, LabelEncoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, string_int64, LabelEncoder);
void RegisterOnnxMLOperatorKernels(KernelRegistry& kernel_registry) {
Status RegisterOnnxMLOperatorKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, float,
ArrayFeatureExtractor)>,
@ -1196,31 +1197,40 @@ void RegisterOnnxMLOperatorKernels(KernelRegistry& kernel_registry) {
};
for (auto& function_table_entry : function_table) {
kernel_registry.Register(function_table_entry());
ORT_RETURN_IF_ERROR(kernel_registry.Register(function_table_entry()));
}
return Status::OK();
}
} // namespace ml
static void RegisterCPUKernels(KernelRegistry& kernel_registry) {
RegisterOnnxOperatorKernels(kernel_registry);
::onnxruntime::ml::RegisterOnnxMLOperatorKernels(kernel_registry);
static Status RegisterCPUKernels(KernelRegistry& kernel_registry) {
ORT_RETURN_IF_ERROR(RegisterOnnxOperatorKernels(kernel_registry));
ORT_RETURN_IF_ERROR(::onnxruntime::ml::RegisterOnnxMLOperatorKernels(kernel_registry));
#ifndef DISABLE_CONTRIB_OPS
::onnxruntime::contrib::RegisterCpuContribKernels(kernel_registry);
ORT_RETURN_IF_ERROR(::onnxruntime::contrib::RegisterCpuContribKernels(kernel_registry));
#endif
#ifdef MICROSOFT_AUTOML
::onnxruntime::automl::RegisterCpuAutoMLKernels(kernel_registry);
ORT_RETURN_IF_ERROR(::onnxruntime::automl::RegisterCpuAutoMLKernels(kernel_registry));
#endif
return Status::OK();
}
std::shared_ptr<KernelRegistry> GetCpuKernelRegistry() {
struct KernelRegistryAndStatus{
std::shared_ptr<KernelRegistry> kernel_registry = std::make_shared<KernelRegistry>();
RegisterCPUKernels(*kernel_registry);
return kernel_registry;
Status st;
};
KernelRegistryAndStatus GetCpuKernelRegistry() {
KernelRegistryAndStatus ret;
ret.st = RegisterCPUKernels(*ret.kernel_registry);
return ret;
}
std::shared_ptr<KernelRegistry> CPUExecutionProvider::GetKernelRegistry() const {
static std::shared_ptr<KernelRegistry> kernel_registry = GetCpuKernelRegistry();
return kernel_registry;
static KernelRegistryAndStatus k = GetCpuKernelRegistry();
//throw if the registry failed to initialize
ORT_THROW_IF_ERROR(k.st);
return k.kernel_registry;
}
std::unique_ptr<IDataTransfer> CPUExecutionProvider::GetDataTransfer() const {

View file

@ -118,9 +118,9 @@ REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Greater, 7, 9, float, Greater)
REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Greater, 9, int32_t, Greater);
REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Greater, 9, int64_t, Greater);
REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 7, bool, Equal);
REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 7, int32_t, Equal);
REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 7, int64_t, Equal);
REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 7, 10, bool, Equal);
REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 7, 10, int32_t, Equal);
REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 7, 10, int64_t, Equal);
REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 11, bool, Equal);
REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 11, int32_t, Equal);
REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 11, int64_t, Equal);

View file

@ -0,0 +1,94 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <gtest/gtest.h>
#include <core/framework/kernel_registry.h>
#include <core/framework/op_kernel.h>
using namespace onnxruntime;
static Status RegKernels(KernelRegistry& r, std::vector<std::unique_ptr<KernelDef> >& function_table, const KernelCreateFn& kernel_creator) {
for (auto& function_table_entry : function_table) {
ORT_RETURN_IF_ERROR(r.Register(KernelCreateInfo(std::move(function_table_entry), kernel_creator)));
}
return Status::OK();
}
class FakeKernel final : public OpKernel {
public:
FakeKernel(const OpKernelInfo& info) : OpKernel(info) {}
Status Compute(OpKernelContext*) const override {
return Status::OK();
}
};
OpKernel* CreateFakeKernel(const OpKernelInfo& info) {
return new FakeKernel(info);
}
TEST(KernelRegistryTests, simple) {
KernelRegistry r;
std::vector<std::unique_ptr<KernelDef> > function_table;
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain("").SinceVersion(6).Provider(kCpuExecutionProvider).Build());
Status st;
ASSERT_TRUE((st = RegKernels(r, function_table, CreateFakeKernel)).IsOK()) << st.ErrorMessage();
}
TEST(KernelRegistryTests, dup_simple) {
KernelRegistry r;
std::vector<std::unique_ptr<KernelDef> > function_table;
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain("").SinceVersion(6).Provider(kCpuExecutionProvider).Build());
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain("").SinceVersion(6).Provider(kCpuExecutionProvider).Build());
Status st;
ASSERT_FALSE((st = RegKernels(r, function_table, CreateFakeKernel)).IsOK()) << st.ErrorMessage();
}
//duplicated registration. One in default("") domain, another in "ai.onnx" domain
TEST(KernelRegistryTests, dup_simple2) {
KernelRegistry r;
std::vector<std::unique_ptr<KernelDef> > function_table;
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain("").SinceVersion(6).Provider(kCpuExecutionProvider).Build());
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain("ai.onnx").SinceVersion(6).Provider(kCpuExecutionProvider).Build());
Status st;
ASSERT_FALSE((st = RegKernels(r, function_table, CreateFakeKernel)).IsOK()) << st.ErrorMessage();
}
//One in default("") domain, another in ms domain. Should be ok
TEST(KernelRegistryTests, one_op_name_in_two_domains) {
KernelRegistry r;
std::vector<std::unique_ptr<KernelDef> > function_table;
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain("").SinceVersion(6).Provider(kCpuExecutionProvider).Build());
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain(kMSDomain).SinceVersion(6).Provider(kCpuExecutionProvider).Build());
Status st;
ASSERT_TRUE((st = RegKernels(r, function_table, CreateFakeKernel)).IsOK()) << st.ErrorMessage();
}
//One op two versions
TEST(KernelRegistryTests, two_versions) {
KernelRegistry r;
std::vector<std::unique_ptr<KernelDef> > function_table;
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain("").SinceVersion(6).Provider(kCpuExecutionProvider).Build());
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain("").SinceVersion(1, 5).Provider(kCpuExecutionProvider).Build());
Status st;
ASSERT_TRUE((st = RegKernels(r, function_table, CreateFakeKernel)).IsOK()) << st.ErrorMessage();
}
//One op two versions
TEST(KernelRegistryTests, two_versions2) {
KernelRegistry r;
std::vector<std::unique_ptr<KernelDef> > function_table;
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain("").SinceVersion(6).Provider(kCpuExecutionProvider).Build());
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain("").SinceVersion(1, 6).Provider(kCpuExecutionProvider).Build());
Status st;
ASSERT_FALSE((st = RegKernels(r, function_table, CreateFakeKernel)).IsOK());
}
//One op two versions
TEST(KernelRegistryTests, two_versions3) {
KernelRegistry r;
std::vector<std::unique_ptr<KernelDef> > function_table;
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain("").SinceVersion(6).Provider(kCpuExecutionProvider).Build());
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain("").SinceVersion(1).Provider(kCpuExecutionProvider).Build());
Status st;
ASSERT_FALSE((st = RegKernels(r, function_table, CreateFakeKernel)).IsOK());
}