mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Make C API capable of defining CUDA custom ops (#1178)
Recreate the PR on behalf of Rui Xia, for #779
This commit is contained in:
parent
b68bb51dd0
commit
d8ac0d64d0
3 changed files with 42 additions and 2 deletions
|
|
@ -580,6 +580,10 @@ struct OrtCustomOp {
|
|||
// Returns the name of the op
|
||||
const char*(ORT_API_CALL* GetName)(_In_ struct OrtCustomOp* op);
|
||||
|
||||
// Returns the type of the execution provider
|
||||
// If the function pointer is null, use CPU execution provider by default
|
||||
const char*(ORT_API_CALL* GetExecutionProviderType)(_In_ struct OrtCustomOp* op);
|
||||
|
||||
// Returns the count and types of the input & output tensors
|
||||
ONNXTensorElementDataType(ORT_API_CALL* GetInputType)(_In_ struct OrtCustomOp* op, _In_ size_t index);
|
||||
size_t(ORT_API_CALL* GetInputTypeCount)(_In_ struct OrtCustomOp* op);
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@
|
|||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <type_traits>
|
||||
|
||||
namespace Ort {
|
||||
|
||||
|
|
@ -276,6 +278,32 @@ struct CustomOpApi {
|
|||
const OrtCustomOpApi& api_;
|
||||
};
|
||||
|
||||
namespace CustomOpImpl {
|
||||
// SFINAE definition to determine whether class T as GetExecutionProviderType interface
|
||||
template <typename T, typename... Args>
|
||||
class HasProvider {
|
||||
template <typename U = T,
|
||||
typename = decltype( std::declval<U>().GetExecutionProviderType(std::declval<Args>()...) )>
|
||||
static std::true_type test(int);
|
||||
template <typename U = T>
|
||||
static std::false_type test(...);
|
||||
public:
|
||||
static constexpr bool value = decltype(test(0))::value;
|
||||
};
|
||||
|
||||
// SFINAE definitions to get the execution provider of op.
|
||||
// If type T has GetExecutionProviderType, use the result of op->GetExecutionProviderType().
|
||||
// Otherwise, use kCpuExecutionProvider by default.
|
||||
template <typename T>
|
||||
std::enable_if_t<HasProvider<T>::value, const char*> GetProvider(T* op) {
|
||||
return op->GetExecutionProviderType();
|
||||
}
|
||||
template <typename T>
|
||||
std::enable_if_t<!HasProvider<T>::value, const char*> GetProvider(T*) {
|
||||
return "CPUExecutionProvider";
|
||||
}
|
||||
} // namespace CustomOpImpl
|
||||
|
||||
template <typename TOp, typename TKernel>
|
||||
struct CustomOpBase : OrtCustomOp {
|
||||
CustomOpBase() {
|
||||
|
|
@ -283,6 +311,9 @@ struct CustomOpBase : OrtCustomOp {
|
|||
OrtCustomOp::CreateKernel = [](OrtCustomOp* this_, const OrtCustomOpApi* api, const OrtKernelInfo* info) { return static_cast<TOp*>(this_)->CreateKernel(*api, info); };
|
||||
OrtCustomOp::GetName = [](OrtCustomOp* this_) { return static_cast<TOp*>(this_)->GetName(); };
|
||||
|
||||
// If OrtCustomOp does not have a definition of GetExecutionProviderType, use CPUExecutorProvider by default
|
||||
OrtCustomOp::GetExecutionProviderType = [](OrtCustomOp* this_) { return CustomOpImpl::GetProvider<TOp>(static_cast<TOp*>(this_)); };
|
||||
|
||||
OrtCustomOp::GetInputTypeCount = [](OrtCustomOp* this_) { return static_cast<TOp*>(this_)->GetInputTypeCount(); };
|
||||
OrtCustomOp::GetInputType = [](OrtCustomOp* this_, size_t index) { return static_cast<TOp*>(this_)->GetInputType(index); };
|
||||
|
||||
|
|
|
|||
|
|
@ -130,8 +130,13 @@ common::Status CreateCustomRegistry(const std::vector<OrtCustomOpDomain*>& op_do
|
|||
KernelDefBuilder def_builder;
|
||||
def_builder.SetName(op->GetName(op))
|
||||
.SetDomain(domain->domain_)
|
||||
.SinceVersion(1)
|
||||
.Provider(onnxruntime::kCpuExecutionProvider);
|
||||
.SinceVersion(1);
|
||||
if (op->GetExecutionProviderType) {
|
||||
def_builder.Provider(op->GetExecutionProviderType(op));
|
||||
} else {
|
||||
def_builder.Provider(onnxruntime::kCpuExecutionProvider);
|
||||
}
|
||||
|
||||
KernelCreateFn kernel_create_fn = [&op](const OpKernelInfo& info) -> OpKernel* { return new CustomOpKernel(info, *op); };
|
||||
KernelCreateInfo create_info(def_builder.Build(), kernel_create_fn);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue