Ryanunderhill/ocr custom op (#744)

* Adding a custom op interface to the C API to remove shared library dependency.
* Remove old custom op test
* Rework how custom ops handle inputs/outputs to enable custom op output shape calculation in the compute method
* Add a nicer C++ API for custom ops and switch the tests to use it.
This commit is contained in:
Ryan Hill 2019-04-05 18:53:20 -07:00 committed by GitHub
parent 58ef1306d4
commit fda1d0dce9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 326 additions and 314 deletions

View file

@ -474,22 +474,12 @@ set_target_properties(onnxruntime_perf_test PROPERTIES FOLDER "ONNXRuntimeTest")
# shared lib
if (onnxruntime_BUILD_SHARED_LIB)
if (UNIX)
# test custom op shared lib
file(GLOB onnxruntime_custom_op_shared_lib_test_srcs "${ONNXRUNTIME_ROOT}/test/custom_op_shared_lib/test_custom_op.cc")
add_library(onnxruntime_custom_op_shared_lib_test SHARED ${onnxruntime_custom_op_shared_lib_test_srcs})
onnxruntime_add_include_to_target(onnxruntime_custom_op_shared_lib_test gsl)
add_dependencies(onnxruntime_custom_op_shared_lib_test onnx_proto ${onnxruntime_EXTERNAL_DEPENDENCIES})
target_include_directories(onnxruntime_custom_op_shared_lib_test PUBLIC "${PROJECT_SOURCE_DIR}/include")
target_link_libraries(onnxruntime_custom_op_shared_lib_test PRIVATE onnxruntime onnx onnx_proto protobuf::libprotobuf)
set_target_properties(onnxruntime_custom_op_shared_lib_test PROPERTIES FOLDER "ONNXRuntimeSharedLibTest")
endif()
add_library(onnxruntime_mocked_allocator ${ONNXRUNTIME_ROOT}/test/util/test_allocator.cc)
target_include_directories(onnxruntime_mocked_allocator PUBLIC ${ONNXRUNTIME_ROOT}/test/util/include)
set_target_properties(onnxruntime_mocked_allocator PROPERTIES FOLDER "ONNXRuntimeTest")
#################################################################
# test inference using shared lib + custom op
# test inference using shared lib
set (ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR "${ONNXRUNTIME_ROOT}/test/shared_lib")
set (onnxruntime_shared_lib_test_SRC
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_fixture.h

View file

@ -29,15 +29,15 @@ int main(int argc, char* argv[]) {
CHECK_STATUS(OrtCreateEnv(ORT_LOGGING_LEVEL_WARNING, "test", &env));
// initialize session options if needed
OrtSessionOptions* session_option = OrtCreateSessionOptions();
OrtSetSessionThreadPoolSize(session_option, 1);
OrtSessionOptions* session_options = OrtCreateSessionOptions();
OrtSetSessionThreadPoolSize(session_options, 1);
// Sets graph optimization level
// Available levels are
// Available levels are
// 0 -> To disable all optimizations
// 1 -> To enable basic optimizations (Such as redundant node removals)
// 2 -> To enable all optimizations (Includes level 1 + more complex optimizations like node fusions)
OrtSetSessionGraphOptimizationLevel(session_option, 1);
OrtSetSessionGraphOptimizationLevel(session_options, 1);
//*************************************************************************
// create session and load model into memory
@ -45,7 +45,7 @@ int main(int argc, char* argv[]) {
// URL = https://github.com/onnx/models/tree/master/squeezenet
OrtSession* session;
const wchar_t* model_path = L"squeezenet.onnx";
CHECK_STATUS(OrtCreateSession(env, model_path, session_option, &session));
CHECK_STATUS(OrtCreateSession(env, model_path, session_options, &session));
//*************************************************************************
// print model input layer (node names, types, shape etc.)
@ -149,6 +149,7 @@ int main(int argc, char* argv[]) {
OrtReleaseValue(output_tensor);
OrtReleaseValue(input_tensor);
OrtReleaseSession(session);
OrtReleaseSessionOptions(session_options);
OrtReleaseEnv(env);
printf("Done!\n");
return 0;

View file

@ -1,28 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
/**
this header should include all the headers that are required to build a custom op so that
custom op developers don't have to worry about which headers to include, etc.
*/
#include "core/framework/op_kernel.h"
struct KernelsContainer {
std::vector<::onnxruntime::KernelCreateInfo> kernels_list;
};
struct SchemasContainer {
std::vector<ONNX_NAMESPACE::OpSchema> schemas_list;
std::string domain;
int baseline_opset_version;
int opset_version;
};
extern "C" {
KernelsContainer* GetAllKernels();
SchemasContainer* GetAllSchemas();
void FreeKernelsContainer(KernelsContainer*);
void FreeSchemasContainer(SchemasContainer*);
}

View file

@ -534,26 +534,28 @@ ORT_API_STATUS(OrtCreateValue, OrtValue** const in, int num_values, enum ONNXTyp
*/
struct OrtKernelInfo;
typedef struct OrtKernelInfo OrtKernelInfo;
/*
* These allow reading node attributes during kernel creation
*/
ORT_API_STATUS(OrtKernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out);
ORT_API_STATUS(OrtKernelInfoGetAttribute_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out);
struct OrtKernelContext;
typedef struct OrtKernelContext OrtKernelContext;
struct OrtCustomOpApi {
/*
* These allow reading node attributes during kernel creation
*/
OrtStatus*(ORT_API_CALL* KernelInfoGetAttribute_float)(_In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out);
OrtStatus*(ORT_API_CALL* KernelInfoGetAttribute_int64)(_In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out);
OrtStatus*(ORT_API_CALL* GetTensorShapeAndType)(_In_ const OrtValue* value, _Out_ OrtTensorTypeAndShapeInfo** out);
size_t(ORT_API_CALL* GetNumOfDimensions)(_In_ const OrtTensorTypeAndShapeInfo* info);
int64_t(ORT_API_CALL* GetTensorShapeElementCount)(_In_ const OrtTensorTypeAndShapeInfo* info);
size_t(ORT_API_CALL* GetDimensionCount)(_In_ const OrtTensorTypeAndShapeInfo* info);
void(ORT_API_CALL* GetDimensions)(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length);
OrtStatus*(ORT_API_CALL* SetDims)(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count);
OrtStatus*(ORT_API_CALL* GetTensorMutableData)(_Inout_ OrtValue* value, _Out_ void** out);
OrtStatus*(ORT_API_CALL* SetDimensions)(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count);
OrtStatus*(ORT_API_CALL* GetTensorMutableData)(_Inout_ OrtValue* value, void** data);
void(ORT_API_CALL* ReleaseTensorTypeAndShapeInfo)(OrtTensorTypeAndShapeInfo* input);
OrtValue*(ORT_API_CALL* KernelContext_GetInput)(OrtKernelContext* context, _In_ size_t index);
OrtValue*(ORT_API_CALL* KernelContext_GetOutput)(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count);
};
typedef struct OrtCustomOpApi OrtCustomOpApi;
@ -565,7 +567,7 @@ struct OrtCustomOp {
uint32_t version; // Initialize to ORT_API_VERSION
// This callback creates the kernel, which is a user defined parameter that is passed to the Kernel* callbacks below.
void(ORT_API_CALL* CreateKernel)(_In_ struct OrtCustomOp* op, _In_ const OrtCustomOpApi* api, _In_ const OrtKernelInfo* info, _Out_ void** op_kernel);
void*(ORT_API_CALL* CreateKernel)(_In_ struct OrtCustomOp* op, _In_ const OrtCustomOpApi* api, _In_ const OrtKernelInfo* info);
// Returns the name of the op
const char*(ORT_API_CALL* GetName)(_In_ struct OrtCustomOp* op);
@ -577,8 +579,8 @@ struct OrtCustomOp {
size_t(ORT_API_CALL* GetOutputTypeCount)(_In_ struct OrtCustomOp* op);
// Op kernel callbacks
void(ORT_API_CALL* KernelGetOutputShape)(_In_ void* op_kernel, _In_ OrtValue** inputs, _In_ size_t input_count, _In_ size_t output_index, _In_ OrtTensorTypeAndShapeInfo* output);
void(ORT_API_CALL* KernelCompute)(_In_ void* op_kernel, _In_ OrtValue** inputs, _In_ size_t input_count, _In_ OrtValue** outputs, _In_ size_t output_count);
void(ORT_API_CALL* KernelGetOutputShape)(_In_ void* op_kernel, _In_ OrtKernelContext* context, _In_ size_t output_index, _In_ OrtTensorTypeAndShapeInfo* output);
void(ORT_API_CALL* KernelCompute)(_In_ void* op_kernel, _In_ OrtKernelContext* context);
void(ORT_API_CALL* KernelDestroy)(_In_ void* op_kernel);
};
typedef struct OrtCustomOp OrtCustomOp;

View file

@ -147,6 +147,90 @@ inline std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info
OrtGetDimensions(info, ret.data(), ret.size());
return ret;
}
struct CustomOpApi {
CustomOpApi(const OrtCustomOpApi& api) : api_(api) {}
template <typename T>
T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name);
OrtTensorTypeAndShapeInfo* GetTensorShapeAndType(_In_ const OrtValue* value) {
OrtTensorTypeAndShapeInfo* out;
ORT_THROW_ON_ERROR(api_.GetTensorShapeAndType(value, &out));
return out;
}
int64_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) {
return api_.GetTensorShapeElementCount(info);
}
size_t GetDimensionCount(_In_ const OrtTensorTypeAndShapeInfo* info) {
return api_.GetDimensionCount(info);
}
void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) {
api_.GetDimensions(info, dim_values, dim_values_length);
}
void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) {
api_.SetDimensions(info, dim_values, dim_count);
}
template <typename T>
T* GetTensorMutableData(_Inout_ OrtValue* value) {
T* data;
ORT_THROW_ON_ERROR(api_.GetTensorMutableData(value, reinterpret_cast<void**>(&data)));
return data;
}
void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) {
api_.ReleaseTensorTypeAndShapeInfo(input);
}
OrtValue* KernelContext_GetInput(OrtKernelContext* context, _In_ size_t index) {
return api_.KernelContext_GetInput(context, index);
}
OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count) {
return api_.KernelContext_GetOutput(context, index, dim_values, dim_count);
}
private:
const OrtCustomOpApi& api_;
};
template <>
inline float CustomOpApi::KernelInfoGetAttribute<float>(_In_ const OrtKernelInfo* info, _In_ const char* name) {
float out;
ORT_THROW_ON_ERROR(api_.KernelInfoGetAttribute_float(info, name, &out));
return out;
}
template <>
inline int64_t CustomOpApi::KernelInfoGetAttribute<int64_t>(_In_ const OrtKernelInfo* info, _In_ const char* name) {
int64_t out;
ORT_THROW_ON_ERROR(api_.KernelInfoGetAttribute_int64(info, name, &out));
return out;
}
template <typename TOp, typename TKernel>
struct CustomOpBase : OrtCustomOp {
CustomOpBase() {
OrtCustomOp::version = ORT_API_VERSION;
OrtCustomOp::CreateKernel = [](OrtCustomOp* this_, const OrtCustomOpApi* api, const OrtKernelInfo* info) { return static_cast<TOp*>(this_)->CreateKernel(*api, info); };
OrtCustomOp::GetName = [](OrtCustomOp* this_) { return static_cast<TOp*>(this_)->GetName(); };
OrtCustomOp::GetInputTypeCount = [](OrtCustomOp* this_) { return static_cast<TOp*>(this_)->GetInputTypeCount(); };
OrtCustomOp::GetInputType = [](OrtCustomOp* this_, size_t index) { return static_cast<TOp*>(this_)->GetInputType(index); };
OrtCustomOp::GetOutputTypeCount = [](OrtCustomOp* this_) { return static_cast<TOp*>(this_)->GetOutputTypeCount(); };
OrtCustomOp::GetOutputType = [](OrtCustomOp* this_, size_t index) { return static_cast<TOp*>(this_)->GetOutputType(index); };
OrtCustomOp::KernelGetOutputShape = [](void* op_kernel, OrtKernelContext* context, size_t output_index, OrtTensorTypeAndShapeInfo* output) { static_cast<TKernel*>(op_kernel)->GetOutputShape(context, output_index, output); };
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { static_cast<TKernel*>(op_kernel)->Compute(context); };
OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); };
}
};
} // namespace onnxruntime
#undef ORT_REDIRECT_SIMPLE_FUNCTION_CALL

