From cd52431b8fe3fe1a3f32039f46ed9f809b968e7c Mon Sep 17 00:00:00 2001 From: Ryan Hill <38674843+RyanUnderhill@users.noreply.github.com> Date: Thu, 21 Mar 2019 15:46:50 -0700 Subject: [PATCH] Custom op interface to the C API to remove shared library dependency (#668) * Adding a custom op interface to the C API to remove shared library dependency. * Fixup const issues * Renaming to make things a little simpler * Add a comment --- .../core/session/onnxruntime_c_api.h | 22 ++++++++++++-- onnxruntime/core/session/inference_session.cc | 27 +++++++++++++---- onnxruntime/test/shared_lib/test_inference.cc | 30 +++++++++++-------- 3 files changed, 58 insertions(+), 21 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 06721297f3..8af8c820bb 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -538,8 +538,24 @@ typedef struct OrtKernelInfo OrtKernelInfo; /* * These allow reading node attributes during kernel creation */ -ORT_API_STATUS(OrtKernelInfoGetAttribute_float, _In_ OrtKernelInfo* info, _In_ const char* name, _Out_ float* out); -ORT_API_STATUS(OrtKernelInfoGetAttribute_int64, _In_ OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out); +ORT_API_STATUS(OrtKernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out); +ORT_API_STATUS(OrtKernelInfoGetAttribute_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out); + +struct OrtCustomOpApi { + OrtStatus*(ORT_API_CALL* KernelInfoGetAttribute_float)(_In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out); + OrtStatus*(ORT_API_CALL* KernelInfoGetAttribute_int64)(_In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out); + + OrtStatus*(ORT_API_CALL* GetTensorShapeAndType)(_In_ const OrtValue* value, _Out_ OrtTensorTypeAndShapeInfo** out); + + size_t(ORT_API_CALL* GetNumOfDimensions)(_In_ const OrtTensorTypeAndShapeInfo* info); + void(ORT_API_CALL* GetDimensions)(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length); + OrtStatus*(ORT_API_CALL* SetDims)(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count); + + OrtStatus*(ORT_API_CALL* GetTensorMutableData)(_Inout_ OrtValue* value, _Out_ void** out); + + void(ORT_API_CALL* ReleaseTensorTypeAndShapeInfo)(OrtTensorTypeAndShapeInfo* input); +}; +typedef struct OrtCustomOpApi OrtCustomOpApi; /* * The OrtCustomOp structure defines a custom op's schema and its kernel callbacks. The callbacks are filled in by @@ -549,7 +565,7 @@ struct OrtCustomOp { uint32_t version; // Initialize to ORT_API_VERSION // This callback creates the kernel, which is a user defined parameter that is passed to the Kernel* callbacks below. - void(ORT_API_CALL* CreateKernel)(_In_ struct OrtCustomOp* op, _In_ OrtKernelInfo* info, _Out_ void** op_kernel); + void(ORT_API_CALL* CreateKernel)(_In_ struct OrtCustomOp* op, _In_ const OrtCustomOpApi* api, _In_ const OrtKernelInfo* info, _Out_ void** op_kernel); // Returns the name of the op const char*(ORT_API_CALL* GetName)(_In_ struct OrtCustomOp* op); diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 505a6d5ee7..935343ba83 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -58,17 +58,32 @@ using namespace ONNX_NAMESPACE; +constexpr OrtCustomOpApi g_custom_op_api = { + &OrtKernelInfoGetAttribute_float, + &OrtKernelInfoGetAttribute_int64, + + &OrtGetTensorShapeAndType, + + &OrtGetNumOfDimensions, + &OrtGetDimensions, + &OrtSetDims, + + &OrtGetTensorMutableData, + + &OrtReleaseTensorTypeAndShapeInfo, +}; + ONNXTensorElementDataType MLDataTypeToOnnxRuntimeTensorElementDataType(const onnxruntime::DataTypeImpl* cpp_type); -ORT_API_STATUS_IMPL(OrtKernelInfoGetAttribute_float, _In_ OrtKernelInfo* info, _In_ const char* name, _Out_ float* out) { - auto status = reinterpret_cast(info)->GetAttr(name, out); +ORT_API_STATUS_IMPL(OrtKernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out) { + auto status = reinterpret_cast(info)->GetAttr(name, out); if (status.IsOK()) return nullptr; return onnxruntime::ToOrtStatus(status); } -ORT_API_STATUS_IMPL(OrtKernelInfoGetAttribute_int64, _In_ OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out) { - auto status = reinterpret_cast(info)->GetAttr(name, out); +ORT_API_STATUS_IMPL(OrtKernelInfoGetAttribute_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out) { + auto status = reinterpret_cast(info)->GetAttr(name, out); if (status.IsOK()) return nullptr; return onnxruntime::ToOrtStatus(status); @@ -111,7 +126,9 @@ inline std::basic_string GetCurrentTimeString() { } // namespace struct CustomOpKernel : OpKernel { CustomOpKernel(const OpKernelInfo& info, OrtCustomOp& op) : OpKernel(info), op_(op) { - op_.CreateKernel(&op_, reinterpret_cast(const_cast(&info)), &op_kernel_); + if (op_.version != 1) + throw std::invalid_argument("Unsupported version '" + std::to_string(op_.version) + "' in custom op '" + op.GetName(&op)); + op_.CreateKernel(&op_, &g_custom_op_api, reinterpret_cast(const_cast(&info)), &op_kernel_); } ~CustomOpKernel() { diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 13f57b852c..6475c34798 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -156,13 +156,13 @@ INSTANTIATE_TEST_CASE_P(CApiTestWithProviders, ::testing::Values(0, 1, 2, 3, 4)); struct OrtTensorDimensions : std::vector { - OrtTensorDimensions(OrtValue* value) { + OrtTensorDimensions(const OrtCustomOpApi& ort, OrtValue* value) { OrtTensorTypeAndShapeInfo* info; - ORT_THROW_ON_ERROR(OrtGetTensorShapeAndType(value, &info)); - auto dimensionCount = OrtGetNumOfDimensions(info); + ORT_THROW_ON_ERROR(ort.GetTensorShapeAndType(value, &info)); + auto dimensionCount = ort.GetNumOfDimensions(info); resize(dimensionCount); - OrtGetDimensions(info, data(), dimensionCount); - OrtReleaseTensorTypeAndShapeInfo(info); + ort.GetDimensions(info, data(), dimensionCount); + ort.ReleaseTensorTypeAndShapeInfo(info); } size_t ElementCount() const { @@ -173,38 +173,42 @@ struct OrtTensorDimensions : std::vector { } }; +// Once we use C++17 this could be replaced with std::size template constexpr size_t countof(T (&)[N]) { return N; } struct MyCustomKernel { - MyCustomKernel(OrtKernelInfo& /*info*/) { + MyCustomKernel(const OrtCustomOpApi& ort, const OrtKernelInfo& /*info*/) : ort_(ort) { } void GetOutputShape(OrtValue** inputs, size_t /*input_count*/, size_t /*output_index*/, OrtTensorTypeAndShapeInfo* info) { - OrtTensorDimensions dimensions(inputs[0]); - ORT_THROW_ON_ERROR(OrtSetDims(info, dimensions.data(), dimensions.size())); + OrtTensorDimensions dimensions(ort_, inputs[0]); + ORT_THROW_ON_ERROR(ort_.SetDims(info, dimensions.data(), dimensions.size())); } void Compute(OrtValue** inputs, size_t /*input_count*/, OrtValue** outputs, size_t /*output_count*/) { const float* X; const float* Y; - ORT_THROW_ON_ERROR(OrtGetTensorMutableData(inputs[0], reinterpret_cast(const_cast(&X)))); - ORT_THROW_ON_ERROR(OrtGetTensorMutableData(inputs[1], reinterpret_cast(const_cast(&Y)))); + ORT_THROW_ON_ERROR(ort_.GetTensorMutableData(inputs[0], reinterpret_cast(const_cast(&X)))); + ORT_THROW_ON_ERROR(ort_.GetTensorMutableData(inputs[1], reinterpret_cast(const_cast(&Y)))); float* out; - ORT_THROW_ON_ERROR(OrtGetTensorMutableData(outputs[0], reinterpret_cast(&out))); + ORT_THROW_ON_ERROR(ort_.GetTensorMutableData(outputs[0], reinterpret_cast(&out))); - int64_t size = OrtTensorDimensions(inputs[0]).ElementCount(); + int64_t size = OrtTensorDimensions(ort_, inputs[0]).ElementCount(); for (int64_t i = 0; i < size; i++) { out[i] = X[i] + Y[i]; } } + + private: + const OrtCustomOpApi& ort_; }; struct MyCustomOp : OrtCustomOp { MyCustomOp() { OrtCustomOp::version = ORT_API_VERSION; - OrtCustomOp::CreateKernel = [](OrtCustomOp* /*this_*/, OrtKernelInfo* info, void** output) { *output = new MyCustomKernel(*info); }; + OrtCustomOp::CreateKernel = [](OrtCustomOp* /*this_*/, const OrtCustomOpApi* api, const OrtKernelInfo* info, void** output) { *output = new MyCustomKernel(*api, *info); }; OrtCustomOp::GetName = [](OrtCustomOp* /*this_*/) { return "Foo"; }; static const ONNXTensorElementDataType c_inputTypes[] = {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT};