mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
Allow a custom op with the same name to be registered for several providers. (#3400)
This commit is contained in:
parent
a5fea26cb4
commit
3568f8d186
2 changed files with 40 additions and 6 deletions
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/graph/schema_registry.h"
|
||||
#include "core/common/logging/logging.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
// Add customized domain to min/max version.
|
||||
|
|
@ -69,7 +70,8 @@ common::Status OnnxRuntimeOpSchemaRegistry::RegisterOpSchemaInternal(ONNX_NAMESP
|
|||
<< op_schema.line()
|
||||
<< ", but it is already registered from file "
|
||||
<< schema.file() << " line " << schema.line() << std::endl;
|
||||
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, ostream.str());
|
||||
LOGS_DEFAULT(WARNING) << ostream.str();
|
||||
return common::Status::OK(); // an op with the same name can be registered for multiple execution providers
|
||||
}
|
||||
|
||||
auto ver_range_it = domain_version_range_map_.find(op_domain);
|
||||
|
|
|
|||
|
|
@ -239,19 +239,20 @@ struct MyCustomKernel {
|
|||
};
|
||||
|
||||
struct MyCustomOp : Ort::CustomOpBase<MyCustomOp, MyCustomKernel> {
|
||||
explicit MyCustomOp(const char* provider) : provider_(provider) {}
|
||||
void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) { return new MyCustomKernel(api, info); };
|
||||
const char* GetName() const { return "Foo"; };
|
||||
|
||||
#ifdef USE_CUDA
|
||||
// Make the kernel run on CUDA for CUDA-enabled builds
|
||||
const char* GetExecutionProviderType() const { return onnxruntime::kCudaExecutionProvider; };
|
||||
#endif
|
||||
const char* GetExecutionProviderType() const { return provider_; };
|
||||
|
||||
size_t GetInputTypeCount() const { return 2; };
|
||||
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
||||
|
||||
size_t GetOutputTypeCount() const { return 1; };
|
||||
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
||||
|
||||
private:
|
||||
const char* provider_;
|
||||
};
|
||||
|
||||
TEST(CApiTest, custom_op_handler) {
|
||||
|
|
@ -267,7 +268,12 @@ TEST(CApiTest, custom_op_handler) {
|
|||
std::vector<int64_t> expected_dims_y = {3, 2};
|
||||
std::vector<float> expected_values_y = {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f};
|
||||
|
||||
MyCustomOp custom_op;
|
||||
#ifdef USE_CUDA
|
||||
MyCustomOp custom_op{onnxruntime::kCudaExecutionProvider};
|
||||
#else
|
||||
MyCustomOp custom_op{onnxruntime::kCpuExecutionProvider};
|
||||
#endif
|
||||
|
||||
Ort::CustomOpDomain custom_op_domain("");
|
||||
custom_op_domain.Add(&custom_op);
|
||||
|
||||
|
|
@ -286,6 +292,32 @@ TEST(CApiTest, custom_op_handler) {
|
|||
#endif
|
||||
}
|
||||
|
||||
// Tests registration of a custom op of the same name for both CPU and CUDA EPs
|
||||
#ifdef USE_CUDA
|
||||
TEST(CApiTest, RegisterCustomOpForCPUAndCUDA) {
|
||||
std::cout << "Tests registration of a custom op of the same name for both CPU and CUDA EPs" << std::endl;
|
||||
|
||||
std::vector<Input> inputs(1);
|
||||
Input& input = inputs[0];
|
||||
input.name = "X";
|
||||
input.dims = {3, 2};
|
||||
input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
|
||||
|
||||
// prepare expected inputs and outputs
|
||||
std::vector<int64_t> expected_dims_y = {3, 2};
|
||||
std::vector<float> expected_values_y = {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f};
|
||||
|
||||
MyCustomOp custom_op_cpu{onnxruntime::kCpuExecutionProvider};
|
||||
MyCustomOp custom_op_cuda{onnxruntime::kCudaExecutionProvider};
|
||||
Ort::CustomOpDomain custom_op_domain("");
|
||||
custom_op_domain.Add(&custom_op_cpu);
|
||||
custom_op_domain.Add(&custom_op_cuda);
|
||||
|
||||
TestInference<PATH_TYPE, float>(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y,
|
||||
expected_values_y, 1, custom_op_domain, nullptr, true);
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST(CApiTest, DISABLED_test_custom_op_library) {
|
||||
std::cout << "Running inference using custom op shared library" << std::endl;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue