diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index e804e1b0d2..37f828214a 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -4346,7 +4346,10 @@ typedef enum OrtCustomOpInputOutputCharacteristic { struct OrtCustomOp { uint32_t version; // Must be initialized to ORT_API_VERSION - // This callback creates the kernel, which is a user defined parameter that is passed to the Kernel* callbacks below. + // This callback creates the kernel, which is a user defined + // parameter that is passed to the Kernel* callbacks below. It is + // recommended to use CreateKernelV2 which allows for a safe error + // propagation by returning an OrtStatusPtr. void*(ORT_API_CALL* CreateKernel)(_In_ const struct OrtCustomOp* op, _In_ const OrtApi* api, _In_ const OrtKernelInfo* info); @@ -4362,7 +4365,9 @@ struct OrtCustomOp { ONNXTensorElementDataType(ORT_API_CALL* GetOutputType)(_In_ const struct OrtCustomOp* op, _In_ size_t index); size_t(ORT_API_CALL* GetOutputTypeCount)(_In_ const struct OrtCustomOp* op); - // Op kernel callbacks + // Perform a computation step. It is recommended to use + // KernelComputeV2 which allows for a safe error propagation by + // returning an OrtStatusPtr. void(ORT_API_CALL* KernelCompute)(_In_ void* op_kernel, _In_ OrtKernelContext* context); void(ORT_API_CALL* KernelDestroy)(_In_ void* op_kernel); @@ -4394,6 +4399,14 @@ struct OrtCustomOp { // and false (zero) otherwise. // Applicable only for custom ops that have a variadic output. int(ORT_API_CALL* GetVariadicOutputHomogeneity)(_In_ const struct OrtCustomOp* op); + + // Create the kernel state which is passed to each compute call. + OrtStatusPtr(ORT_API_CALL* CreateKernelV2)(_In_ const struct OrtCustomOp* op, _In_ const OrtApi* api, + _In_ const OrtKernelInfo* info, + _Out_ void** kernel); + + // Perform the computation step. + OrtStatusPtr(ORT_API_CALL* KernelComputeV2)(_In_ void* op_kernel, _In_ OrtKernelContext* context); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 409f26f0e2..69cec42895 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1847,11 +1847,10 @@ struct Op : detail::Base { size_t output_count); }; -template +template struct CustomOpBase : OrtCustomOp { CustomOpBase() { OrtCustomOp::version = ORT_API_VERSION; - OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast(this_)->CreateKernel(*api, info); }; OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast(this_)->GetName(); }; OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast(this_)->GetExecutionProviderType(); }; @@ -1863,7 +1862,6 @@ struct CustomOpBase : OrtCustomOp { OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast(this_)->GetOutputTypeCount(); }; OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetOutputType(index); }; - OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { static_cast(op_kernel)->Compute(context); }; #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(push) #pragma warning(disable : 26409) @@ -1879,6 +1877,22 @@ struct CustomOpBase : OrtCustomOp { OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp* this_) { return static_cast(static_cast(this_)->GetVariadicInputHomogeneity()); }; OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp* this_) { return static_cast(this_)->GetVariadicOutputMinArity(); }; OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp* this_) { return static_cast(static_cast(this_)->GetVariadicOutputHomogeneity()); }; + if constexpr (WithStatus) { + OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr { + return static_cast(this_)->CreateKernelV2(*api, info, op_kernel); + }; + OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { + return static_cast(op_kernel)->ComputeV2(context); + }; + } else { + OrtCustomOp::CreateKernelV2 = nullptr; + OrtCustomOp::KernelComputeV2 = nullptr; + + OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast(this_)->CreateKernel(*api, info); }; + OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { + static_cast(op_kernel)->Compute(context); + }; + } } // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index a8db834cc1..c0d45c0541 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1812,9 +1812,9 @@ inline std::vector GetAvailableProviders() { return available_providers; } -template -void CustomOpBase::GetSessionConfigs(std::unordered_map& out, - ConstSessionOptions options) const { +template +void CustomOpBase::GetSessionConfigs(std::unordered_map& out, + ConstSessionOptions options) const { const TOp* derived = static_cast(this); std::vector keys = derived->GetSessionConfigKeys(); diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 1df91d528a..65d3d06c1c 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -402,13 +402,30 @@ struct CustomOpKernel : OpKernel { ORT_THROW("Unsupported version '" + std::to_string(op_.version) + "' in custom op '" + op.GetName(&op)); } - op_kernel_ = op_.CreateKernel(&op_, OrtGetApiBase()->GetApi(op_.version), - reinterpret_cast(&info)); + if (op_.version > 15 && op_.KernelCompute == 0) { + op_kernel_ = nullptr; + Ort::ThrowOnError( + op_.CreateKernelV2( + &op_, + OrtGetApiBase()->GetApi(op_.version), + reinterpret_cast(&info), + &op_kernel_)); + } else { + op_kernel_ = op_.CreateKernel(&op_, OrtGetApiBase()->GetApi(op_.version), + reinterpret_cast(&info)); + } } - ~CustomOpKernel() override { op_.KernelDestroy(op_kernel_); } + ~CustomOpKernel() override { + op_.KernelDestroy(op_kernel_); + } Status Compute(OpKernelContext* ctx) const override { + if (op_.version > 15 && op_.KernelCompute == 0) { + auto status_ptr = op_.KernelComputeV2(op_kernel_, reinterpret_cast(ctx)); + return ToStatus(status_ptr); + } + op_.KernelCompute(op_kernel_, reinterpret_cast(ctx)); return Status::OK(); } diff --git a/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc b/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc index 5d6465984a..ccae8a034f 100644 --- a/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc +++ b/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc @@ -20,7 +20,7 @@ void cuda_add(int64_t, T3*, const T1*, const T2*, cudaStream_t compute_stream); static const char* c_OpDomain = "test.customop"; struct KernelOne { - void Compute(OrtKernelContext* context) { + OrtStatusPtr ComputeV2(OrtKernelContext* context) { // Setup inputs Ort::KernelContext ctx(context); auto input_X = ctx.GetInput(0); @@ -45,13 +45,15 @@ struct KernelOne { out[i] = X[i] + Y[i]; } #endif + return nullptr; } }; -// legacy custom op registration -struct CustomOpOne : Ort::CustomOpBase { - void* CreateKernel(const OrtApi& /* api */, const OrtKernelInfo* /* info */) const { - return std::make_unique().release(); +// legacy custom op registration with kernel creation and compute function that return an OrtStatusPtr +struct CustomOpOne : Ort::CustomOpBase { + OrtStatusPtr CreateKernelV2(const OrtApi& /* api */, const OrtKernelInfo* /* info */, void** op_kernel) const { + *op_kernel = reinterpret_cast(std::make_unique().release()); + return nullptr; }; const char* GetName() const { return "CustomOpOne"; };