View file

@ -0,0 +1,138 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef _WIN32
#pragma warning(disable : 4267)
#endif
#include "core/session/inference_session.h"
#include "core/framework/customregistry.h"
#include "core/framework/data_types.h"
#include "core/framework/op_kernel_info.h"
#include "core/framework/op_kernel_context_internal.h"
#include "core/framework/error_code_helper.h"
#include "core/framework/tensor_type_and_shape.h"
ONNXTensorElementDataType MLDataTypeToOnnxRuntimeTensorElementDataType(const onnxruntime::DataTypeImpl* cpp_type);
ORT_API_STATUS_IMPL(OrtKernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out) {
auto status = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAttr<float>(name, out);
if (status.IsOK())
return nullptr;
return onnxruntime::ToOrtStatus(status);
}
ORT_API_STATUS_IMPL(OrtKernelInfoGetAttribute_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out) {
auto status = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAttr<int64_t>(name, out);
if (status.IsOK())
return nullptr;
return onnxruntime::ToOrtStatus(status);
}
OrtValue* OrtKernelContext_GetInput(OrtKernelContext* context, _In_ size_t index) {
return reinterpret_cast<OrtValue*>(const_cast<onnxruntime::MLValue*>(reinterpret_cast<onnxruntime::OpKernelContextInternal*>(context)->GetInputMLValue(index)));
};
OrtValue* OrtKernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count) {
onnxruntime::TensorShape shape(dim_values, dim_count);
return reinterpret_cast<OrtValue*>(reinterpret_cast<onnxruntime::OpKernelContextInternal*>(context)->OutputMLValue(index, shape));
};
constexpr OrtCustomOpApi g_custom_op_api = {
&OrtKernelInfoGetAttribute_float,
&OrtKernelInfoGetAttribute_int64,
&OrtGetTensorShapeAndType,
&OrtGetTensorShapeElementCount,
&OrtGetNumOfDimensions,
&OrtGetDimensions,
&OrtSetDims,
&OrtGetTensorMutableData,
&OrtReleaseTensorTypeAndShapeInfo,
&OrtKernelContext_GetInput,
&OrtKernelContext_GetOutput,
};
namespace onnxruntime {
struct CustomOpKernel : OpKernel {
CustomOpKernel(const OpKernelInfo& info, OrtCustomOp& op) : OpKernel(info), op_(op) {
if (op_.version != 1)
throw std::invalid_argument("Unsupported version '" + std::to_string(op_.version) + "' in custom op '" + op.GetName(&op));
op_kernel_ = op_.CreateKernel(&op_, &g_custom_op_api, reinterpret_cast<OrtKernelInfo*>(const_cast<OpKernelInfo*>(&info)));
}
~CustomOpKernel() {
op_.KernelDestroy(op_kernel_);
}
Status Compute(OpKernelContext* ctx) const override {
auto* ictx = static_cast<OpKernelContextInternal*>(ctx);
op_.KernelCompute(op_kernel_, reinterpret_cast<OrtKernelContext*>(ictx));
return Status::OK();
}
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CustomOpKernel);
OrtCustomOp& op_;
void* op_kernel_;
};
common::Status CreateCustomRegistry(const std::vector<OrtCustomOpDomain*>& op_domains, std::shared_ptr<CustomRegistry>& output) {
output = std::make_shared<CustomRegistry>();
for (auto& domain : op_domains) {
if (domain->domain_[0])
ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance().AddDomainToVersion(domain->domain_, 1, 1000);
std::vector<ONNX_NAMESPACE::OpSchema> schemas_list;
for (auto& op : domain->custom_ops_) {
ONNX_NAMESPACE::OpSchema schema(op->GetName(op), "unknown", 0);
auto input_count = op->GetInputTypeCount(op);
for (size_t i = 0; i < input_count; i++) {
auto type = op->GetInputType(op, i);
schema.Input(i, "A", "Description",
DataTypeImpl::ToString(onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(type)));
}
auto output_count = op->GetOutputTypeCount(op);
for (size_t i = 0; i < output_count; i++) {
auto type = op->GetOutputType(op, i);
schema.Output(i, "A", "Description",
DataTypeImpl::ToString(onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(type)));
}
schema.SetDomain(domain->domain_);
schema.SinceVersion(1);
schema.AllowUncheckedAttributes();
schemas_list.push_back(schema);
KernelDefBuilder def_builder;
def_builder.SetName(op->GetName(op))
.SetDomain(domain->domain_)
.SinceVersion(1)
.Provider(onnxruntime::kCpuExecutionProvider);
KernelCreateFn kernel_create_fn = [&op](const OpKernelInfo& info) -> OpKernel* { return new CustomOpKernel(info, *op); };
KernelCreateInfo create_info(def_builder.Build(), kernel_create_fn);
output->RegisterCustomKernel(create_info);
}
ORT_RETURN_IF_ERROR(output->RegisterOpSet(schemas_list,
domain->domain_,
1 /* baseline opset version */,
1000 /* opset version */));
}
return Status::OK();
}
} // namespace onnxruntime

