From fda1d0dce926f26056da95ff30eb456ffc7cf8cc Mon Sep 17 00:00:00 2001 From: Ryan Hill <38674843+RyanUnderhill@users.noreply.github.com> Date: Fri, 5 Apr 2019 18:53:20 -0700 Subject: [PATCH] Ryanunderhill/ocr custom op (#744) * Adding a custom op interface to the C API to remove shared library dependency. * Remove old custom op test * Rework how custom ops handle inputs/outputs to enable custom op output shape calculation in the compute method * Add a nicer C++ API for custom ops and switch the tests to use it. --- cmake/onnxruntime_unittests.cmake | 12 +- .../C_Api_Sample.cpp | 11 +- .../core/framework/custom_ops_author.h | 28 ---- .../core/session/onnxruntime_c_api.h | 28 ++-- .../core/session/onnxruntime_cxx_api.h | 84 +++++++++++ onnxruntime/core/session/custom_ops.cc | 138 ++++++++++++++++++ onnxruntime/core/session/custom_ops.h | 8 + onnxruntime/core/session/inference_session.cc | 121 +-------------- .../custom_op_shared_lib/test_custom_op.cc | 83 ----------- onnxruntime/test/shared_lib/test_inference.cc | 127 ++++++++-------- 10 files changed, 326 insertions(+), 314 deletions(-) delete mode 100644 include/onnxruntime/core/framework/custom_ops_author.h create mode 100644 onnxruntime/core/session/custom_ops.cc create mode 100644 onnxruntime/core/session/custom_ops.h delete mode 100644 onnxruntime/test/custom_op_shared_lib/test_custom_op.cc diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index fe210d551a..0fa7419482 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -474,22 +474,12 @@ set_target_properties(onnxruntime_perf_test PROPERTIES FOLDER "ONNXRuntimeTest") # shared lib if (onnxruntime_BUILD_SHARED_LIB) - if (UNIX) - # test custom op shared lib - file(GLOB onnxruntime_custom_op_shared_lib_test_srcs "${ONNXRUNTIME_ROOT}/test/custom_op_shared_lib/test_custom_op.cc") - add_library(onnxruntime_custom_op_shared_lib_test SHARED ${onnxruntime_custom_op_shared_lib_test_srcs}) - onnxruntime_add_include_to_target(onnxruntime_custom_op_shared_lib_test gsl) - add_dependencies(onnxruntime_custom_op_shared_lib_test onnx_proto ${onnxruntime_EXTERNAL_DEPENDENCIES}) - target_include_directories(onnxruntime_custom_op_shared_lib_test PUBLIC "${PROJECT_SOURCE_DIR}/include") - target_link_libraries(onnxruntime_custom_op_shared_lib_test PRIVATE onnxruntime onnx onnx_proto protobuf::libprotobuf) - set_target_properties(onnxruntime_custom_op_shared_lib_test PROPERTIES FOLDER "ONNXRuntimeSharedLibTest") - endif() add_library(onnxruntime_mocked_allocator ${ONNXRUNTIME_ROOT}/test/util/test_allocator.cc) target_include_directories(onnxruntime_mocked_allocator PUBLIC ${ONNXRUNTIME_ROOT}/test/util/include) set_target_properties(onnxruntime_mocked_allocator PROPERTIES FOLDER "ONNXRuntimeTest") ################################################################# - # test inference using shared lib + custom op + # test inference using shared lib set (ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR "${ONNXRUNTIME_ROOT}/test/shared_lib") set (onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_fixture.h diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/C_Api_Sample.cpp b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/C_Api_Sample.cpp index da7efe3c78..db033bcb4a 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/C_Api_Sample.cpp +++ b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/C_Api_Sample.cpp @@ -29,15 +29,15 @@ int main(int argc, char* argv[]) { CHECK_STATUS(OrtCreateEnv(ORT_LOGGING_LEVEL_WARNING, "test", &env)); // initialize session options if needed - OrtSessionOptions* session_option = OrtCreateSessionOptions(); - OrtSetSessionThreadPoolSize(session_option, 1); + OrtSessionOptions* session_options = OrtCreateSessionOptions(); + OrtSetSessionThreadPoolSize(session_options, 1); // Sets graph optimization level - // Available levels are + // Available levels are // 0 -> To disable all optimizations // 1 -> To enable basic optimizations (Such as redundant node removals) // 2 -> To enable all optimizations (Includes level 1 + more complex optimizations like node fusions) - OrtSetSessionGraphOptimizationLevel(session_option, 1); + OrtSetSessionGraphOptimizationLevel(session_options, 1); //************************************************************************* // create session and load model into memory @@ -45,7 +45,7 @@ int main(int argc, char* argv[]) { // URL = https://github.com/onnx/models/tree/master/squeezenet OrtSession* session; const wchar_t* model_path = L"squeezenet.onnx"; - CHECK_STATUS(OrtCreateSession(env, model_path, session_option, &session)); + CHECK_STATUS(OrtCreateSession(env, model_path, session_options, &session)); //************************************************************************* // print model input layer (node names, types, shape etc.) @@ -149,6 +149,7 @@ int main(int argc, char* argv[]) { OrtReleaseValue(output_tensor); OrtReleaseValue(input_tensor); OrtReleaseSession(session); + OrtReleaseSessionOptions(session_options); OrtReleaseEnv(env); printf("Done!\n"); return 0; diff --git a/include/onnxruntime/core/framework/custom_ops_author.h b/include/onnxruntime/core/framework/custom_ops_author.h deleted file mode 100644 index f5a03a3e0a..0000000000 --- a/include/onnxruntime/core/framework/custom_ops_author.h +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -/** - this header should include all the headers that are required to build a custom op so that - custom op developers don't have to worry about which headers to include, etc. -*/ -#include "core/framework/op_kernel.h" - -struct KernelsContainer { - std::vector<::onnxruntime::KernelCreateInfo> kernels_list; -}; - -struct SchemasContainer { - std::vector schemas_list; - std::string domain; - int baseline_opset_version; - int opset_version; -}; - -extern "C" { - KernelsContainer* GetAllKernels(); - SchemasContainer* GetAllSchemas(); - void FreeKernelsContainer(KernelsContainer*); - void FreeSchemasContainer(SchemasContainer*); -} diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 54085a2c08..2dcd728705 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -534,26 +534,28 @@ ORT_API_STATUS(OrtCreateValue, OrtValue** const in, int num_values, enum ONNXTyp */ struct OrtKernelInfo; typedef struct OrtKernelInfo OrtKernelInfo; - -/* - * These allow reading node attributes during kernel creation -*/ -ORT_API_STATUS(OrtKernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out); -ORT_API_STATUS(OrtKernelInfoGetAttribute_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out); +struct OrtKernelContext; +typedef struct OrtKernelContext OrtKernelContext; struct OrtCustomOpApi { + /* + * These allow reading node attributes during kernel creation + */ OrtStatus*(ORT_API_CALL* KernelInfoGetAttribute_float)(_In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out); OrtStatus*(ORT_API_CALL* KernelInfoGetAttribute_int64)(_In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out); OrtStatus*(ORT_API_CALL* GetTensorShapeAndType)(_In_ const OrtValue* value, _Out_ OrtTensorTypeAndShapeInfo** out); - size_t(ORT_API_CALL* GetNumOfDimensions)(_In_ const OrtTensorTypeAndShapeInfo* info); + int64_t(ORT_API_CALL* GetTensorShapeElementCount)(_In_ const OrtTensorTypeAndShapeInfo* info); + size_t(ORT_API_CALL* GetDimensionCount)(_In_ const OrtTensorTypeAndShapeInfo* info); void(ORT_API_CALL* GetDimensions)(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length); - OrtStatus*(ORT_API_CALL* SetDims)(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count); - - OrtStatus*(ORT_API_CALL* GetTensorMutableData)(_Inout_ OrtValue* value, _Out_ void** out); + OrtStatus*(ORT_API_CALL* SetDimensions)(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count); + OrtStatus*(ORT_API_CALL* GetTensorMutableData)(_Inout_ OrtValue* value, void** data); void(ORT_API_CALL* ReleaseTensorTypeAndShapeInfo)(OrtTensorTypeAndShapeInfo* input); + + OrtValue*(ORT_API_CALL* KernelContext_GetInput)(OrtKernelContext* context, _In_ size_t index); + OrtValue*(ORT_API_CALL* KernelContext_GetOutput)(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count); }; typedef struct OrtCustomOpApi OrtCustomOpApi; @@ -565,7 +567,7 @@ struct OrtCustomOp { uint32_t version; // Initialize to ORT_API_VERSION // This callback creates the kernel, which is a user defined parameter that is passed to the Kernel* callbacks below. - void(ORT_API_CALL* CreateKernel)(_In_ struct OrtCustomOp* op, _In_ const OrtCustomOpApi* api, _In_ const OrtKernelInfo* info, _Out_ void** op_kernel); + void*(ORT_API_CALL* CreateKernel)(_In_ struct OrtCustomOp* op, _In_ const OrtCustomOpApi* api, _In_ const OrtKernelInfo* info); // Returns the name of the op const char*(ORT_API_CALL* GetName)(_In_ struct OrtCustomOp* op); @@ -577,8 +579,8 @@ struct OrtCustomOp { size_t(ORT_API_CALL* GetOutputTypeCount)(_In_ struct OrtCustomOp* op); // Op kernel callbacks - void(ORT_API_CALL* KernelGetOutputShape)(_In_ void* op_kernel, _In_ OrtValue** inputs, _In_ size_t input_count, _In_ size_t output_index, _In_ OrtTensorTypeAndShapeInfo* output); - void(ORT_API_CALL* KernelCompute)(_In_ void* op_kernel, _In_ OrtValue** inputs, _In_ size_t input_count, _In_ OrtValue** outputs, _In_ size_t output_count); + void(ORT_API_CALL* KernelGetOutputShape)(_In_ void* op_kernel, _In_ OrtKernelContext* context, _In_ size_t output_index, _In_ OrtTensorTypeAndShapeInfo* output); + void(ORT_API_CALL* KernelCompute)(_In_ void* op_kernel, _In_ OrtKernelContext* context); void(ORT_API_CALL* KernelDestroy)(_In_ void* op_kernel); }; typedef struct OrtCustomOp OrtCustomOp; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 0d6ce33c6b..c0be19e149 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -147,6 +147,90 @@ inline std::vector GetTensorShape(const OrtTensorTypeAndShapeInfo* info OrtGetDimensions(info, ret.data(), ret.size()); return ret; } + +struct CustomOpApi { + CustomOpApi(const OrtCustomOpApi& api) : api_(api) {} + + template + T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name); + + OrtTensorTypeAndShapeInfo* GetTensorShapeAndType(_In_ const OrtValue* value) { + OrtTensorTypeAndShapeInfo* out; + ORT_THROW_ON_ERROR(api_.GetTensorShapeAndType(value, &out)); + return out; + } + + int64_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) { + return api_.GetTensorShapeElementCount(info); + } + + size_t GetDimensionCount(_In_ const OrtTensorTypeAndShapeInfo* info) { + return api_.GetDimensionCount(info); + } + + void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) { + api_.GetDimensions(info, dim_values, dim_values_length); + } + + void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) { + api_.SetDimensions(info, dim_values, dim_count); + } + + template + T* GetTensorMutableData(_Inout_ OrtValue* value) { + T* data; + ORT_THROW_ON_ERROR(api_.GetTensorMutableData(value, reinterpret_cast(&data))); + return data; + } + + void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) { + api_.ReleaseTensorTypeAndShapeInfo(input); + } + + OrtValue* KernelContext_GetInput(OrtKernelContext* context, _In_ size_t index) { + return api_.KernelContext_GetInput(context, index); + } + OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count) { + return api_.KernelContext_GetOutput(context, index, dim_values, dim_count); + } + + private: + const OrtCustomOpApi& api_; +}; + +template <> +inline float CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) { + float out; + ORT_THROW_ON_ERROR(api_.KernelInfoGetAttribute_float(info, name, &out)); + return out; +} + +template <> +inline int64_t CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) { + int64_t out; + ORT_THROW_ON_ERROR(api_.KernelInfoGetAttribute_int64(info, name, &out)); + return out; +} + +template +struct CustomOpBase : OrtCustomOp { + CustomOpBase() { + OrtCustomOp::version = ORT_API_VERSION; + OrtCustomOp::CreateKernel = [](OrtCustomOp* this_, const OrtCustomOpApi* api, const OrtKernelInfo* info) { return static_cast(this_)->CreateKernel(*api, info); }; + OrtCustomOp::GetName = [](OrtCustomOp* this_) { return static_cast(this_)->GetName(); }; + + OrtCustomOp::GetInputTypeCount = [](OrtCustomOp* this_) { return static_cast(this_)->GetInputTypeCount(); }; + OrtCustomOp::GetInputType = [](OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetInputType(index); }; + + OrtCustomOp::GetOutputTypeCount = [](OrtCustomOp* this_) { return static_cast(this_)->GetOutputTypeCount(); }; + OrtCustomOp::GetOutputType = [](OrtCustomOp* this_, size_t index) { return static_cast(this_)->GetOutputType(index); }; + + OrtCustomOp::KernelGetOutputShape = [](void* op_kernel, OrtKernelContext* context, size_t output_index, OrtTensorTypeAndShapeInfo* output) { static_cast(op_kernel)->GetOutputShape(context, output_index, output); }; + OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { static_cast(op_kernel)->Compute(context); }; + OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast(op_kernel); }; + } +}; + } // namespace onnxruntime #undef ORT_REDIRECT_SIMPLE_FUNCTION_CALL diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc new file mode 100644 index 0000000000..165bb13ecd --- /dev/null +++ b/onnxruntime/core/session/custom_ops.cc @@ -0,0 +1,138 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef _WIN32 +#pragma warning(disable : 4267) +#endif + +#include "core/session/inference_session.h" +#include "core/framework/customregistry.h" +#include "core/framework/data_types.h" +#include "core/framework/op_kernel_info.h" +#include "core/framework/op_kernel_context_internal.h" +#include "core/framework/error_code_helper.h" +#include "core/framework/tensor_type_and_shape.h" + +ONNXTensorElementDataType MLDataTypeToOnnxRuntimeTensorElementDataType(const onnxruntime::DataTypeImpl* cpp_type); + +ORT_API_STATUS_IMPL(OrtKernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out) { + auto status = reinterpret_cast(info)->GetAttr(name, out); + if (status.IsOK()) + return nullptr; + return onnxruntime::ToOrtStatus(status); +} + +ORT_API_STATUS_IMPL(OrtKernelInfoGetAttribute_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out) { + auto status = reinterpret_cast(info)->GetAttr(name, out); + if (status.IsOK()) + return nullptr; + return onnxruntime::ToOrtStatus(status); +} + +OrtValue* OrtKernelContext_GetInput(OrtKernelContext* context, _In_ size_t index) { + return reinterpret_cast(const_cast(reinterpret_cast(context)->GetInputMLValue(index))); +}; + +OrtValue* OrtKernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count) { + onnxruntime::TensorShape shape(dim_values, dim_count); + return reinterpret_cast(reinterpret_cast(context)->OutputMLValue(index, shape)); +}; + +constexpr OrtCustomOpApi g_custom_op_api = { + &OrtKernelInfoGetAttribute_float, + &OrtKernelInfoGetAttribute_int64, + + &OrtGetTensorShapeAndType, + + &OrtGetTensorShapeElementCount, + &OrtGetNumOfDimensions, + &OrtGetDimensions, + &OrtSetDims, + + &OrtGetTensorMutableData, + + &OrtReleaseTensorTypeAndShapeInfo, + + &OrtKernelContext_GetInput, + &OrtKernelContext_GetOutput, +}; + +namespace onnxruntime { + +struct CustomOpKernel : OpKernel { + CustomOpKernel(const OpKernelInfo& info, OrtCustomOp& op) : OpKernel(info), op_(op) { + if (op_.version != 1) + throw std::invalid_argument("Unsupported version '" + std::to_string(op_.version) + "' in custom op '" + op.GetName(&op)); + op_kernel_ = op_.CreateKernel(&op_, &g_custom_op_api, reinterpret_cast(const_cast(&info))); + } + + ~CustomOpKernel() { + op_.KernelDestroy(op_kernel_); + } + + Status Compute(OpKernelContext* ctx) const override { + auto* ictx = static_cast(ctx); + op_.KernelCompute(op_kernel_, reinterpret_cast(ictx)); + return Status::OK(); + } + + private: + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CustomOpKernel); + + OrtCustomOp& op_; + void* op_kernel_; +}; + +common::Status CreateCustomRegistry(const std::vector& op_domains, std::shared_ptr& output) { + output = std::make_shared(); + + for (auto& domain : op_domains) { + if (domain->domain_[0]) + ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance().AddDomainToVersion(domain->domain_, 1, 1000); + + std::vector schemas_list; + + for (auto& op : domain->custom_ops_) { + ONNX_NAMESPACE::OpSchema schema(op->GetName(op), "unknown", 0); + + auto input_count = op->GetInputTypeCount(op); + for (size_t i = 0; i < input_count; i++) { + auto type = op->GetInputType(op, i); + + schema.Input(i, "A", "Description", + DataTypeImpl::ToString(onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(type))); + } + + auto output_count = op->GetOutputTypeCount(op); + for (size_t i = 0; i < output_count; i++) { + auto type = op->GetOutputType(op, i); + + schema.Output(i, "A", "Description", + DataTypeImpl::ToString(onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(type))); + } + + schema.SetDomain(domain->domain_); + schema.SinceVersion(1); + schema.AllowUncheckedAttributes(); + schemas_list.push_back(schema); + + KernelDefBuilder def_builder; + def_builder.SetName(op->GetName(op)) + .SetDomain(domain->domain_) + .SinceVersion(1) + .Provider(onnxruntime::kCpuExecutionProvider); + KernelCreateFn kernel_create_fn = [&op](const OpKernelInfo& info) -> OpKernel* { return new CustomOpKernel(info, *op); }; + KernelCreateInfo create_info(def_builder.Build(), kernel_create_fn); + + output->RegisterCustomKernel(create_info); + } + + ORT_RETURN_IF_ERROR(output->RegisterOpSet(schemas_list, + domain->domain_, + 1 /* baseline opset version */, + 1000 /* opset version */)); + } + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/session/custom_ops.h b/onnxruntime/core/session/custom_ops.h new file mode 100644 index 0000000000..7afbe22737 --- /dev/null +++ b/onnxruntime/core/session/custom_ops.h @@ -0,0 +1,8 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace onnxruntime { + +common::Status CreateCustomRegistry(const std::vector& op_domains, std::shared_ptr& output); + +} diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 3cc4455618..e58fa5739a 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -43,8 +43,8 @@ #include "core/optimizer/insert_cast_transformer.h" #include "core/optimizer/transformer_memcpy.h" #include "core/providers/cpu/cpu_execution_provider.h" -#include "core/framework/custom_ops_author.h" #include "core/session/IOBinding.h" +#include "core/session/custom_ops.h" #include "core/util/protobuf_parsing_utils.h" #include "core/optimizer/rule_based_graph_transformer.h" #include "core/optimizer/graph_transformer_utils.h" @@ -55,37 +55,6 @@ using namespace ONNX_NAMESPACE; -constexpr OrtCustomOpApi g_custom_op_api = { - &OrtKernelInfoGetAttribute_float, - &OrtKernelInfoGetAttribute_int64, - - &OrtGetTensorShapeAndType, - - &OrtGetNumOfDimensions, - &OrtGetDimensions, - &OrtSetDims, - - &OrtGetTensorMutableData, - - &OrtReleaseTensorTypeAndShapeInfo, -}; - -ONNXTensorElementDataType MLDataTypeToOnnxRuntimeTensorElementDataType(const onnxruntime::DataTypeImpl* cpp_type); - -ORT_API_STATUS_IMPL(OrtKernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out) { - auto status = reinterpret_cast(info)->GetAttr(name, out); - if (status.IsOK()) - return nullptr; - return onnxruntime::ToOrtStatus(status); -} - -ORT_API_STATUS_IMPL(OrtKernelInfoGetAttribute_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out) { - auto status = reinterpret_cast(info)->GetAttr(name, out); - if (status.IsOK()) - return nullptr; - return onnxruntime::ToOrtStatus(status); -} - namespace onnxruntime { namespace { template @@ -121,42 +90,6 @@ inline std::basic_string GetCurrentTimeString() { return std::basic_string(time_str); } } // namespace -struct CustomOpKernel : OpKernel { - CustomOpKernel(const OpKernelInfo& info, OrtCustomOp& op) : OpKernel(info), op_(op) { - if (op_.version != 1) - throw std::invalid_argument("Unsupported version '" + std::to_string(op_.version) + "' in custom op '" + op.GetName(&op)); - op_.CreateKernel(&op_, &g_custom_op_api, reinterpret_cast(const_cast(&info)), &op_kernel_); - } - - ~CustomOpKernel() { - op_.KernelDestroy(op_kernel_); - } - - Status Compute(OpKernelContext* ctx) const override { - auto* ictx = static_cast(ctx); - std::vector input_tensors; - auto input_count = ictx->InputCount(); - for (int i = 0; i < input_count; i++) - input_tensors.emplace_back(const_cast(reinterpret_cast(ictx->GetInputMLValue(i)))); - - std::vector output_tensors; - auto output_count = ictx->OutputCount(); - for (int i = 0; i < output_count; i++) { - OrtTensorTypeAndShapeInfo info; - op_.KernelGetOutputShape(op_kernel_, input_tensors.data(), input_tensors.size(), i, &info); - output_tensors.emplace_back(reinterpret_cast(ictx->OutputMLValue(0, info.shape))); - } - - op_.KernelCompute(op_kernel_, input_tensors.data(), input_tensors.size(), output_tensors.data(), output_tensors.size()); - return Status::OK(); - } - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CustomOpKernel); - - OrtCustomOp& op_; - void* op_kernel_; -}; InferenceSession::InferenceSession(const SessionOptions& session_options, logging::LoggingManager* logging_manager) : session_state_{execution_providers_}, @@ -236,55 +169,8 @@ common::Status InferenceSession::AddCustomTransformerList(const std::vector& op_domains) { - auto custom_registry = std::make_shared(); - - for (auto& domain : op_domains) { - SchemasContainer schemas_container; - - schemas_container.domain = domain->domain_; - schemas_container.baseline_opset_version = 1; - schemas_container.opset_version = 1000; - - for (auto& op : domain->custom_ops_) { - ONNX_NAMESPACE::OpSchema schema(op->GetName(op), "unknown", 0); - - auto input_count = op->GetInputTypeCount(op); - for (size_t i = 0; i < input_count; i++) { - auto type = op->GetInputType(op, i); - - schema.Input(i, "A", "Description", - DataTypeImpl::ToString(onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(type))); - } - - auto output_count = op->GetOutputTypeCount(op); - for (size_t i = 0; i < output_count; i++) { - auto type = op->GetOutputType(op, i); - - schema.Output(i, "A", "Description", - DataTypeImpl::ToString(onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(type))); - } - - schema.SinceVersion(1); - schema.AllowUncheckedAttributes(); - - schemas_container.schemas_list.push_back(schema); - - KernelDefBuilder def_builder; - def_builder.SetName(op->GetName(op)) - .SetDomain(onnxruntime::kOnnxDomain) - .SinceVersion(1) - .Provider(onnxruntime::kCpuExecutionProvider); - KernelCreateFn kernel_create_fn = [&op](const OpKernelInfo& info) -> OpKernel* { return new CustomOpKernel(info, *op); }; - KernelCreateInfo create_info(def_builder.Build(), kernel_create_fn); - - custom_registry->RegisterCustomKernel(create_info); - } - - ORT_RETURN_IF_ERROR(custom_registry->RegisterOpSet(schemas_container.schemas_list, - schemas_container.domain, - schemas_container.baseline_opset_version, - schemas_container.opset_version)); - } + std::shared_ptr custom_registry; + ORT_RETURN_IF_ERROR(CreateCustomRegistry(op_domains, custom_registry)); RegisterCustomRegistry(custom_registry); return Status::OK(); } @@ -1054,4 +940,5 @@ common::Status InferenceSession::WaitForNotification(Notification* p_executor_do return Status::OK(); } + } // namespace onnxruntime diff --git a/onnxruntime/test/custom_op_shared_lib/test_custom_op.cc b/onnxruntime/test/custom_op_shared_lib/test_custom_op.cc deleted file mode 100644 index 298c70fa2e..0000000000 --- a/onnxruntime/test/custom_op_shared_lib/test_custom_op.cc +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// example custom op - -#include "core/framework/custom_ops_author.h" -#include "core/session/onnxruntime_c_api.h" - -using namespace onnxruntime; -using namespace onnxruntime::common; -using namespace ONNX_NAMESPACE; - -class FooKernel : public OpKernel { - public: - FooKernel(const OpKernelInfo& info) : OpKernel(info) { - } - - Status Compute(OpKernelContext* ctx) const override { - const Tensor* X = ctx->Input(0); - const Tensor* W = ctx->Input(1); - auto* X_data = X->template Data(); - auto* W_data = W->template Data(); - Tensor* Y = ctx->Output(0, X->Shape()); - auto* Y_data = Y->template MutableData(); - - for (int64_t i = 0; i < X->Shape().Size(); i++) { - Y_data[i] = X_data[i] + W_data[i]; - } - - return Status::OK(); - } -}; - -ORT_EXPORT KernelsContainer* GetAllKernels() { - KernelsContainer* kc = new KernelsContainer; - - KernelDefBuilder def_builder; - def_builder.SetName("Foo") - .SetDomain(onnxruntime::kOnnxDomain) - .SinceVersion(7) - .Provider(onnxruntime::kCpuExecutionProvider) - .TypeConstraint("T", DataTypeImpl::GetTensorType()); - KernelCreateFn kernel_create_fn = [](const OpKernelInfo& info) -> OpKernel* { return new FooKernel(info); }; - KernelCreateInfo create_info(def_builder.Build(), kernel_create_fn); - kc->kernels_list.push_back(std::move(create_info)); - return kc; -} - -ORT_EXPORT SchemasContainer* GetAllSchemas() { - SchemasContainer* sc = new SchemasContainer; - sc->domain = onnxruntime::kOnnxDomain; - sc->baseline_opset_version = 5; - sc->opset_version = 7; - - ONNX_NAMESPACE::OpSchema schema("Foo", "unknown", 0); - schema.Input(0, - "A", - "First operand, should share the type with the second operand.", - "T"); - schema.Input( - 1, - "B", - "Second operand. With broadcasting can be of smaller size than A. " - "If broadcasting is disabled it should be of the same size.", - "T"); - schema.Output(0, "C", "Result, has same dimensions and type as A", "T"); - schema.TypeConstraint( - "T", - OpSchema::numeric_types_for_math_reduction(), - "Constrain input and output types to high-precision numeric tensors."); - schema.SinceVersion(7); - - sc->schemas_list.push_back(schema); - return sc; -} - -ORT_EXPORT void FreeKernelsContainer(KernelsContainer* kc) { - delete kc; -} - -ORT_EXPORT void FreeSchemasContainer(SchemasContainer* sc) { - delete sc; -} diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 295c4abaaf..55ed015164 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -13,24 +13,31 @@ #include "onnx_protobuf.h" using namespace onnxruntime; +struct Input { + const char* name; + std::vector dims; + std::vector values; +}; + void RunSession(OrtAllocator* env, OrtSession* session_object, - const std::vector& dims_x, - const std::vector& values_x, + const std::vector& inputs, + const char* output_name, const std::vector& dims_y, const std::vector& values_y, OrtValue* output_tensor) { - std::unique_ptr value_x(nullptr, OrtReleaseValue); - std::vector inputs(1); - inputs[0] = OrtCreateTensorAsOrtValue(env, dims_x, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); - value_x.reset(inputs[0]); - void* raw_data; - ORT_THROW_ON_ERROR(OrtGetTensorMutableData(inputs[0], &raw_data)); - memcpy(raw_data, values_x.data(), values_x.size() * sizeof(values_x[0])); - std::vector input_names{"X"}; - const char* output_names[] = {"Y"}; + std::vector ort_inputs; + std::vector> ort_inputs_cleanup; + std::vector input_names; + for (int i = 0; i < inputs.size(); i++) { + input_names.emplace_back(inputs[i].name); + ort_inputs.emplace_back(OrtCreateTensorWithDataAsOrtValue(env->Info(env), (void*)inputs[i].values.data(), inputs[i].values.size() * sizeof(inputs[i].values[0]), inputs[i].dims, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)); + ort_inputs_cleanup.emplace_back(ort_inputs.back(), OrtReleaseValue); + } + + // const char* output_names[] = {"Y"}; bool is_output_allocated_by_ort = output_tensor == nullptr; OrtValue* old_output_ptr = output_tensor; - ORT_THROW_ON_ERROR(OrtRun(session_object, NULL, input_names.data(), inputs.data(), inputs.size(), output_names, 1, &output_tensor)); + ORT_THROW_ON_ERROR(OrtRun(session_object, NULL, input_names.data(), ort_inputs.data(), ort_inputs.size(), &output_name, 1, &output_tensor)); ASSERT_NE(output_tensor, nullptr); if (!is_output_allocated_by_ort) ASSERT_EQ(output_tensor, old_output_ptr); @@ -59,8 +66,8 @@ void RunSession(OrtAllocator* env, OrtSession* session_object, template void TestInference(OrtEnv* env, T model_uri, - const std::vector& dims_x, - const std::vector& values_x, + const std::vector& inputs, + const char* output_name, const std::vector& expected_dims_y, const std::vector& expected_values_y, int provider_type, OrtCustomOpDomain* custom_op_domain_ptr) { @@ -101,8 +108,8 @@ void TestInference(OrtEnv* env, T model_uri, //without preallocated output tensor RunSession(default_allocator.get(), inference_session.get(), - dims_x, - values_x, + inputs, + output_name, expected_dims_y, expected_values_y, nullptr); @@ -123,8 +130,8 @@ void TestInference(OrtEnv* env, T model_uri, for (int i = 0; i != 2; ++i) RunSession(default_allocator.get(), inference_session.get(), - dims_x, - values_x, + inputs, + output_name, expected_dims_y, expected_values_y, value_y.get()); @@ -141,14 +148,17 @@ class CApiTestWithProvider : public CApiTest, TEST_P(CApiTestWithProvider, simple) { // simple inference test // prepare inputs - std::vector dims_x = {3, 2}; - std::vector values_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + std::vector inputs(1); + Input& input = inputs.back(); + input.name = "X"; + input.dims = {3, 2}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; // prepare expected inputs and outputs std::vector expected_dims_y = {3, 2}; std::vector expected_values_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}; - TestInference(env, MODEL_URI, dims_x, values_x, expected_dims_y, expected_values_y, GetParam(), nullptr); + TestInference(env, MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, GetParam(), nullptr); } INSTANTIATE_TEST_CASE_P(CApiTestWithProviders, @@ -156,10 +166,9 @@ INSTANTIATE_TEST_CASE_P(CApiTestWithProviders, ::testing::Values(0, 1, 2, 3, 4)); struct OrtTensorDimensions : std::vector { - OrtTensorDimensions(const OrtCustomOpApi& ort, OrtValue* value) { - OrtTensorTypeAndShapeInfo* info; - ORT_THROW_ON_ERROR(ort.GetTensorShapeAndType(value, &info)); - auto dimensionCount = ort.GetNumOfDimensions(info); + OrtTensorDimensions(onnxruntime::CustomOpApi ort, OrtValue* value) { + OrtTensorTypeAndShapeInfo* info = ort.GetTensorShapeAndType(value); + auto dimensionCount = ort.GetDimensionCount(info); resize(dimensionCount); ort.GetDimensions(info, data(), dimensionCount); ort.ReleaseTensorTypeAndShapeInfo(info); @@ -178,57 +187,60 @@ template constexpr size_t countof(T (&)[N]) { return N; } struct MyCustomKernel { - MyCustomKernel(const OrtCustomOpApi& ort, const OrtKernelInfo& /*info*/) : ort_(ort) { + MyCustomKernel(onnxruntime::CustomOpApi ort, const OrtKernelInfo* /*info*/) : ort_(ort) { } - void GetOutputShape(OrtValue** inputs, size_t /*input_count*/, size_t /*output_index*/, OrtTensorTypeAndShapeInfo* info) { - OrtTensorDimensions dimensions(ort_, inputs[0]); - ORT_THROW_ON_ERROR(ort_.SetDims(info, dimensions.data(), dimensions.size())); + void GetOutputShape(OrtKernelContext* context, size_t /*output_index*/, OrtTensorTypeAndShapeInfo* info) { + OrtValue* input_X = ort_.KernelContext_GetInput(context, 0); + OrtTensorDimensions dimensions(ort_, input_X); + ort_.SetDimensions(info, dimensions.data(), dimensions.size()); } - void Compute(OrtValue** inputs, size_t /*input_count*/, OrtValue** outputs, size_t /*output_count*/) { - const float* X; - const float* Y; - ORT_THROW_ON_ERROR(ort_.GetTensorMutableData(inputs[0], reinterpret_cast(const_cast(&X)))); - ORT_THROW_ON_ERROR(ort_.GetTensorMutableData(inputs[1], reinterpret_cast(const_cast(&Y)))); + void Compute(OrtKernelContext* context) { + // Setup inputs + OrtValue* input_X = ort_.KernelContext_GetInput(context, 0); + OrtValue* input_Y = ort_.KernelContext_GetInput(context, 1); + float* X = ort_.GetTensorMutableData(input_X); + float* Y = ort_.GetTensorMutableData(input_Y); - float* out; - ORT_THROW_ON_ERROR(ort_.GetTensorMutableData(outputs[0], reinterpret_cast(&out))); + // Setup output + OrtTensorDimensions dimensions(ort_, input_X); + OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size()); + float* out = ort_.GetTensorMutableData(output); - int64_t size = OrtTensorDimensions(ort_, inputs[0]).ElementCount(); + OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorShapeAndType(output); + int64_t size = ort_.GetTensorShapeElementCount(output_info); + ort_.ReleaseTensorTypeAndShapeInfo(output_info); + + // Do computation for (int64_t i = 0; i < size; i++) { out[i] = X[i] + Y[i]; } } private: - const OrtCustomOpApi& ort_; + onnxruntime::CustomOpApi ort_; }; -struct MyCustomOp : OrtCustomOp { - MyCustomOp() { - OrtCustomOp::version = ORT_API_VERSION; - OrtCustomOp::CreateKernel = [](OrtCustomOp* /*this_*/, const OrtCustomOpApi* api, const OrtKernelInfo* info, void** output) { *output = new MyCustomKernel(*api, *info); }; - OrtCustomOp::GetName = [](OrtCustomOp* /*this_*/) { return "Foo"; }; +struct MyCustomOp : onnxruntime::CustomOpBase { + void* CreateKernel(onnxruntime::CustomOpApi api, const OrtKernelInfo* info) { return new MyCustomKernel(api, info); }; + const char* GetName() const { return "Foo"; }; - static const ONNXTensorElementDataType c_inputTypes[] = {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT}; - OrtCustomOp::GetInputTypeCount = [](OrtCustomOp* /*this_*/) { return countof(c_inputTypes); }; - OrtCustomOp::GetInputType = [](OrtCustomOp* /*this_*/, size_t index) { return c_inputTypes[index]; }; + size_t GetInputTypeCount() const { return 2; }; + ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; - static const ONNXTensorElementDataType c_outputTypes[] = {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT}; - OrtCustomOp::GetOutputTypeCount = [](OrtCustomOp* /*this_*/) { return countof(c_outputTypes); }; - OrtCustomOp::GetOutputType = [](OrtCustomOp* /*this_*/, size_t index) { return c_outputTypes[index]; }; - - OrtCustomOp::KernelGetOutputShape = [](void* op_kernel, OrtValue** inputs, size_t input_count, size_t output_index, OrtTensorTypeAndShapeInfo* output) { static_cast(op_kernel)->GetOutputShape(inputs, input_count, output_index, output); }; - OrtCustomOp::KernelCompute = [](void* op_kernel, OrtValue** inputs, size_t input_count, OrtValue** outputs, size_t output_count) { static_cast(op_kernel)->Compute(inputs, input_count, outputs, output_count); }; - OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast(op_kernel); }; - } + size_t GetOutputTypeCount() const { return 1; }; + ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; }; TEST_F(CApiTest, custom_op_handler) { std::cout << "Running custom op inference" << std::endl; - std::vector dims_x = {3, 2}; - std::vector values_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + std::vector inputs(1); + Input& input = inputs[0]; + input.name = "X"; + input.dims = {3, 2}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; // prepare expected inputs and outputs std::vector expected_dims_y = {3, 2}; @@ -238,7 +250,7 @@ TEST_F(CApiTest, custom_op_handler) { OrtCustomOpDomain* custom_op_domain = OrtCreateCustomOpDomain(""); ORT_THROW_ON_ERROR(OrtCustomOpDomain_Add(custom_op_domain, &custom_op)); - TestInference(env, CUSTOM_OP_MODEL_URI, dims_x, values_x, expected_dims_y, expected_values_y, 0, custom_op_domain); + TestInference(env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 0, custom_op_domain); OrtReleaseCustomOpDomain(custom_op_domain); } @@ -251,6 +263,7 @@ TEST_F(CApiTest, create_session_without_session_option) { OrtReleaseSession(ret); } #endif + TEST_F(CApiTest, create_tensor) { const char* s[] = {"abc", "kmp"}; int64_t expected_len = 2;