Make C API capable of defining CUDA custom ops (#1178)

Recreate the PR on behalf of Rui Xia, for #779
This commit is contained in:
Changming Sun 2019-06-06 13:45:32 -07:00 committed by GitHub
parent b68bb51dd0
commit d8ac0d64d0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 42 additions and 2 deletions

View file

@ -580,6 +580,10 @@ struct OrtCustomOp {
// Returns the name of the op
const char*(ORT_API_CALL* GetName)(_In_ struct OrtCustomOp* op);
// Returns the type of the execution provider
// If the function pointer is null, use CPU execution provider by default
const char*(ORT_API_CALL* GetExecutionProviderType)(_In_ struct OrtCustomOp* op);
// Returns the count and types of the input & output tensors
ONNXTensorElementDataType(ORT_API_CALL* GetInputType)(_In_ struct OrtCustomOp* op, _In_ size_t index);
size_t(ORT_API_CALL* GetInputTypeCount)(_In_ struct OrtCustomOp* op);

View file

@ -20,6 +20,8 @@
#include <stdexcept>
#include <string>
#include <vector>
#include <utility>
#include <type_traits>
namespace Ort {
@ -276,6 +278,32 @@ struct CustomOpApi {
const OrtCustomOpApi& api_;
};
namespace CustomOpImpl {
// SFINAE definition to determine whether class T as GetExecutionProviderType interface
template <typename T, typename... Args>
class HasProvider {
template <typename U = T,
typename = decltype( std::declval<U>().GetExecutionProviderType(std::declval<Args>()...) )>
static std::true_type test(int);
template <typename U = T>
static std::false_type test(...);
public:
static constexpr bool value = decltype(test(0))::value;
};
// SFINAE definitions to get the execution provider of op.
// If type T has GetExecutionProviderType, use the result of op->GetExecutionProviderType().
// Otherwise, use kCpuExecutionProvider by default.
template <typename T>
std::enable_if_t<HasProvider<T>::value, const char*> GetProvider(T* op) {
return op->GetExecutionProviderType();
}
template <typename T>
std::enable_if_t<!HasProvider<T>::value, const char*> GetProvider(T*) {
return "CPUExecutionProvider";
}
} // namespace CustomOpImpl
template <typename TOp, typename TKernel>
struct CustomOpBase : OrtCustomOp {
CustomOpBase() {
@ -283,6 +311,9 @@ struct CustomOpBase : OrtCustomOp {
OrtCustomOp::CreateKernel = [](OrtCustomOp* this_, const OrtCustomOpApi* api, const OrtKernelInfo* info) { return static_cast<TOp*>(this_)->CreateKernel(*api, info); };
OrtCustomOp::GetName = [](OrtCustomOp* this_) { return static_cast<TOp*>(this_)->GetName(); };
// If OrtCustomOp does not have a definition of GetExecutionProviderType, use CPUExecutorProvider by default
OrtCustomOp::GetExecutionProviderType = [](OrtCustomOp* this_) { return CustomOpImpl::GetProvider<TOp>(static_cast<TOp*>(this_)); };
OrtCustomOp::GetInputTypeCount = [](OrtCustomOp* this_) { return static_cast<TOp*>(this_)->GetInputTypeCount(); };
OrtCustomOp::GetInputType = [](OrtCustomOp* this_, size_t index) { return static_cast<TOp*>(this_)->GetInputType(index); };

View file

@ -130,8 +130,13 @@ common::Status CreateCustomRegistry(const std::vector<OrtCustomOpDomain*>& op_do
KernelDefBuilder def_builder;
def_builder.SetName(op->GetName(op))
.SetDomain(domain->domain_)
.SinceVersion(1)
.Provider(onnxruntime::kCpuExecutionProvider);
.SinceVersion(1);
if (op->GetExecutionProviderType) {
def_builder.Provider(op->GetExecutionProviderType(op));
} else {
def_builder.Provider(onnxruntime::kCpuExecutionProvider);
}
KernelCreateFn kernel_create_fn = [&op](const OpKernelInfo& info) -> OpKernel* { return new CustomOpKernel(info, *op); };
KernelCreateInfo create_info(def_builder.Build(), kernel_create_fn);