View file

@ -0,0 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
namespace onnxruntime {
common::Status CreateCustomRegistry(const std::vector<OrtCustomOpDomain*>& op_domains, std::shared_ptr<CustomRegistry>& output);
}

View file

@ -43,8 +43,8 @@
#include "core/optimizer/insert_cast_transformer.h"
#include "core/optimizer/transformer_memcpy.h"
#include "core/providers/cpu/cpu_execution_provider.h"
#include "core/framework/custom_ops_author.h"
#include "core/session/IOBinding.h"
#include "core/session/custom_ops.h"
#include "core/util/protobuf_parsing_utils.h"
#include "core/optimizer/rule_based_graph_transformer.h"
#include "core/optimizer/graph_transformer_utils.h"
@ -55,37 +55,6 @@
using namespace ONNX_NAMESPACE;
constexpr OrtCustomOpApi g_custom_op_api = {
&OrtKernelInfoGetAttribute_float,
&OrtKernelInfoGetAttribute_int64,
&OrtGetTensorShapeAndType,
&OrtGetNumOfDimensions,
&OrtGetDimensions,
&OrtSetDims,
&OrtGetTensorMutableData,
&OrtReleaseTensorTypeAndShapeInfo,
};
ONNXTensorElementDataType MLDataTypeToOnnxRuntimeTensorElementDataType(const onnxruntime::DataTypeImpl* cpp_type);
ORT_API_STATUS_IMPL(OrtKernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out) {
auto status = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAttr<float>(name, out);
if (status.IsOK())
return nullptr;
return onnxruntime::ToOrtStatus(status);
}
ORT_API_STATUS_IMPL(OrtKernelInfoGetAttribute_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out) {
auto status = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAttr<int64_t>(name, out);
if (status.IsOK())
return nullptr;
return onnxruntime::ToOrtStatus(status);
}
namespace onnxruntime {
namespace {
template <typename T>
@ -121,42 +90,6 @@ inline std::basic_string<T> GetCurrentTimeString() {
return std::basic_string<T>(time_str);
}
} // namespace
struct CustomOpKernel : OpKernel {
CustomOpKernel(const OpKernelInfo& info, OrtCustomOp& op) : OpKernel(info), op_(op) {
if (op_.version != 1)
throw std::invalid_argument("Unsupported version '" + std::to_string(op_.version) + "' in custom op '" + op.GetName(&op));
op_.CreateKernel(&op_, &g_custom_op_api, reinterpret_cast<OrtKernelInfo*>(const_cast<OpKernelInfo*>(&info)), &op_kernel_);
}
~CustomOpKernel() {
op_.KernelDestroy(op_kernel_);
}
Status Compute(OpKernelContext* ctx) const override {
auto* ictx = static_cast<OpKernelContextInternal*>(ctx);
std::vector<OrtValue*> input_tensors;
auto input_count = ictx->InputCount();
for (int i = 0; i < input_count; i++)
input_tensors.emplace_back(const_cast<OrtValue*>(reinterpret_cast<const OrtValue*>(ictx->GetInputMLValue(i))));
std::vector<OrtValue*> output_tensors;
auto output_count = ictx->OutputCount();
for (int i = 0; i < output_count; i++) {
OrtTensorTypeAndShapeInfo info;
op_.KernelGetOutputShape(op_kernel_, input_tensors.data(), input_tensors.size(), i, &info);
output_tensors.emplace_back(reinterpret_cast<OrtValue*>(ictx->OutputMLValue(0, info.shape)));
}
op_.KernelCompute(op_kernel_, input_tensors.data(), input_tensors.size(), output_tensors.data(), output_tensors.size());
return Status::OK();
}
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CustomOpKernel);
OrtCustomOp& op_;
void* op_kernel_;
};
InferenceSession::InferenceSession(const SessionOptions& session_options, logging::LoggingManager* logging_manager)
: session_state_{execution_providers_},
@ -236,55 +169,8 @@ common::Status InferenceSession::AddCustomTransformerList(const std::vector<std:
}
common::Status InferenceSession::AddCustomOpDomains(const std::vector<OrtCustomOpDomain*>& op_domains) {
auto custom_registry = std::make_shared<CustomRegistry>();
for (auto& domain : op_domains) {
SchemasContainer schemas_container;
schemas_container.domain = domain->domain_;
schemas_container.baseline_opset_version = 1;
schemas_container.opset_version = 1000;
for (auto& op : domain->custom_ops_) {
ONNX_NAMESPACE::OpSchema schema(op->GetName(op), "unknown", 0);
auto input_count = op->GetInputTypeCount(op);
for (size_t i = 0; i < input_count; i++) {
auto type = op->GetInputType(op, i);
schema.Input(i, "A", "Description",
DataTypeImpl::ToString(onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(type)));
}
auto output_count = op->GetOutputTypeCount(op);
for (size_t i = 0; i < output_count; i++) {
auto type = op->GetOutputType(op, i);
schema.Output(i, "A", "Description",
DataTypeImpl::ToString(onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(type)));
}
schema.SinceVersion(1);
schema.AllowUncheckedAttributes();
schemas_container.schemas_list.push_back(schema);
KernelDefBuilder def_builder;
def_builder.SetName(op->GetName(op))
.SetDomain(onnxruntime::kOnnxDomain)
.SinceVersion(1)
.Provider(onnxruntime::kCpuExecutionProvider);
KernelCreateFn kernel_create_fn = [&op](const OpKernelInfo& info) -> OpKernel* { return new CustomOpKernel(info, *op); };
KernelCreateInfo create_info(def_builder.Build(), kernel_create_fn);
custom_registry->RegisterCustomKernel(create_info);
}
ORT_RETURN_IF_ERROR(custom_registry->RegisterOpSet(schemas_container.schemas_list,
schemas_container.domain,
schemas_container.baseline_opset_version,
schemas_container.opset_version));
}
std::shared_ptr<CustomRegistry> custom_registry;
ORT_RETURN_IF_ERROR(CreateCustomRegistry(op_domains, custom_registry));
RegisterCustomRegistry(custom_registry);
return Status::OK();
}
@ -1054,4 +940,5 @@ common::Status InferenceSession::WaitForNotification(Notification* p_executor_do
return Status::OK();
}
} // namespace onnxruntime

