From 3568f8d186df180a2b942a42254dffd761ab2ea6 Mon Sep 17 00:00:00 2001 From: Pranav Sharma Date: Thu, 2 Apr 2020 15:38:51 -0700 Subject: [PATCH] Allow a custom op with the same name to be registered for several providers. (#3400) --- onnxruntime/core/graph/schema_registry.cc | 4 +- onnxruntime/test/shared_lib/test_inference.cc | 42 ++++++++++++++++--- 2 files changed, 40 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/graph/schema_registry.cc b/onnxruntime/core/graph/schema_registry.cc index 8127af7877..acd425001c 100644 --- a/onnxruntime/core/graph/schema_registry.cc +++ b/onnxruntime/core/graph/schema_registry.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/graph/schema_registry.h" +#include "core/common/logging/logging.h" namespace onnxruntime { // Add customized domain to min/max version. @@ -69,7 +70,8 @@ common::Status OnnxRuntimeOpSchemaRegistry::RegisterOpSchemaInternal(ONNX_NAMESP << op_schema.line() << ", but it is already registered from file " << schema.file() << " line " << schema.line() << std::endl; - return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, ostream.str()); + LOGS_DEFAULT(WARNING) << ostream.str(); + return common::Status::OK(); // an op with the same name can be registered for multiple execution providers } auto ver_range_it = domain_version_range_map_.find(op_domain); diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index c066e822fd..51294e5dc0 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -239,19 +239,20 @@ struct MyCustomKernel { }; struct MyCustomOp : Ort::CustomOpBase { + explicit MyCustomOp(const char* provider) : provider_(provider) {} void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) { return new MyCustomKernel(api, info); }; const char* GetName() const { return "Foo"; }; -#ifdef USE_CUDA - // Make the kernel run on CUDA for CUDA-enabled builds - const char* GetExecutionProviderType() const { return onnxruntime::kCudaExecutionProvider; }; -#endif + const char* GetExecutionProviderType() const { return provider_; }; size_t GetInputTypeCount() const { return 2; }; ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; size_t GetOutputTypeCount() const { return 1; }; ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; + + private: + const char* provider_; }; TEST(CApiTest, custom_op_handler) { @@ -267,7 +268,12 @@ TEST(CApiTest, custom_op_handler) { std::vector expected_dims_y = {3, 2}; std::vector expected_values_y = {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f}; - MyCustomOp custom_op; +#ifdef USE_CUDA + MyCustomOp custom_op{onnxruntime::kCudaExecutionProvider}; +#else + MyCustomOp custom_op{onnxruntime::kCpuExecutionProvider}; +#endif + Ort::CustomOpDomain custom_op_domain(""); custom_op_domain.Add(&custom_op); @@ -286,6 +292,32 @@ TEST(CApiTest, custom_op_handler) { #endif } +// Tests registration of a custom op of the same name for both CPU and CUDA EPs +#ifdef USE_CUDA +TEST(CApiTest, RegisterCustomOpForCPUAndCUDA) { + std::cout << "Tests registration of a custom op of the same name for both CPU and CUDA EPs" << std::endl; + + std::vector inputs(1); + Input& input = inputs[0]; + input.name = "X"; + input.dims = {3, 2}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // prepare expected inputs and outputs + std::vector expected_dims_y = {3, 2}; + std::vector expected_values_y = {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f}; + + MyCustomOp custom_op_cpu{onnxruntime::kCpuExecutionProvider}; + MyCustomOp custom_op_cuda{onnxruntime::kCudaExecutionProvider}; + Ort::CustomOpDomain custom_op_domain(""); + custom_op_domain.Add(&custom_op_cpu); + custom_op_domain.Add(&custom_op_cuda); + + TestInference(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, + expected_values_y, 1, custom_op_domain, nullptr, true); +} +#endif + TEST(CApiTest, DISABLED_test_custom_op_library) { std::cout << "Running inference using custom op shared library" << std::endl;