mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
Allow custom operator functions to safely propagate errors through the C-API (#16479)
### Description This PR implements a backward-compatible way to define custom operators with fallible compute functions. The C++ API templated gained an optional `Fallible` argument. Closes #14287 ### Motivation and Context #14287 contains more context. The gist is that the current C-API defines compute operations of custom operators as functions returning `void` rather than an `OrtStatusPtr`. Currently, errors are often propagated across the C-ABI using C++ exceptions. That is very unsafe and undefined behavior. Moreover, it is difficult for languages other than C++ to use this approach even if they wanted to. A C-compliant sound and safe way to propagate errors allows for non-C++ fallible custom operators. ### An example in action https://github.com/cbourjau/ort-custom-op/pull/6/files is a demonstration of how this PR can be used to write safe and fallible custom operators in Rust.
This commit is contained in:
parent
15f16ef36e
commit
6dd4e4801a
5 changed files with 62 additions and 16 deletions
|
|
@ -4346,7 +4346,10 @@ typedef enum OrtCustomOpInputOutputCharacteristic {
|
|||
struct OrtCustomOp {
|
||||
uint32_t version; // Must be initialized to ORT_API_VERSION
|
||||
|
||||
// This callback creates the kernel, which is a user defined parameter that is passed to the Kernel* callbacks below.
|
||||
// This callback creates the kernel, which is a user defined
|
||||
// parameter that is passed to the Kernel* callbacks below. It is
|
||||
// recommended to use CreateKernelV2 which allows for a safe error
|
||||
// propagation by returning an OrtStatusPtr.
|
||||
void*(ORT_API_CALL* CreateKernel)(_In_ const struct OrtCustomOp* op, _In_ const OrtApi* api,
|
||||
_In_ const OrtKernelInfo* info);
|
||||
|
||||
|
|
@ -4362,7 +4365,9 @@ struct OrtCustomOp {
|
|||
ONNXTensorElementDataType(ORT_API_CALL* GetOutputType)(_In_ const struct OrtCustomOp* op, _In_ size_t index);
|
||||
size_t(ORT_API_CALL* GetOutputTypeCount)(_In_ const struct OrtCustomOp* op);
|
||||
|
||||
// Op kernel callbacks
|
||||
// Perform a computation step. It is recommended to use
|
||||
// KernelComputeV2 which allows for a safe error propagation by
|
||||
// returning an OrtStatusPtr.
|
||||
void(ORT_API_CALL* KernelCompute)(_In_ void* op_kernel, _In_ OrtKernelContext* context);
|
||||
void(ORT_API_CALL* KernelDestroy)(_In_ void* op_kernel);
|
||||
|
||||
|
|
@ -4394,6 +4399,14 @@ struct OrtCustomOp {
|
|||
// and false (zero) otherwise.
|
||||
// Applicable only for custom ops that have a variadic output.
|
||||
int(ORT_API_CALL* GetVariadicOutputHomogeneity)(_In_ const struct OrtCustomOp* op);
|
||||
|
||||
// Create the kernel state which is passed to each compute call.
|
||||
OrtStatusPtr(ORT_API_CALL* CreateKernelV2)(_In_ const struct OrtCustomOp* op, _In_ const OrtApi* api,
|
||||
_In_ const OrtKernelInfo* info,
|
||||
_Out_ void** kernel);
|
||||
|
||||
// Perform the computation step.
|
||||
OrtStatusPtr(ORT_API_CALL* KernelComputeV2)(_In_ void* op_kernel, _In_ OrtKernelContext* context);
|
||||
};
|
||||
|
||||
/*
|
||||
|
|
|
|||
|
|
@ -1847,11 +1847,10 @@ struct Op : detail::Base<OrtOp> {
|
|||
size_t output_count);
|
||||
};
|
||||
|
||||
template <typename TOp, typename TKernel>
|
||||
template <typename TOp, typename TKernel, bool WithStatus = false>
|
||||
struct CustomOpBase : OrtCustomOp {
|
||||
CustomOpBase() {
|
||||
OrtCustomOp::version = ORT_API_VERSION;
|
||||
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast<const TOp*>(this_)->CreateKernel(*api, info); };
|
||||
OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetName(); };
|
||||
|
||||
OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetExecutionProviderType(); };
|
||||
|
|
@ -1863,7 +1862,6 @@ struct CustomOpBase : OrtCustomOp {
|
|||
OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetOutputTypeCount(); };
|
||||
OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputType(index); };
|
||||
|
||||
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { static_cast<TKernel*>(op_kernel)->Compute(context); };
|
||||
#if defined(_MSC_VER) && !defined(__clang__)
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable : 26409)
|
||||
|
|
@ -1879,6 +1877,22 @@ struct CustomOpBase : OrtCustomOp {
|
|||
OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicInputHomogeneity()); };
|
||||
OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetVariadicOutputMinArity(); };
|
||||
OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicOutputHomogeneity()); };
|
||||
if constexpr (WithStatus) {
|
||||
OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr {
|
||||
return static_cast<const TOp*>(this_)->CreateKernelV2(*api, info, op_kernel);
|
||||
};
|
||||
OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
|
||||
return static_cast<TKernel*>(op_kernel)->ComputeV2(context);
|
||||
};
|
||||
} else {
|
||||
OrtCustomOp::CreateKernelV2 = nullptr;
|
||||
OrtCustomOp::KernelComputeV2 = nullptr;
|
||||
|
||||
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast<const TOp*>(this_)->CreateKernel(*api, info); };
|
||||
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
|
||||
static_cast<TKernel*>(op_kernel)->Compute(context);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
|
||||
|
|
|
|||
|
|
@ -1812,9 +1812,9 @@ inline std::vector<std::string> GetAvailableProviders() {
|
|||
return available_providers;
|
||||
}
|
||||
|
||||
template <typename TOp, typename TKernel>
|
||||
void CustomOpBase<TOp, TKernel>::GetSessionConfigs(std::unordered_map<std::string, std::string>& out,
|
||||
ConstSessionOptions options) const {
|
||||
template <typename TOp, typename TKernel, bool WithStatus>
|
||||
void CustomOpBase<TOp, TKernel, WithStatus>::GetSessionConfigs(std::unordered_map<std::string, std::string>& out,
|
||||
ConstSessionOptions options) const {
|
||||
const TOp* derived = static_cast<const TOp*>(this);
|
||||
std::vector<std::string> keys = derived->GetSessionConfigKeys();
|
||||
|
||||
|
|
|
|||
|
|
@ -402,13 +402,30 @@ struct CustomOpKernel : OpKernel {
|
|||
ORT_THROW("Unsupported version '" + std::to_string(op_.version) + "' in custom op '" + op.GetName(&op));
|
||||
}
|
||||
|
||||
op_kernel_ = op_.CreateKernel(&op_, OrtGetApiBase()->GetApi(op_.version),
|
||||
reinterpret_cast<const OrtKernelInfo*>(&info));
|
||||
if (op_.version > 15 && op_.KernelCompute == 0) {
|
||||
op_kernel_ = nullptr;
|
||||
Ort::ThrowOnError(
|
||||
op_.CreateKernelV2(
|
||||
&op_,
|
||||
OrtGetApiBase()->GetApi(op_.version),
|
||||
reinterpret_cast<const OrtKernelInfo*>(&info),
|
||||
&op_kernel_));
|
||||
} else {
|
||||
op_kernel_ = op_.CreateKernel(&op_, OrtGetApiBase()->GetApi(op_.version),
|
||||
reinterpret_cast<const OrtKernelInfo*>(&info));
|
||||
}
|
||||
}
|
||||
|
||||
~CustomOpKernel() override { op_.KernelDestroy(op_kernel_); }
|
||||
~CustomOpKernel() override {
|
||||
op_.KernelDestroy(op_kernel_);
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* ctx) const override {
|
||||
if (op_.version > 15 && op_.KernelCompute == 0) {
|
||||
auto status_ptr = op_.KernelComputeV2(op_kernel_, reinterpret_cast<OrtKernelContext*>(ctx));
|
||||
return ToStatus(status_ptr);
|
||||
}
|
||||
|
||||
op_.KernelCompute(op_kernel_, reinterpret_cast<OrtKernelContext*>(ctx));
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ void cuda_add(int64_t, T3*, const T1*, const T2*, cudaStream_t compute_stream);
|
|||
static const char* c_OpDomain = "test.customop";
|
||||
|
||||
struct KernelOne {
|
||||
void Compute(OrtKernelContext* context) {
|
||||
OrtStatusPtr ComputeV2(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
Ort::KernelContext ctx(context);
|
||||
auto input_X = ctx.GetInput(0);
|
||||
|
|
@ -45,13 +45,15 @@ struct KernelOne {
|
|||
out[i] = X[i] + Y[i];
|
||||
}
|
||||
#endif
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
// legacy custom op registration
|
||||
struct CustomOpOne : Ort::CustomOpBase<CustomOpOne, KernelOne> {
|
||||
void* CreateKernel(const OrtApi& /* api */, const OrtKernelInfo* /* info */) const {
|
||||
return std::make_unique<KernelOne>().release();
|
||||
// legacy custom op registration with kernel creation and compute function that return an OrtStatusPtr
|
||||
struct CustomOpOne : Ort::CustomOpBase<CustomOpOne, KernelOne, true> {
|
||||
OrtStatusPtr CreateKernelV2(const OrtApi& /* api */, const OrtKernelInfo* /* info */, void** op_kernel) const {
|
||||
*op_kernel = reinterpret_cast<void*>(std::make_unique<KernelOne>().release());
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
const char* GetName() const { return "CustomOpOne"; };
|
||||
|
|
|
|||
Loading…
Reference in a new issue