View file

@ -1,83 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// example custom op
#include "core/framework/custom_ops_author.h"
#include "core/session/onnxruntime_c_api.h"
using namespace onnxruntime;
using namespace onnxruntime::common;
using namespace ONNX_NAMESPACE;
class FooKernel : public OpKernel {
public:
FooKernel(const OpKernelInfo& info) : OpKernel(info) {
}
Status Compute(OpKernelContext* ctx) const override {
const Tensor* X = ctx->Input<Tensor>(0);
const Tensor* W = ctx->Input<Tensor>(1);
auto* X_data = X->template Data<float>();
auto* W_data = W->template Data<float>();
Tensor* Y = ctx->Output(0, X->Shape());
auto* Y_data = Y->template MutableData<float>();
for (int64_t i = 0; i < X->Shape().Size(); i++) {
Y_data[i] = X_data[i] + W_data[i];
}
return Status::OK();
}
};
ORT_EXPORT KernelsContainer* GetAllKernels() {
KernelsContainer* kc = new KernelsContainer;
KernelDefBuilder def_builder;
def_builder.SetName("Foo")
.SetDomain(onnxruntime::kOnnxDomain)
.SinceVersion(7)
.Provider(onnxruntime::kCpuExecutionProvider)
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>());
KernelCreateFn kernel_create_fn = [](const OpKernelInfo& info) -> OpKernel* { return new FooKernel(info); };
KernelCreateInfo create_info(def_builder.Build(), kernel_create_fn);
kc->kernels_list.push_back(std::move(create_info));
return kc;
}
ORT_EXPORT SchemasContainer* GetAllSchemas() {
SchemasContainer* sc = new SchemasContainer;
sc->domain = onnxruntime::kOnnxDomain;
sc->baseline_opset_version = 5;
sc->opset_version = 7;
ONNX_NAMESPACE::OpSchema schema("Foo", "unknown", 0);
schema.Input(0,
"A",
"First operand, should share the type with the second operand.",
"T");
schema.Input(
1,
"B",
"Second operand. With broadcasting can be of smaller size than A. "
"If broadcasting is disabled it should be of the same size.",
"T");
schema.Output(0, "C", "Result, has same dimensions and type as A", "T");
schema.TypeConstraint(
"T",
OpSchema::numeric_types_for_math_reduction(),
"Constrain input and output types to high-precision numeric tensors.");
schema.SinceVersion(7);
sc->schemas_list.push_back(schema);
return sc;
}
ORT_EXPORT void FreeKernelsContainer(KernelsContainer* kc) {
delete kc;
}
ORT_EXPORT void FreeSchemasContainer(SchemasContainer* sc) {
delete sc;
}

