mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-07 00:13:17 +00:00
Fix kernel registry bug (#2137)
This commit is contained in:
parent
2bf1778a5c
commit
13f8b49d58
11 changed files with 176 additions and 86 deletions
|
|
@ -40,9 +40,8 @@ class KernelRegistry {
|
|||
bool IsEmpty() const { return kernel_creator_fn_map_.empty(); }
|
||||
|
||||
#ifdef onnxruntime_PYBIND_EXPORT_OPSCHEMA
|
||||
// This is used by the opkernel doc generator to enlist all registered operators for a given provider's opkernel
|
||||
const KernelCreateMap& GetKernelCreateMap() const
|
||||
{
|
||||
// This is used by the opkernel doc generator to enlist all registered operators for a given provider's opkernel
|
||||
const KernelCreateMap& GetKernelCreateMap() const {
|
||||
return kernel_creator_fn_map_;
|
||||
}
|
||||
#endif
|
||||
|
|
@ -64,10 +63,19 @@ class KernelRegistry {
|
|||
// otherwise, kernel_def.provider must equal to node.provider. exec_provider is ignored.
|
||||
static bool VerifyKernelDef(const onnxruntime::Node& node,
|
||||
const KernelDef& kernel_def,
|
||||
std::string& error_str,
|
||||
onnxruntime::ProviderType exec_provider = "");
|
||||
std::string& error_str);
|
||||
|
||||
static std::string GetMapKey(const std::string& op_name, const std::string& domain, const std::string& provider) {
|
||||
std::string key(op_name);
|
||||
key.append(1, ' ').append(domain.empty() ? kOnnxDomainAlias : domain).append(1, ' ').append(provider);
|
||||
return key;
|
||||
}
|
||||
|
||||
static std::string GetMapKey(const KernelDef& kernel_def) {
|
||||
return GetMapKey(kernel_def.OpName(), kernel_def.Domain(), kernel_def.Provider());
|
||||
}
|
||||
// Kernel create function map from op name to kernel creation info.
|
||||
// key is opname+domain_name+provider_name
|
||||
KernelCreateMap kernel_creator_fn_map_;
|
||||
};
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -10,15 +10,16 @@ namespace automl {
|
|||
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSAutoMLDomain, 1, DateTimeTransformer);
|
||||
|
||||
void RegisterCpuAutoMLKernels(KernelRegistry& kernel_registry) {
|
||||
Status RegisterCpuAutoMLKernels(KernelRegistry& kernel_registry) {
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
// add more kernels here
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSAutoMLDomain, 1, DateTimeTransformer)>
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
kernel_registry.Register(function_table_entry());
|
||||
ORT_RETURN_IF_ERROR(kernel_registry.Register(function_table_entry()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace automl
|
||||
|
|
|
|||
|
|
@ -8,6 +8,6 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
namespace automl {
|
||||
void RegisterCpuAutoMLKernels(KernelRegistry& kernel_registry);
|
||||
Status RegisterCpuAutoMLKernels(KernelRegistry& kernel_registry);
|
||||
} // namespace automl
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -19,9 +19,10 @@ ONNX_CPU_OPERATOR_KERNEL(
|
|||
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
ScaledTanh<float>);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
||||
ThresholdedRelu,
|
||||
1,
|
||||
9,
|
||||
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
ThresholdedRelu<float>);
|
||||
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Ima
|
|||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 8, MeanVarianceNormalization);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ParametricSoftplus);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ScaledTanh);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ThresholdedRelu);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, ThresholdedRelu);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Scale);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomain, 1, float, ReorderInput);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomain, 1, float, ReorderOutput);
|
||||
|
|
@ -66,7 +66,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomai
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, LayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, LayerNormalization);
|
||||
|
||||
void RegisterNchwcKernels(KernelRegistry& kernel_registry) {
|
||||
Status RegisterNchwcKernels(KernelRegistry& kernel_registry) {
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomain, 1, float, ReorderInput)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomain, 1, float, ReorderOutput)>,
|
||||
|
|
@ -77,11 +77,12 @@ void RegisterNchwcKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomain, 1, float, GlobalAveragePool)>};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
kernel_registry.Register(function_table_entry());
|
||||
ORT_RETURN_IF_ERROR(kernel_registry.Register(function_table_entry()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
|
||||
Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SampleOp)>,
|
||||
|
||||
|
|
@ -129,20 +130,20 @@ void RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 8, MeanVarianceNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ParametricSoftplus)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ScaledTanh)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ThresholdedRelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, ThresholdedRelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Scale)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, LayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, LayerNormalization)>
|
||||
};
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, LayerNormalization)>};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
kernel_registry.Register(function_table_entry());
|
||||
ORT_RETURN_IF_ERROR(kernel_registry.Register(function_table_entry()));
|
||||
}
|
||||
|
||||
// Register the NCHWc kernels if supported by the platform.
|
||||
if (MlasNchwcGetBlockSize() > 1) {
|
||||
RegisterNchwcKernels(kernel_registry);
|
||||
ORT_RETURN_IF_ERROR(RegisterNchwcKernels(kernel_registry));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace contrib
|
||||
|
|
|
|||
|
|
@ -8,6 +8,6 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
void RegisterCpuContribKernels(KernelRegistry& kernel_registry);
|
||||
Status RegisterCpuContribKernels(KernelRegistry& kernel_registry);
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ inline bool AreVectorsOverlap(const std::vector<T>& v1, const std::vector<T>& v2
|
|||
}
|
||||
} // namespace
|
||||
|
||||
//TODO: Tell user why it has conflicts
|
||||
bool KernelDef::IsConflict(const KernelDef& other) const {
|
||||
if (op_name_ != other.OpName() || provider_type_ != other.Provider())
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -115,32 +115,7 @@ class TypeBindingResolver {
|
|||
|
||||
bool KernelRegistry::VerifyKernelDef(const onnxruntime::Node& node,
|
||||
const KernelDef& kernel_def,
|
||||
std::string& error_str,
|
||||
onnxruntime::ProviderType exec_provider) {
|
||||
// check if domain matches
|
||||
if (node.Domain() != kernel_def.Domain()) {
|
||||
std::ostringstream ostr;
|
||||
ostr << "Op: " << node.OpType()
|
||||
<< " Domain mismatch: "
|
||||
<< " Expected: " << kernel_def.Domain()
|
||||
<< " Actual: " << node.Domain();
|
||||
error_str = ostr.str();
|
||||
return false;
|
||||
}
|
||||
|
||||
// check if execution provider matches
|
||||
const auto& node_provider = node.GetExecutionProviderType();
|
||||
const auto& expected_provider = (node_provider.empty() ? exec_provider : node_provider);
|
||||
if (expected_provider != kernel_def.Provider()) {
|
||||
std::ostringstream ostr;
|
||||
ostr << "Op: " << node.OpType()
|
||||
<< " Execution provider mismatch."
|
||||
<< " Expected: " << expected_provider
|
||||
<< " Actual: " << kernel_def.Provider();
|
||||
error_str = ostr.str();
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string& error_str) {
|
||||
// check if version matches
|
||||
int kernel_start_version;
|
||||
int kernel_end_version;
|
||||
|
|
@ -215,28 +190,26 @@ Status KernelRegistry::Register(KernelDefBuilder& kernel_builder,
|
|||
}
|
||||
|
||||
Status KernelRegistry::Register(KernelCreateInfo&& create_info) {
|
||||
auto& op_name = create_info.kernel_def->OpName();
|
||||
|
||||
if (!create_info.kernel_def) {
|
||||
return Status(ONNXRUNTIME, FAIL, "kernel def can't be NULL");
|
||||
}
|
||||
std::string key = GetMapKey(*create_info.kernel_def);
|
||||
// Check op version conflicts.
|
||||
auto range = kernel_creator_fn_map_.equal_range(op_name);
|
||||
auto range = kernel_creator_fn_map_.equal_range(key);
|
||||
for (auto i = range.first; i != range.second; ++i) {
|
||||
if (i->second.kernel_def &&
|
||||
i->second.status.IsOK() &&
|
||||
i->second.kernel_def->IsConflict(*create_info.kernel_def)) {
|
||||
auto st = create_info.status =
|
||||
Status(ONNXRUNTIME, FAIL,
|
||||
"Failed to add kernel for " + op_name +
|
||||
": Conflicting with a registered kernel with op versions.");
|
||||
// For invalid entries, we keep them in the map now. Must check for status
|
||||
// when using the entries from the map.
|
||||
kernel_creator_fn_map_.emplace(op_name, std::move(create_info));
|
||||
return st;
|
||||
return create_info.status =
|
||||
Status(ONNXRUNTIME, FAIL,
|
||||
"Failed to add kernel for " + key +
|
||||
": Conflicting with a registered kernel with op versions.");
|
||||
}
|
||||
}
|
||||
|
||||
// Register the kernel.
|
||||
// Ownership of the KernelDef is transferred to the map.
|
||||
kernel_creator_fn_map_.emplace(op_name, std::move(create_info));
|
||||
kernel_creator_fn_map_.emplace(key, std::move(create_info));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -278,7 +251,10 @@ static std::string ToString(const std::vector<std::string>& error_strs) {
|
|||
// otherwise, kernel_def.provider must equal to node.provider. exec_provider is ignored.
|
||||
const KernelCreateInfo* KernelRegistry::TryFindKernel(const onnxruntime::Node& node,
|
||||
onnxruntime::ProviderType exec_provider) const {
|
||||
auto range = kernel_creator_fn_map_.equal_range(node.OpType());
|
||||
const auto& node_provider = node.GetExecutionProviderType();
|
||||
const auto& expected_provider = (node_provider.empty() ? exec_provider : node_provider);
|
||||
|
||||
auto range = kernel_creator_fn_map_.equal_range(GetMapKey(node.OpType(), node.Domain(), expected_provider));
|
||||
std::vector<std::string> error_strs;
|
||||
for (auto i = range.first; i != range.second; ++i) {
|
||||
if (!i->second.status.IsOK()) {
|
||||
|
|
@ -287,13 +263,11 @@ const KernelCreateInfo* KernelRegistry::TryFindKernel(const onnxruntime::Node& n
|
|||
continue;
|
||||
}
|
||||
std::string error_str;
|
||||
if (VerifyKernelDef(node, *i->second.kernel_def, error_str, exec_provider)) {
|
||||
if (VerifyKernelDef(node, *i->second.kernel_def, error_str)) {
|
||||
return &i->second;
|
||||
}
|
||||
error_strs.push_back(error_str);
|
||||
}
|
||||
std::string expected_provider =
|
||||
(node.GetExecutionProviderType().empty() ? exec_provider : node.GetExecutionProviderType());
|
||||
LOGS_DEFAULT(INFO) << node.OpType() << " kernel is not supported in " << expected_provider
|
||||
<< " Encountered following errors: " << ToString(error_strs);
|
||||
return nullptr;
|
||||
|
|
|
|||
|
|
@ -88,9 +88,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Xor
|
|||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, float, Less);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, double, Less);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, float, Greater);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, bool, Equal);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int32_t, Equal);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int64_t, Equal);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 10, bool, Equal);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 10, int32_t, Equal);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 10, int64_t, Equal);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, bool, Equal);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int32_t, Equal);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int64_t, Equal);
|
||||
|
|
@ -410,7 +410,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Ra
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Unique);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, TopK);
|
||||
|
||||
void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
||||
Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 10,
|
||||
Clip)>,
|
||||
|
|
@ -509,11 +509,11 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
double, Less)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9,
|
||||
float, Greater)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, bool, Equal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int32_t,
|
||||
Equal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int64_t,
|
||||
Equal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 10, bool, Equal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 10, int32_t,
|
||||
Equal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 10, int64_t,
|
||||
Equal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, bool,
|
||||
Equal)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int32_t,
|
||||
|
|
@ -1047,8 +1047,9 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
kernel_registry.Register(function_table_entry());
|
||||
ORT_RETURN_IF_ERROR(kernel_registry.Register(function_table_entry()));
|
||||
}
|
||||
return Status::OK();
|
||||
} // namespace onnxruntime
|
||||
|
||||
// Forward declarations of ml op kernels
|
||||
|
|
@ -1104,7 +1105,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, int64_string, LabelEncoder);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, string_int64, LabelEncoder);
|
||||
|
||||
void RegisterOnnxMLOperatorKernels(KernelRegistry& kernel_registry) {
|
||||
Status RegisterOnnxMLOperatorKernels(KernelRegistry& kernel_registry) {
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, float,
|
||||
ArrayFeatureExtractor)>,
|
||||
|
|
@ -1196,31 +1197,40 @@ void RegisterOnnxMLOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
kernel_registry.Register(function_table_entry());
|
||||
ORT_RETURN_IF_ERROR(kernel_registry.Register(function_table_entry()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace ml
|
||||
|
||||
static void RegisterCPUKernels(KernelRegistry& kernel_registry) {
|
||||
RegisterOnnxOperatorKernels(kernel_registry);
|
||||
::onnxruntime::ml::RegisterOnnxMLOperatorKernels(kernel_registry);
|
||||
static Status RegisterCPUKernels(KernelRegistry& kernel_registry) {
|
||||
ORT_RETURN_IF_ERROR(RegisterOnnxOperatorKernels(kernel_registry));
|
||||
ORT_RETURN_IF_ERROR(::onnxruntime::ml::RegisterOnnxMLOperatorKernels(kernel_registry));
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
::onnxruntime::contrib::RegisterCpuContribKernels(kernel_registry);
|
||||
ORT_RETURN_IF_ERROR(::onnxruntime::contrib::RegisterCpuContribKernels(kernel_registry));
|
||||
#endif
|
||||
#ifdef MICROSOFT_AUTOML
|
||||
::onnxruntime::automl::RegisterCpuAutoMLKernels(kernel_registry);
|
||||
ORT_RETURN_IF_ERROR(::onnxruntime::automl::RegisterCpuAutoMLKernels(kernel_registry));
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<KernelRegistry> GetCpuKernelRegistry() {
|
||||
struct KernelRegistryAndStatus{
|
||||
std::shared_ptr<KernelRegistry> kernel_registry = std::make_shared<KernelRegistry>();
|
||||
RegisterCPUKernels(*kernel_registry);
|
||||
return kernel_registry;
|
||||
Status st;
|
||||
};
|
||||
|
||||
KernelRegistryAndStatus GetCpuKernelRegistry() {
|
||||
KernelRegistryAndStatus ret;
|
||||
ret.st = RegisterCPUKernels(*ret.kernel_registry);
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::shared_ptr<KernelRegistry> CPUExecutionProvider::GetKernelRegistry() const {
|
||||
static std::shared_ptr<KernelRegistry> kernel_registry = GetCpuKernelRegistry();
|
||||
return kernel_registry;
|
||||
static KernelRegistryAndStatus k = GetCpuKernelRegistry();
|
||||
//throw if the registry failed to initialize
|
||||
ORT_THROW_IF_ERROR(k.st);
|
||||
return k.kernel_registry;
|
||||
}
|
||||
|
||||
std::unique_ptr<IDataTransfer> CPUExecutionProvider::GetDataTransfer() const {
|
||||
|
|
|
|||
|
|
@ -118,9 +118,9 @@ REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Greater, 7, 9, float, Greater)
|
|||
REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Greater, 9, int32_t, Greater);
|
||||
REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Greater, 9, int64_t, Greater);
|
||||
|
||||
REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 7, bool, Equal);
|
||||
REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 7, int32_t, Equal);
|
||||
REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 7, int64_t, Equal);
|
||||
REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 7, 10, bool, Equal);
|
||||
REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 7, 10, int32_t, Equal);
|
||||
REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 7, 10, int64_t, Equal);
|
||||
REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 11, bool, Equal);
|
||||
REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 11, int32_t, Equal);
|
||||
REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 11, int64_t, Equal);
|
||||
|
|
|
|||
94
onnxruntime/test/framework/kernel_registry_test.cc
Normal file
94
onnxruntime/test/framework/kernel_registry_test.cc
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <core/framework/kernel_registry.h>
|
||||
#include <core/framework/op_kernel.h>
|
||||
|
||||
using namespace onnxruntime;
|
||||
static Status RegKernels(KernelRegistry& r, std::vector<std::unique_ptr<KernelDef> >& function_table, const KernelCreateFn& kernel_creator) {
|
||||
for (auto& function_table_entry : function_table) {
|
||||
ORT_RETURN_IF_ERROR(r.Register(KernelCreateInfo(std::move(function_table_entry), kernel_creator)));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
class FakeKernel final : public OpKernel {
|
||||
public:
|
||||
FakeKernel(const OpKernelInfo& info) : OpKernel(info) {}
|
||||
Status Compute(OpKernelContext*) const override {
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
OpKernel* CreateFakeKernel(const OpKernelInfo& info) {
|
||||
return new FakeKernel(info);
|
||||
}
|
||||
|
||||
TEST(KernelRegistryTests, simple) {
|
||||
KernelRegistry r;
|
||||
std::vector<std::unique_ptr<KernelDef> > function_table;
|
||||
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain("").SinceVersion(6).Provider(kCpuExecutionProvider).Build());
|
||||
|
||||
Status st;
|
||||
ASSERT_TRUE((st = RegKernels(r, function_table, CreateFakeKernel)).IsOK()) << st.ErrorMessage();
|
||||
}
|
||||
|
||||
TEST(KernelRegistryTests, dup_simple) {
|
||||
KernelRegistry r;
|
||||
std::vector<std::unique_ptr<KernelDef> > function_table;
|
||||
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain("").SinceVersion(6).Provider(kCpuExecutionProvider).Build());
|
||||
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain("").SinceVersion(6).Provider(kCpuExecutionProvider).Build());
|
||||
Status st;
|
||||
ASSERT_FALSE((st = RegKernels(r, function_table, CreateFakeKernel)).IsOK()) << st.ErrorMessage();
|
||||
}
|
||||
|
||||
//duplicated registration. One in default("") domain, another in "ai.onnx" domain
|
||||
TEST(KernelRegistryTests, dup_simple2) {
|
||||
KernelRegistry r;
|
||||
std::vector<std::unique_ptr<KernelDef> > function_table;
|
||||
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain("").SinceVersion(6).Provider(kCpuExecutionProvider).Build());
|
||||
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain("ai.onnx").SinceVersion(6).Provider(kCpuExecutionProvider).Build());
|
||||
Status st;
|
||||
ASSERT_FALSE((st = RegKernels(r, function_table, CreateFakeKernel)).IsOK()) << st.ErrorMessage();
|
||||
}
|
||||
|
||||
//One in default("") domain, another in ms domain. Should be ok
|
||||
TEST(KernelRegistryTests, one_op_name_in_two_domains) {
|
||||
KernelRegistry r;
|
||||
std::vector<std::unique_ptr<KernelDef> > function_table;
|
||||
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain("").SinceVersion(6).Provider(kCpuExecutionProvider).Build());
|
||||
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain(kMSDomain).SinceVersion(6).Provider(kCpuExecutionProvider).Build());
|
||||
Status st;
|
||||
ASSERT_TRUE((st = RegKernels(r, function_table, CreateFakeKernel)).IsOK()) << st.ErrorMessage();
|
||||
}
|
||||
|
||||
//One op two versions
|
||||
TEST(KernelRegistryTests, two_versions) {
|
||||
KernelRegistry r;
|
||||
std::vector<std::unique_ptr<KernelDef> > function_table;
|
||||
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain("").SinceVersion(6).Provider(kCpuExecutionProvider).Build());
|
||||
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain("").SinceVersion(1, 5).Provider(kCpuExecutionProvider).Build());
|
||||
Status st;
|
||||
ASSERT_TRUE((st = RegKernels(r, function_table, CreateFakeKernel)).IsOK()) << st.ErrorMessage();
|
||||
}
|
||||
|
||||
//One op two versions
|
||||
TEST(KernelRegistryTests, two_versions2) {
|
||||
KernelRegistry r;
|
||||
std::vector<std::unique_ptr<KernelDef> > function_table;
|
||||
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain("").SinceVersion(6).Provider(kCpuExecutionProvider).Build());
|
||||
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain("").SinceVersion(1, 6).Provider(kCpuExecutionProvider).Build());
|
||||
Status st;
|
||||
ASSERT_FALSE((st = RegKernels(r, function_table, CreateFakeKernel)).IsOK());
|
||||
}
|
||||
|
||||
//One op two versions
|
||||
TEST(KernelRegistryTests, two_versions3) {
|
||||
KernelRegistry r;
|
||||
std::vector<std::unique_ptr<KernelDef> > function_table;
|
||||
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain("").SinceVersion(6).Provider(kCpuExecutionProvider).Build());
|
||||
function_table.emplace_back(KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).SetName("Elu").SetDomain("").SinceVersion(1).Provider(kCpuExecutionProvider).Build());
|
||||
Status st;
|
||||
ASSERT_FALSE((st = RegKernels(r, function_table, CreateFakeKernel)).IsOK());
|
||||
}
|
||||
Loading…
Reference in a new issue