diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index bfdc6b47f1..388e04f6de 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -580,6 +580,10 @@ struct OrtCustomOp { // Returns the name of the op const char*(ORT_API_CALL* GetName)(_In_ struct OrtCustomOp* op); + // Returns the type of the execution provider + // If the function pointer is null, use CPU execution provider by default + const char*(ORT_API_CALL* GetExecutionProviderType)(_In_ struct OrtCustomOp* op); + // Returns the count and types of the input & output tensors ONNXTensorElementDataType(ORT_API_CALL* GetInputType)(_In_ struct OrtCustomOp* op, _In_ size_t index); size_t(ORT_API_CALL* GetInputTypeCount)(_In_ struct OrtCustomOp* op); diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index f343847f30..116509ec7d 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -20,6 +20,8 @@ #include #include #include +#include +#include namespace Ort { @@ -276,6 +278,32 @@ struct CustomOpApi { const OrtCustomOpApi& api_; }; +namespace CustomOpImpl { + // SFINAE definition to determine whether class T as GetExecutionProviderType interface + template + class HasProvider { + template ().GetExecutionProviderType(std::declval()...) )> + static std::true_type test(int); + template + static std::false_type test(...); + public: + static constexpr bool value = decltype(test(0))::value; + }; + + // SFINAE definitions to get the execution provider of op. + // If type T has GetExecutionProviderType, use the result of op->GetExecutionProviderType(). + // Otherwise, use kCpuExecutionProvider by default. + template + std::enable_if_t::value, const char*> GetProvider(T* op) { + return op->GetExecutionProviderType(); + } + template + std::enable_if_t::value, const char*> GetProvider(T*) { + return "CPUExecutionProvider"; + } +} // namespace CustomOpImpl + template struct CustomOpBase : OrtCustomOp { CustomOpBase() { @@ -283,6 +311,9 @@ struct CustomOpBase : OrtCustomOp { OrtCustomOp::CreateKernel = [](OrtCustomOp* this_, const OrtCustomOpApi* api, const OrtKernelInfo* info) { return static_cast(this_)->CreateKernel(*api, info); }; OrtCustomOp::GetName = [](OrtCustomOp* this_) { return static_cast(this_)->GetName(); }; + // If OrtCustomOp does not have a definition of GetExecutionProviderType, use CPUExecutorProvider by default + OrtCustomOp::GetExecutionProviderType = [](OrtCustomOp* this_) { return CustomOpImpl::GetProvider(static_cast(this_)); }; + OrtCustomOp::GetInputTypeCount = [](OrtCustomOp* this_) { return static_cast(this_)->GetInputTypeCount(); }; OrtCustomOp::GetInputType = [](OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetInputType(index); }; diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index cc0990b13d..69a47a2595 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -130,8 +130,13 @@ common::Status CreateCustomRegistry(const std::vector& op_do KernelDefBuilder def_builder; def_builder.SetName(op->GetName(op)) .SetDomain(domain->domain_) - .SinceVersion(1) - .Provider(onnxruntime::kCpuExecutionProvider); + .SinceVersion(1); + if (op->GetExecutionProviderType) { + def_builder.Provider(op->GetExecutionProviderType(op)); + } else { + def_builder.Provider(onnxruntime::kCpuExecutionProvider); + } + KernelCreateFn kernel_create_fn = [&op](const OpKernelInfo& info) -> OpKernel* { return new CustomOpKernel(info, *op); }; KernelCreateInfo create_info(def_builder.Build(), kernel_create_fn);