View file

@ -13,24 +13,31 @@
#include "onnx_protobuf.h"
using namespace onnxruntime;
struct Input {
const char* name;
std::vector<int64_t> dims;
std::vector<float> values;
};
void RunSession(OrtAllocator* env, OrtSession* session_object,
const std::vector<int64_t>& dims_x,
const std::vector<float>& values_x,
const std::vector<Input>& inputs,
const char* output_name,
const std::vector<int64_t>& dims_y,
const std::vector<float>& values_y,
OrtValue* output_tensor) {
std::unique_ptr<OrtValue, decltype(&OrtReleaseValue)> value_x(nullptr, OrtReleaseValue);
std::vector<OrtValue*> inputs(1);
inputs[0] = OrtCreateTensorAsOrtValue(env, dims_x, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
value_x.reset(inputs[0]);
void* raw_data;
ORT_THROW_ON_ERROR(OrtGetTensorMutableData(inputs[0], &raw_data));
memcpy(raw_data, values_x.data(), values_x.size() * sizeof(values_x[0]));
std::vector<const char*> input_names{"X"};
const char* output_names[] = {"Y"};
std::vector<OrtValue*> ort_inputs;
std::vector<std::unique_ptr<OrtValue, decltype(&OrtReleaseValue)>> ort_inputs_cleanup;
std::vector<const char*> input_names;
for (int i = 0; i < inputs.size(); i++) {
input_names.emplace_back(inputs[i].name);
ort_inputs.emplace_back(OrtCreateTensorWithDataAsOrtValue(env->Info(env), (void*)inputs[i].values.data(), inputs[i].values.size() * sizeof(inputs[i].values[0]), inputs[i].dims, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT));
ort_inputs_cleanup.emplace_back(ort_inputs.back(), OrtReleaseValue);
}
// const char* output_names[] = {"Y"};
bool is_output_allocated_by_ort = output_tensor == nullptr;
OrtValue* old_output_ptr = output_tensor;
ORT_THROW_ON_ERROR(OrtRun(session_object, NULL, input_names.data(), inputs.data(), inputs.size(), output_names, 1, &output_tensor));
ORT_THROW_ON_ERROR(OrtRun(session_object, NULL, input_names.data(), ort_inputs.data(), ort_inputs.size(), &output_name, 1, &output_tensor));
ASSERT_NE(output_tensor, nullptr);
if (!is_output_allocated_by_ort)
ASSERT_EQ(output_tensor, old_output_ptr);
@ -59,8 +66,8 @@ void RunSession(OrtAllocator* env, OrtSession* session_object,
template <typename T>
void TestInference(OrtEnv* env, T model_uri,
const std::vector<int64_t>& dims_x,
const std::vector<float>& values_x,
const std::vector<Input>& inputs,
const char* output_name,
const std::vector<int64_t>& expected_dims_y,
const std::vector<float>& expected_values_y,
int provider_type, OrtCustomOpDomain* custom_op_domain_ptr) {
@ -101,8 +108,8 @@ void TestInference(OrtEnv* env, T model_uri,
//without preallocated output tensor
RunSession(default_allocator.get(),
inference_session.get(),
dims_x,
values_x,
inputs,
output_name,
expected_dims_y,
expected_values_y,
nullptr);
@ -123,8 +130,8 @@ void TestInference(OrtEnv* env, T model_uri,
for (int i = 0; i != 2; ++i)
RunSession(default_allocator.get(),
inference_session.get(),
dims_x,
values_x,
inputs,
output_name,
expected_dims_y,
expected_values_y,
value_y.get());
@ -141,14 +148,17 @@ class CApiTestWithProvider : public CApiTest,
TEST_P(CApiTestWithProvider, simple) {
// simple inference test
// prepare inputs
std::vector<int64_t> dims_x = {3, 2};
std::vector<float> values_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
std::vector<Input> inputs(1);
Input& input = inputs.back();
input.name = "X";
input.dims = {3, 2};
input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
// prepare expected inputs and outputs
std::vector<int64_t> expected_dims_y = {3, 2};
std::vector<float> expected_values_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f};
TestInference<PATH_TYPE>(env, MODEL_URI, dims_x, values_x, expected_dims_y, expected_values_y, GetParam(), nullptr);
TestInference<PATH_TYPE>(env, MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, GetParam(), nullptr);
}
INSTANTIATE_TEST_CASE_P(CApiTestWithProviders,
@ -156,10 +166,9 @@ INSTANTIATE_TEST_CASE_P(CApiTestWithProviders,
::testing::Values(0, 1, 2, 3, 4));
struct OrtTensorDimensions : std::vector<int64_t> {
OrtTensorDimensions(const OrtCustomOpApi& ort, OrtValue* value) {
OrtTensorTypeAndShapeInfo* info;
ORT_THROW_ON_ERROR(ort.GetTensorShapeAndType(value, &info));
auto dimensionCount = ort.GetNumOfDimensions(info);
OrtTensorDimensions(onnxruntime::CustomOpApi ort, OrtValue* value) {
OrtTensorTypeAndShapeInfo* info = ort.GetTensorShapeAndType(value);
auto dimensionCount = ort.GetDimensionCount(info);
resize(dimensionCount);
ort.GetDimensions(info, data(), dimensionCount);
ort.ReleaseTensorTypeAndShapeInfo(info);
@ -178,57 +187,60 @@ template <typename T, size_t N>
constexpr size_t countof(T (&)[N]) { return N; }
struct MyCustomKernel {
MyCustomKernel(const OrtCustomOpApi& ort, const OrtKernelInfo& /*info*/) : ort_(ort) {
MyCustomKernel(onnxruntime::CustomOpApi ort, const OrtKernelInfo* /*info*/) : ort_(ort) {
}
void GetOutputShape(OrtValue** inputs, size_t /*input_count*/, size_t /*output_index*/, OrtTensorTypeAndShapeInfo* info) {
OrtTensorDimensions dimensions(ort_, inputs[0]);
ORT_THROW_ON_ERROR(ort_.SetDims(info, dimensions.data(), dimensions.size()));
void GetOutputShape(OrtKernelContext* context, size_t /*output_index*/, OrtTensorTypeAndShapeInfo* info) {
OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
OrtTensorDimensions dimensions(ort_, input_X);
ort_.SetDimensions(info, dimensions.data(), dimensions.size());
}
void Compute(OrtValue** inputs, size_t /*input_count*/, OrtValue** outputs, size_t /*output_count*/) {
const float* X;
const float* Y;
ORT_THROW_ON_ERROR(ort_.GetTensorMutableData(inputs[0], reinterpret_cast<void**>(const_cast<float**>(&X))));
ORT_THROW_ON_ERROR(ort_.GetTensorMutableData(inputs[1], reinterpret_cast<void**>(const_cast<float**>(&Y))));
void Compute(OrtKernelContext* context) {
// Setup inputs
OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
OrtValue* input_Y = ort_.KernelContext_GetInput(context, 1);
float* X = ort_.GetTensorMutableData<float>(input_X);
float* Y = ort_.GetTensorMutableData<float>(input_Y);
float* out;
ORT_THROW_ON_ERROR(ort_.GetTensorMutableData(outputs[0], reinterpret_cast<void**>(&out)));
// Setup output
OrtTensorDimensions dimensions(ort_, input_X);
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
float* out = ort_.GetTensorMutableData<float>(output);
int64_t size = OrtTensorDimensions(ort_, inputs[0]).ElementCount();
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorShapeAndType(output);
int64_t size = ort_.GetTensorShapeElementCount(output_info);
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
// Do computation
for (int64_t i = 0; i < size; i++) {
out[i] = X[i] + Y[i];
}
}
private:
const OrtCustomOpApi& ort_;
onnxruntime::CustomOpApi ort_;
};
struct MyCustomOp : OrtCustomOp {
MyCustomOp() {
OrtCustomOp::version = ORT_API_VERSION;
OrtCustomOp::CreateKernel = [](OrtCustomOp* /*this_*/, const OrtCustomOpApi* api, const OrtKernelInfo* info, void** output) { *output = new MyCustomKernel(*api, *info); };
OrtCustomOp::GetName = [](OrtCustomOp* /*this_*/) { return "Foo"; };
struct MyCustomOp : onnxruntime::CustomOpBase<MyCustomOp, MyCustomKernel> {
void* CreateKernel(onnxruntime::CustomOpApi api, const OrtKernelInfo* info) { return new MyCustomKernel(api, info); };
const char* GetName() const { return "Foo"; };
static const ONNXTensorElementDataType c_inputTypes[] = {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT};
OrtCustomOp::GetInputTypeCount = [](OrtCustomOp* /*this_*/) { return countof(c_inputTypes); };
OrtCustomOp::GetInputType = [](OrtCustomOp* /*this_*/, size_t index) { return c_inputTypes[index]; };
size_t GetInputTypeCount() const { return 2; };
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
static const ONNXTensorElementDataType c_outputTypes[] = {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT};
OrtCustomOp::GetOutputTypeCount = [](OrtCustomOp* /*this_*/) { return countof(c_outputTypes); };
OrtCustomOp::GetOutputType = [](OrtCustomOp* /*this_*/, size_t index) { return c_outputTypes[index]; };
OrtCustomOp::KernelGetOutputShape = [](void* op_kernel, OrtValue** inputs, size_t input_count, size_t output_index, OrtTensorTypeAndShapeInfo* output) { static_cast<MyCustomKernel*>(op_kernel)->GetOutputShape(inputs, input_count, output_index, output); };
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtValue** inputs, size_t input_count, OrtValue** outputs, size_t output_count) { static_cast<MyCustomKernel*>(op_kernel)->Compute(inputs, input_count, outputs, output_count); };
OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<MyCustomKernel*>(op_kernel); };
}
size_t GetOutputTypeCount() const { return 1; };
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
};
TEST_F(CApiTest, custom_op_handler) {
std::cout << "Running custom op inference" << std::endl;
std::vector<int64_t> dims_x = {3, 2};
std::vector<float> values_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
std::vector<Input> inputs(1);
Input& input = inputs[0];
input.name = "X";
input.dims = {3, 2};
input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
// prepare expected inputs and outputs
std::vector<int64_t> expected_dims_y = {3, 2};
@ -238,7 +250,7 @@ TEST_F(CApiTest, custom_op_handler) {
OrtCustomOpDomain* custom_op_domain = OrtCreateCustomOpDomain("");
ORT_THROW_ON_ERROR(OrtCustomOpDomain_Add(custom_op_domain, &custom_op));
TestInference<PATH_TYPE>(env, CUSTOM_OP_MODEL_URI, dims_x, values_x, expected_dims_y, expected_values_y, 0, custom_op_domain);
TestInference<PATH_TYPE>(env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 0, custom_op_domain);
OrtReleaseCustomOpDomain(custom_op_domain);
}
@ -251,6 +263,7 @@ TEST_F(CApiTest, create_session_without_session_option) {
OrtReleaseSession(ret);
}
#endif
TEST_F(CApiTest, create_tensor) {
const char* s[] = {"abc", "kmp"};
int64_t expected_len = 2;