mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Ryanunderhill/ocr custom op (#744)
* Adding a custom op interface to the C API to remove shared library dependency. * Remove old custom op test * Rework how custom ops handle inputs/outputs to enable custom op output shape calculation in the compute method * Add a nicer C++ API for custom ops and switch the tests to use it.
This commit is contained in:
parent
58ef1306d4
commit
fda1d0dce9
10 changed files with 326 additions and 314 deletions
|
|
@ -474,22 +474,12 @@ set_target_properties(onnxruntime_perf_test PROPERTIES FOLDER "ONNXRuntimeTest")
|
|||
|
||||
# shared lib
|
||||
if (onnxruntime_BUILD_SHARED_LIB)
|
||||
if (UNIX)
|
||||
# test custom op shared lib
|
||||
file(GLOB onnxruntime_custom_op_shared_lib_test_srcs "${ONNXRUNTIME_ROOT}/test/custom_op_shared_lib/test_custom_op.cc")
|
||||
add_library(onnxruntime_custom_op_shared_lib_test SHARED ${onnxruntime_custom_op_shared_lib_test_srcs})
|
||||
onnxruntime_add_include_to_target(onnxruntime_custom_op_shared_lib_test gsl)
|
||||
add_dependencies(onnxruntime_custom_op_shared_lib_test onnx_proto ${onnxruntime_EXTERNAL_DEPENDENCIES})
|
||||
target_include_directories(onnxruntime_custom_op_shared_lib_test PUBLIC "${PROJECT_SOURCE_DIR}/include")
|
||||
target_link_libraries(onnxruntime_custom_op_shared_lib_test PRIVATE onnxruntime onnx onnx_proto protobuf::libprotobuf)
|
||||
set_target_properties(onnxruntime_custom_op_shared_lib_test PROPERTIES FOLDER "ONNXRuntimeSharedLibTest")
|
||||
endif()
|
||||
add_library(onnxruntime_mocked_allocator ${ONNXRUNTIME_ROOT}/test/util/test_allocator.cc)
|
||||
target_include_directories(onnxruntime_mocked_allocator PUBLIC ${ONNXRUNTIME_ROOT}/test/util/include)
|
||||
set_target_properties(onnxruntime_mocked_allocator PROPERTIES FOLDER "ONNXRuntimeTest")
|
||||
|
||||
#################################################################
|
||||
# test inference using shared lib + custom op
|
||||
# test inference using shared lib
|
||||
set (ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR "${ONNXRUNTIME_ROOT}/test/shared_lib")
|
||||
set (onnxruntime_shared_lib_test_SRC
|
||||
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_fixture.h
|
||||
|
|
|
|||
|
|
@ -29,15 +29,15 @@ int main(int argc, char* argv[]) {
|
|||
CHECK_STATUS(OrtCreateEnv(ORT_LOGGING_LEVEL_WARNING, "test", &env));
|
||||
|
||||
// initialize session options if needed
|
||||
OrtSessionOptions* session_option = OrtCreateSessionOptions();
|
||||
OrtSetSessionThreadPoolSize(session_option, 1);
|
||||
OrtSessionOptions* session_options = OrtCreateSessionOptions();
|
||||
OrtSetSessionThreadPoolSize(session_options, 1);
|
||||
|
||||
// Sets graph optimization level
|
||||
// Available levels are
|
||||
// Available levels are
|
||||
// 0 -> To disable all optimizations
|
||||
// 1 -> To enable basic optimizations (Such as redundant node removals)
|
||||
// 2 -> To enable all optimizations (Includes level 1 + more complex optimizations like node fusions)
|
||||
OrtSetSessionGraphOptimizationLevel(session_option, 1);
|
||||
OrtSetSessionGraphOptimizationLevel(session_options, 1);
|
||||
|
||||
//*************************************************************************
|
||||
// create session and load model into memory
|
||||
|
|
@ -45,7 +45,7 @@ int main(int argc, char* argv[]) {
|
|||
// URL = https://github.com/onnx/models/tree/master/squeezenet
|
||||
OrtSession* session;
|
||||
const wchar_t* model_path = L"squeezenet.onnx";
|
||||
CHECK_STATUS(OrtCreateSession(env, model_path, session_option, &session));
|
||||
CHECK_STATUS(OrtCreateSession(env, model_path, session_options, &session));
|
||||
|
||||
//*************************************************************************
|
||||
// print model input layer (node names, types, shape etc.)
|
||||
|
|
@ -149,6 +149,7 @@ int main(int argc, char* argv[]) {
|
|||
OrtReleaseValue(output_tensor);
|
||||
OrtReleaseValue(input_tensor);
|
||||
OrtReleaseSession(session);
|
||||
OrtReleaseSessionOptions(session_options);
|
||||
OrtReleaseEnv(env);
|
||||
printf("Done!\n");
|
||||
return 0;
|
||||
|
|
|
|||
|
|
@ -1,28 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
this header should include all the headers that are required to build a custom op so that
|
||||
custom op developers don't have to worry about which headers to include, etc.
|
||||
*/
|
||||
#include "core/framework/op_kernel.h"
|
||||
|
||||
struct KernelsContainer {
|
||||
std::vector<::onnxruntime::KernelCreateInfo> kernels_list;
|
||||
};
|
||||
|
||||
struct SchemasContainer {
|
||||
std::vector<ONNX_NAMESPACE::OpSchema> schemas_list;
|
||||
std::string domain;
|
||||
int baseline_opset_version;
|
||||
int opset_version;
|
||||
};
|
||||
|
||||
extern "C" {
|
||||
KernelsContainer* GetAllKernels();
|
||||
SchemasContainer* GetAllSchemas();
|
||||
void FreeKernelsContainer(KernelsContainer*);
|
||||
void FreeSchemasContainer(SchemasContainer*);
|
||||
}
|
||||
|
|
@ -534,26 +534,28 @@ ORT_API_STATUS(OrtCreateValue, OrtValue** const in, int num_values, enum ONNXTyp
|
|||
*/
|
||||
struct OrtKernelInfo;
|
||||
typedef struct OrtKernelInfo OrtKernelInfo;
|
||||
|
||||
/*
|
||||
* These allow reading node attributes during kernel creation
|
||||
*/
|
||||
ORT_API_STATUS(OrtKernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out);
|
||||
ORT_API_STATUS(OrtKernelInfoGetAttribute_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out);
|
||||
struct OrtKernelContext;
|
||||
typedef struct OrtKernelContext OrtKernelContext;
|
||||
|
||||
struct OrtCustomOpApi {
|
||||
/*
|
||||
* These allow reading node attributes during kernel creation
|
||||
*/
|
||||
OrtStatus*(ORT_API_CALL* KernelInfoGetAttribute_float)(_In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out);
|
||||
OrtStatus*(ORT_API_CALL* KernelInfoGetAttribute_int64)(_In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out);
|
||||
|
||||
OrtStatus*(ORT_API_CALL* GetTensorShapeAndType)(_In_ const OrtValue* value, _Out_ OrtTensorTypeAndShapeInfo** out);
|
||||
|
||||
size_t(ORT_API_CALL* GetNumOfDimensions)(_In_ const OrtTensorTypeAndShapeInfo* info);
|
||||
int64_t(ORT_API_CALL* GetTensorShapeElementCount)(_In_ const OrtTensorTypeAndShapeInfo* info);
|
||||
size_t(ORT_API_CALL* GetDimensionCount)(_In_ const OrtTensorTypeAndShapeInfo* info);
|
||||
void(ORT_API_CALL* GetDimensions)(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length);
|
||||
OrtStatus*(ORT_API_CALL* SetDims)(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count);
|
||||
|
||||
OrtStatus*(ORT_API_CALL* GetTensorMutableData)(_Inout_ OrtValue* value, _Out_ void** out);
|
||||
OrtStatus*(ORT_API_CALL* SetDimensions)(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count);
|
||||
OrtStatus*(ORT_API_CALL* GetTensorMutableData)(_Inout_ OrtValue* value, void** data);
|
||||
|
||||
void(ORT_API_CALL* ReleaseTensorTypeAndShapeInfo)(OrtTensorTypeAndShapeInfo* input);
|
||||
|
||||
OrtValue*(ORT_API_CALL* KernelContext_GetInput)(OrtKernelContext* context, _In_ size_t index);
|
||||
OrtValue*(ORT_API_CALL* KernelContext_GetOutput)(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count);
|
||||
};
|
||||
typedef struct OrtCustomOpApi OrtCustomOpApi;
|
||||
|
||||
|
|
@ -565,7 +567,7 @@ struct OrtCustomOp {
|
|||
uint32_t version; // Initialize to ORT_API_VERSION
|
||||
|
||||
// This callback creates the kernel, which is a user defined parameter that is passed to the Kernel* callbacks below.
|
||||
void(ORT_API_CALL* CreateKernel)(_In_ struct OrtCustomOp* op, _In_ const OrtCustomOpApi* api, _In_ const OrtKernelInfo* info, _Out_ void** op_kernel);
|
||||
void*(ORT_API_CALL* CreateKernel)(_In_ struct OrtCustomOp* op, _In_ const OrtCustomOpApi* api, _In_ const OrtKernelInfo* info);
|
||||
|
||||
// Returns the name of the op
|
||||
const char*(ORT_API_CALL* GetName)(_In_ struct OrtCustomOp* op);
|
||||
|
|
@ -577,8 +579,8 @@ struct OrtCustomOp {
|
|||
size_t(ORT_API_CALL* GetOutputTypeCount)(_In_ struct OrtCustomOp* op);
|
||||
|
||||
// Op kernel callbacks
|
||||
void(ORT_API_CALL* KernelGetOutputShape)(_In_ void* op_kernel, _In_ OrtValue** inputs, _In_ size_t input_count, _In_ size_t output_index, _In_ OrtTensorTypeAndShapeInfo* output);
|
||||
void(ORT_API_CALL* KernelCompute)(_In_ void* op_kernel, _In_ OrtValue** inputs, _In_ size_t input_count, _In_ OrtValue** outputs, _In_ size_t output_count);
|
||||
void(ORT_API_CALL* KernelGetOutputShape)(_In_ void* op_kernel, _In_ OrtKernelContext* context, _In_ size_t output_index, _In_ OrtTensorTypeAndShapeInfo* output);
|
||||
void(ORT_API_CALL* KernelCompute)(_In_ void* op_kernel, _In_ OrtKernelContext* context);
|
||||
void(ORT_API_CALL* KernelDestroy)(_In_ void* op_kernel);
|
||||
};
|
||||
typedef struct OrtCustomOp OrtCustomOp;
|
||||
|
|
|
|||
|
|
@ -147,6 +147,90 @@ inline std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info
|
|||
OrtGetDimensions(info, ret.data(), ret.size());
|
||||
return ret;
|
||||
}
|
||||
|
||||
struct CustomOpApi {
|
||||
CustomOpApi(const OrtCustomOpApi& api) : api_(api) {}
|
||||
|
||||
template <typename T>
|
||||
T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name);
|
||||
|
||||
OrtTensorTypeAndShapeInfo* GetTensorShapeAndType(_In_ const OrtValue* value) {
|
||||
OrtTensorTypeAndShapeInfo* out;
|
||||
ORT_THROW_ON_ERROR(api_.GetTensorShapeAndType(value, &out));
|
||||
return out;
|
||||
}
|
||||
|
||||
int64_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) {
|
||||
return api_.GetTensorShapeElementCount(info);
|
||||
}
|
||||
|
||||
size_t GetDimensionCount(_In_ const OrtTensorTypeAndShapeInfo* info) {
|
||||
return api_.GetDimensionCount(info);
|
||||
}
|
||||
|
||||
void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) {
|
||||
api_.GetDimensions(info, dim_values, dim_values_length);
|
||||
}
|
||||
|
||||
void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) {
|
||||
api_.SetDimensions(info, dim_values, dim_count);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T* GetTensorMutableData(_Inout_ OrtValue* value) {
|
||||
T* data;
|
||||
ORT_THROW_ON_ERROR(api_.GetTensorMutableData(value, reinterpret_cast<void**>(&data)));
|
||||
return data;
|
||||
}
|
||||
|
||||
void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) {
|
||||
api_.ReleaseTensorTypeAndShapeInfo(input);
|
||||
}
|
||||
|
||||
OrtValue* KernelContext_GetInput(OrtKernelContext* context, _In_ size_t index) {
|
||||
return api_.KernelContext_GetInput(context, index);
|
||||
}
|
||||
OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count) {
|
||||
return api_.KernelContext_GetOutput(context, index, dim_values, dim_count);
|
||||
}
|
||||
|
||||
private:
|
||||
const OrtCustomOpApi& api_;
|
||||
};
|
||||
|
||||
template <>
|
||||
inline float CustomOpApi::KernelInfoGetAttribute<float>(_In_ const OrtKernelInfo* info, _In_ const char* name) {
|
||||
float out;
|
||||
ORT_THROW_ON_ERROR(api_.KernelInfoGetAttribute_float(info, name, &out));
|
||||
return out;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline int64_t CustomOpApi::KernelInfoGetAttribute<int64_t>(_In_ const OrtKernelInfo* info, _In_ const char* name) {
|
||||
int64_t out;
|
||||
ORT_THROW_ON_ERROR(api_.KernelInfoGetAttribute_int64(info, name, &out));
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename TOp, typename TKernel>
|
||||
struct CustomOpBase : OrtCustomOp {
|
||||
CustomOpBase() {
|
||||
OrtCustomOp::version = ORT_API_VERSION;
|
||||
OrtCustomOp::CreateKernel = [](OrtCustomOp* this_, const OrtCustomOpApi* api, const OrtKernelInfo* info) { return static_cast<TOp*>(this_)->CreateKernel(*api, info); };
|
||||
OrtCustomOp::GetName = [](OrtCustomOp* this_) { return static_cast<TOp*>(this_)->GetName(); };
|
||||
|
||||
OrtCustomOp::GetInputTypeCount = [](OrtCustomOp* this_) { return static_cast<TOp*>(this_)->GetInputTypeCount(); };
|
||||
OrtCustomOp::GetInputType = [](OrtCustomOp* this_, size_t index) { return static_cast<TOp*>(this_)->GetInputType(index); };
|
||||
|
||||
OrtCustomOp::GetOutputTypeCount = [](OrtCustomOp* this_) { return static_cast<TOp*>(this_)->GetOutputTypeCount(); };
|
||||
OrtCustomOp::GetOutputType = [](OrtCustomOp* this_, size_t index) { return static_cast<TOp*>(this_)->GetOutputType(index); };
|
||||
|
||||
OrtCustomOp::KernelGetOutputShape = [](void* op_kernel, OrtKernelContext* context, size_t output_index, OrtTensorTypeAndShapeInfo* output) { static_cast<TKernel*>(op_kernel)->GetOutputShape(context, output_index, output); };
|
||||
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { static_cast<TKernel*>(op_kernel)->Compute(context); };
|
||||
OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); };
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
||||
#undef ORT_REDIRECT_SIMPLE_FUNCTION_CALL
|
||||
|
|
|
|||
138
onnxruntime/core/session/custom_ops.cc
Normal file
138
onnxruntime/core/session/custom_ops.cc
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifdef _WIN32
|
||||
#pragma warning(disable : 4267)
|
||||
#endif
|
||||
|
||||
#include "core/session/inference_session.h"
|
||||
#include "core/framework/customregistry.h"
|
||||
#include "core/framework/data_types.h"
|
||||
#include "core/framework/op_kernel_info.h"
|
||||
#include "core/framework/op_kernel_context_internal.h"
|
||||
#include "core/framework/error_code_helper.h"
|
||||
#include "core/framework/tensor_type_and_shape.h"
|
||||
|
||||
ONNXTensorElementDataType MLDataTypeToOnnxRuntimeTensorElementDataType(const onnxruntime::DataTypeImpl* cpp_type);
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtKernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out) {
|
||||
auto status = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAttr<float>(name, out);
|
||||
if (status.IsOK())
|
||||
return nullptr;
|
||||
return onnxruntime::ToOrtStatus(status);
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtKernelInfoGetAttribute_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out) {
|
||||
auto status = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAttr<int64_t>(name, out);
|
||||
if (status.IsOK())
|
||||
return nullptr;
|
||||
return onnxruntime::ToOrtStatus(status);
|
||||
}
|
||||
|
||||
OrtValue* OrtKernelContext_GetInput(OrtKernelContext* context, _In_ size_t index) {
|
||||
return reinterpret_cast<OrtValue*>(const_cast<onnxruntime::MLValue*>(reinterpret_cast<onnxruntime::OpKernelContextInternal*>(context)->GetInputMLValue(index)));
|
||||
};
|
||||
|
||||
OrtValue* OrtKernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count) {
|
||||
onnxruntime::TensorShape shape(dim_values, dim_count);
|
||||
return reinterpret_cast<OrtValue*>(reinterpret_cast<onnxruntime::OpKernelContextInternal*>(context)->OutputMLValue(index, shape));
|
||||
};
|
||||
|
||||
constexpr OrtCustomOpApi g_custom_op_api = {
|
||||
&OrtKernelInfoGetAttribute_float,
|
||||
&OrtKernelInfoGetAttribute_int64,
|
||||
|
||||
&OrtGetTensorShapeAndType,
|
||||
|
||||
&OrtGetTensorShapeElementCount,
|
||||
&OrtGetNumOfDimensions,
|
||||
&OrtGetDimensions,
|
||||
&OrtSetDims,
|
||||
|
||||
&OrtGetTensorMutableData,
|
||||
|
||||
&OrtReleaseTensorTypeAndShapeInfo,
|
||||
|
||||
&OrtKernelContext_GetInput,
|
||||
&OrtKernelContext_GetOutput,
|
||||
};
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
struct CustomOpKernel : OpKernel {
|
||||
CustomOpKernel(const OpKernelInfo& info, OrtCustomOp& op) : OpKernel(info), op_(op) {
|
||||
if (op_.version != 1)
|
||||
throw std::invalid_argument("Unsupported version '" + std::to_string(op_.version) + "' in custom op '" + op.GetName(&op));
|
||||
op_kernel_ = op_.CreateKernel(&op_, &g_custom_op_api, reinterpret_cast<OrtKernelInfo*>(const_cast<OpKernelInfo*>(&info)));
|
||||
}
|
||||
|
||||
~CustomOpKernel() {
|
||||
op_.KernelDestroy(op_kernel_);
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* ctx) const override {
|
||||
auto* ictx = static_cast<OpKernelContextInternal*>(ctx);
|
||||
op_.KernelCompute(op_kernel_, reinterpret_cast<OrtKernelContext*>(ictx));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CustomOpKernel);
|
||||
|
||||
OrtCustomOp& op_;
|
||||
void* op_kernel_;
|
||||
};
|
||||
|
||||
common::Status CreateCustomRegistry(const std::vector<OrtCustomOpDomain*>& op_domains, std::shared_ptr<CustomRegistry>& output) {
|
||||
output = std::make_shared<CustomRegistry>();
|
||||
|
||||
for (auto& domain : op_domains) {
|
||||
if (domain->domain_[0])
|
||||
ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance().AddDomainToVersion(domain->domain_, 1, 1000);
|
||||
|
||||
std::vector<ONNX_NAMESPACE::OpSchema> schemas_list;
|
||||
|
||||
for (auto& op : domain->custom_ops_) {
|
||||
ONNX_NAMESPACE::OpSchema schema(op->GetName(op), "unknown", 0);
|
||||
|
||||
auto input_count = op->GetInputTypeCount(op);
|
||||
for (size_t i = 0; i < input_count; i++) {
|
||||
auto type = op->GetInputType(op, i);
|
||||
|
||||
schema.Input(i, "A", "Description",
|
||||
DataTypeImpl::ToString(onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(type)));
|
||||
}
|
||||
|
||||
auto output_count = op->GetOutputTypeCount(op);
|
||||
for (size_t i = 0; i < output_count; i++) {
|
||||
auto type = op->GetOutputType(op, i);
|
||||
|
||||
schema.Output(i, "A", "Description",
|
||||
DataTypeImpl::ToString(onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(type)));
|
||||
}
|
||||
|
||||
schema.SetDomain(domain->domain_);
|
||||
schema.SinceVersion(1);
|
||||
schema.AllowUncheckedAttributes();
|
||||
schemas_list.push_back(schema);
|
||||
|
||||
KernelDefBuilder def_builder;
|
||||
def_builder.SetName(op->GetName(op))
|
||||
.SetDomain(domain->domain_)
|
||||
.SinceVersion(1)
|
||||
.Provider(onnxruntime::kCpuExecutionProvider);
|
||||
KernelCreateFn kernel_create_fn = [&op](const OpKernelInfo& info) -> OpKernel* { return new CustomOpKernel(info, *op); };
|
||||
KernelCreateInfo create_info(def_builder.Build(), kernel_create_fn);
|
||||
|
||||
output->RegisterCustomKernel(create_info);
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(output->RegisterOpSet(schemas_list,
|
||||
domain->domain_,
|
||||
1 /* baseline opset version */,
|
||||
1000 /* opset version */));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
8
onnxruntime/core/session/custom_ops.h
Normal file
8
onnxruntime/core/session/custom_ops.h
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
common::Status CreateCustomRegistry(const std::vector<OrtCustomOpDomain*>& op_domains, std::shared_ptr<CustomRegistry>& output);
|
||||
|
||||
}
|
||||
|
|
@ -43,8 +43,8 @@
|
|||
#include "core/optimizer/insert_cast_transformer.h"
|
||||
#include "core/optimizer/transformer_memcpy.h"
|
||||
#include "core/providers/cpu/cpu_execution_provider.h"
|
||||
#include "core/framework/custom_ops_author.h"
|
||||
#include "core/session/IOBinding.h"
|
||||
#include "core/session/custom_ops.h"
|
||||
#include "core/util/protobuf_parsing_utils.h"
|
||||
#include "core/optimizer/rule_based_graph_transformer.h"
|
||||
#include "core/optimizer/graph_transformer_utils.h"
|
||||
|
|
@ -55,37 +55,6 @@
|
|||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
||||
constexpr OrtCustomOpApi g_custom_op_api = {
|
||||
&OrtKernelInfoGetAttribute_float,
|
||||
&OrtKernelInfoGetAttribute_int64,
|
||||
|
||||
&OrtGetTensorShapeAndType,
|
||||
|
||||
&OrtGetNumOfDimensions,
|
||||
&OrtGetDimensions,
|
||||
&OrtSetDims,
|
||||
|
||||
&OrtGetTensorMutableData,
|
||||
|
||||
&OrtReleaseTensorTypeAndShapeInfo,
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType MLDataTypeToOnnxRuntimeTensorElementDataType(const onnxruntime::DataTypeImpl* cpp_type);
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtKernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out) {
|
||||
auto status = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAttr<float>(name, out);
|
||||
if (status.IsOK())
|
||||
return nullptr;
|
||||
return onnxruntime::ToOrtStatus(status);
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtKernelInfoGetAttribute_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out) {
|
||||
auto status = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAttr<int64_t>(name, out);
|
||||
if (status.IsOK())
|
||||
return nullptr;
|
||||
return onnxruntime::ToOrtStatus(status);
|
||||
}
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace {
|
||||
template <typename T>
|
||||
|
|
@ -121,42 +90,6 @@ inline std::basic_string<T> GetCurrentTimeString() {
|
|||
return std::basic_string<T>(time_str);
|
||||
}
|
||||
} // namespace
|
||||
struct CustomOpKernel : OpKernel {
|
||||
CustomOpKernel(const OpKernelInfo& info, OrtCustomOp& op) : OpKernel(info), op_(op) {
|
||||
if (op_.version != 1)
|
||||
throw std::invalid_argument("Unsupported version '" + std::to_string(op_.version) + "' in custom op '" + op.GetName(&op));
|
||||
op_.CreateKernel(&op_, &g_custom_op_api, reinterpret_cast<OrtKernelInfo*>(const_cast<OpKernelInfo*>(&info)), &op_kernel_);
|
||||
}
|
||||
|
||||
~CustomOpKernel() {
|
||||
op_.KernelDestroy(op_kernel_);
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* ctx) const override {
|
||||
auto* ictx = static_cast<OpKernelContextInternal*>(ctx);
|
||||
std::vector<OrtValue*> input_tensors;
|
||||
auto input_count = ictx->InputCount();
|
||||
for (int i = 0; i < input_count; i++)
|
||||
input_tensors.emplace_back(const_cast<OrtValue*>(reinterpret_cast<const OrtValue*>(ictx->GetInputMLValue(i))));
|
||||
|
||||
std::vector<OrtValue*> output_tensors;
|
||||
auto output_count = ictx->OutputCount();
|
||||
for (int i = 0; i < output_count; i++) {
|
||||
OrtTensorTypeAndShapeInfo info;
|
||||
op_.KernelGetOutputShape(op_kernel_, input_tensors.data(), input_tensors.size(), i, &info);
|
||||
output_tensors.emplace_back(reinterpret_cast<OrtValue*>(ictx->OutputMLValue(0, info.shape)));
|
||||
}
|
||||
|
||||
op_.KernelCompute(op_kernel_, input_tensors.data(), input_tensors.size(), output_tensors.data(), output_tensors.size());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CustomOpKernel);
|
||||
|
||||
OrtCustomOp& op_;
|
||||
void* op_kernel_;
|
||||
};
|
||||
|
||||
InferenceSession::InferenceSession(const SessionOptions& session_options, logging::LoggingManager* logging_manager)
|
||||
: session_state_{execution_providers_},
|
||||
|
|
@ -236,55 +169,8 @@ common::Status InferenceSession::AddCustomTransformerList(const std::vector<std:
|
|||
}
|
||||
|
||||
common::Status InferenceSession::AddCustomOpDomains(const std::vector<OrtCustomOpDomain*>& op_domains) {
|
||||
auto custom_registry = std::make_shared<CustomRegistry>();
|
||||
|
||||
for (auto& domain : op_domains) {
|
||||
SchemasContainer schemas_container;
|
||||
|
||||
schemas_container.domain = domain->domain_;
|
||||
schemas_container.baseline_opset_version = 1;
|
||||
schemas_container.opset_version = 1000;
|
||||
|
||||
for (auto& op : domain->custom_ops_) {
|
||||
ONNX_NAMESPACE::OpSchema schema(op->GetName(op), "unknown", 0);
|
||||
|
||||
auto input_count = op->GetInputTypeCount(op);
|
||||
for (size_t i = 0; i < input_count; i++) {
|
||||
auto type = op->GetInputType(op, i);
|
||||
|
||||
schema.Input(i, "A", "Description",
|
||||
DataTypeImpl::ToString(onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(type)));
|
||||
}
|
||||
|
||||
auto output_count = op->GetOutputTypeCount(op);
|
||||
for (size_t i = 0; i < output_count; i++) {
|
||||
auto type = op->GetOutputType(op, i);
|
||||
|
||||
schema.Output(i, "A", "Description",
|
||||
DataTypeImpl::ToString(onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(type)));
|
||||
}
|
||||
|
||||
schema.SinceVersion(1);
|
||||
schema.AllowUncheckedAttributes();
|
||||
|
||||
schemas_container.schemas_list.push_back(schema);
|
||||
|
||||
KernelDefBuilder def_builder;
|
||||
def_builder.SetName(op->GetName(op))
|
||||
.SetDomain(onnxruntime::kOnnxDomain)
|
||||
.SinceVersion(1)
|
||||
.Provider(onnxruntime::kCpuExecutionProvider);
|
||||
KernelCreateFn kernel_create_fn = [&op](const OpKernelInfo& info) -> OpKernel* { return new CustomOpKernel(info, *op); };
|
||||
KernelCreateInfo create_info(def_builder.Build(), kernel_create_fn);
|
||||
|
||||
custom_registry->RegisterCustomKernel(create_info);
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(custom_registry->RegisterOpSet(schemas_container.schemas_list,
|
||||
schemas_container.domain,
|
||||
schemas_container.baseline_opset_version,
|
||||
schemas_container.opset_version));
|
||||
}
|
||||
std::shared_ptr<CustomRegistry> custom_registry;
|
||||
ORT_RETURN_IF_ERROR(CreateCustomRegistry(op_domains, custom_registry));
|
||||
RegisterCustomRegistry(custom_registry);
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -1054,4 +940,5 @@ common::Status InferenceSession::WaitForNotification(Notification* p_executor_do
|
|||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1,83 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// example custom op
|
||||
|
||||
#include "core/framework/custom_ops_author.h"
|
||||
#include "core/session/onnxruntime_c_api.h"
|
||||
|
||||
using namespace onnxruntime;
|
||||
using namespace onnxruntime::common;
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
||||
class FooKernel : public OpKernel {
|
||||
public:
|
||||
FooKernel(const OpKernelInfo& info) : OpKernel(info) {
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* ctx) const override {
|
||||
const Tensor* X = ctx->Input<Tensor>(0);
|
||||
const Tensor* W = ctx->Input<Tensor>(1);
|
||||
auto* X_data = X->template Data<float>();
|
||||
auto* W_data = W->template Data<float>();
|
||||
Tensor* Y = ctx->Output(0, X->Shape());
|
||||
auto* Y_data = Y->template MutableData<float>();
|
||||
|
||||
for (int64_t i = 0; i < X->Shape().Size(); i++) {
|
||||
Y_data[i] = X_data[i] + W_data[i];
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
ORT_EXPORT KernelsContainer* GetAllKernels() {
|
||||
KernelsContainer* kc = new KernelsContainer;
|
||||
|
||||
KernelDefBuilder def_builder;
|
||||
def_builder.SetName("Foo")
|
||||
.SetDomain(onnxruntime::kOnnxDomain)
|
||||
.SinceVersion(7)
|
||||
.Provider(onnxruntime::kCpuExecutionProvider)
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>());
|
||||
KernelCreateFn kernel_create_fn = [](const OpKernelInfo& info) -> OpKernel* { return new FooKernel(info); };
|
||||
KernelCreateInfo create_info(def_builder.Build(), kernel_create_fn);
|
||||
kc->kernels_list.push_back(std::move(create_info));
|
||||
return kc;
|
||||
}
|
||||
|
||||
ORT_EXPORT SchemasContainer* GetAllSchemas() {
|
||||
SchemasContainer* sc = new SchemasContainer;
|
||||
sc->domain = onnxruntime::kOnnxDomain;
|
||||
sc->baseline_opset_version = 5;
|
||||
sc->opset_version = 7;
|
||||
|
||||
ONNX_NAMESPACE::OpSchema schema("Foo", "unknown", 0);
|
||||
schema.Input(0,
|
||||
"A",
|
||||
"First operand, should share the type with the second operand.",
|
||||
"T");
|
||||
schema.Input(
|
||||
1,
|
||||
"B",
|
||||
"Second operand. With broadcasting can be of smaller size than A. "
|
||||
"If broadcasting is disabled it should be of the same size.",
|
||||
"T");
|
||||
schema.Output(0, "C", "Result, has same dimensions and type as A", "T");
|
||||
schema.TypeConstraint(
|
||||
"T",
|
||||
OpSchema::numeric_types_for_math_reduction(),
|
||||
"Constrain input and output types to high-precision numeric tensors.");
|
||||
schema.SinceVersion(7);
|
||||
|
||||
sc->schemas_list.push_back(schema);
|
||||
return sc;
|
||||
}
|
||||
|
||||
ORT_EXPORT void FreeKernelsContainer(KernelsContainer* kc) {
|
||||
delete kc;
|
||||
}
|
||||
|
||||
ORT_EXPORT void FreeSchemasContainer(SchemasContainer* sc) {
|
||||
delete sc;
|
||||
}
|
||||
|
|
@ -13,24 +13,31 @@
|
|||
#include "onnx_protobuf.h"
|
||||
using namespace onnxruntime;
|
||||
|
||||
struct Input {
|
||||
const char* name;
|
||||
std::vector<int64_t> dims;
|
||||
std::vector<float> values;
|
||||
};
|
||||
|
||||
void RunSession(OrtAllocator* env, OrtSession* session_object,
|
||||
const std::vector<int64_t>& dims_x,
|
||||
const std::vector<float>& values_x,
|
||||
const std::vector<Input>& inputs,
|
||||
const char* output_name,
|
||||
const std::vector<int64_t>& dims_y,
|
||||
const std::vector<float>& values_y,
|
||||
OrtValue* output_tensor) {
|
||||
std::unique_ptr<OrtValue, decltype(&OrtReleaseValue)> value_x(nullptr, OrtReleaseValue);
|
||||
std::vector<OrtValue*> inputs(1);
|
||||
inputs[0] = OrtCreateTensorAsOrtValue(env, dims_x, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
|
||||
value_x.reset(inputs[0]);
|
||||
void* raw_data;
|
||||
ORT_THROW_ON_ERROR(OrtGetTensorMutableData(inputs[0], &raw_data));
|
||||
memcpy(raw_data, values_x.data(), values_x.size() * sizeof(values_x[0]));
|
||||
std::vector<const char*> input_names{"X"};
|
||||
const char* output_names[] = {"Y"};
|
||||
std::vector<OrtValue*> ort_inputs;
|
||||
std::vector<std::unique_ptr<OrtValue, decltype(&OrtReleaseValue)>> ort_inputs_cleanup;
|
||||
std::vector<const char*> input_names;
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
input_names.emplace_back(inputs[i].name);
|
||||
ort_inputs.emplace_back(OrtCreateTensorWithDataAsOrtValue(env->Info(env), (void*)inputs[i].values.data(), inputs[i].values.size() * sizeof(inputs[i].values[0]), inputs[i].dims, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT));
|
||||
ort_inputs_cleanup.emplace_back(ort_inputs.back(), OrtReleaseValue);
|
||||
}
|
||||
|
||||
// const char* output_names[] = {"Y"};
|
||||
bool is_output_allocated_by_ort = output_tensor == nullptr;
|
||||
OrtValue* old_output_ptr = output_tensor;
|
||||
ORT_THROW_ON_ERROR(OrtRun(session_object, NULL, input_names.data(), inputs.data(), inputs.size(), output_names, 1, &output_tensor));
|
||||
ORT_THROW_ON_ERROR(OrtRun(session_object, NULL, input_names.data(), ort_inputs.data(), ort_inputs.size(), &output_name, 1, &output_tensor));
|
||||
ASSERT_NE(output_tensor, nullptr);
|
||||
if (!is_output_allocated_by_ort)
|
||||
ASSERT_EQ(output_tensor, old_output_ptr);
|
||||
|
|
@ -59,8 +66,8 @@ void RunSession(OrtAllocator* env, OrtSession* session_object,
|
|||
|
||||
template <typename T>
|
||||
void TestInference(OrtEnv* env, T model_uri,
|
||||
const std::vector<int64_t>& dims_x,
|
||||
const std::vector<float>& values_x,
|
||||
const std::vector<Input>& inputs,
|
||||
const char* output_name,
|
||||
const std::vector<int64_t>& expected_dims_y,
|
||||
const std::vector<float>& expected_values_y,
|
||||
int provider_type, OrtCustomOpDomain* custom_op_domain_ptr) {
|
||||
|
|
@ -101,8 +108,8 @@ void TestInference(OrtEnv* env, T model_uri,
|
|||
//without preallocated output tensor
|
||||
RunSession(default_allocator.get(),
|
||||
inference_session.get(),
|
||||
dims_x,
|
||||
values_x,
|
||||
inputs,
|
||||
output_name,
|
||||
expected_dims_y,
|
||||
expected_values_y,
|
||||
nullptr);
|
||||
|
|
@ -123,8 +130,8 @@ void TestInference(OrtEnv* env, T model_uri,
|
|||
for (int i = 0; i != 2; ++i)
|
||||
RunSession(default_allocator.get(),
|
||||
inference_session.get(),
|
||||
dims_x,
|
||||
values_x,
|
||||
inputs,
|
||||
output_name,
|
||||
expected_dims_y,
|
||||
expected_values_y,
|
||||
value_y.get());
|
||||
|
|
@ -141,14 +148,17 @@ class CApiTestWithProvider : public CApiTest,
|
|||
TEST_P(CApiTestWithProvider, simple) {
|
||||
// simple inference test
|
||||
// prepare inputs
|
||||
std::vector<int64_t> dims_x = {3, 2};
|
||||
std::vector<float> values_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
|
||||
std::vector<Input> inputs(1);
|
||||
Input& input = inputs.back();
|
||||
input.name = "X";
|
||||
input.dims = {3, 2};
|
||||
input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
|
||||
|
||||
// prepare expected inputs and outputs
|
||||
std::vector<int64_t> expected_dims_y = {3, 2};
|
||||
std::vector<float> expected_values_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f};
|
||||
|
||||
TestInference<PATH_TYPE>(env, MODEL_URI, dims_x, values_x, expected_dims_y, expected_values_y, GetParam(), nullptr);
|
||||
TestInference<PATH_TYPE>(env, MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, GetParam(), nullptr);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(CApiTestWithProviders,
|
||||
|
|
@ -156,10 +166,9 @@ INSTANTIATE_TEST_CASE_P(CApiTestWithProviders,
|
|||
::testing::Values(0, 1, 2, 3, 4));
|
||||
|
||||
struct OrtTensorDimensions : std::vector<int64_t> {
|
||||
OrtTensorDimensions(const OrtCustomOpApi& ort, OrtValue* value) {
|
||||
OrtTensorTypeAndShapeInfo* info;
|
||||
ORT_THROW_ON_ERROR(ort.GetTensorShapeAndType(value, &info));
|
||||
auto dimensionCount = ort.GetNumOfDimensions(info);
|
||||
OrtTensorDimensions(onnxruntime::CustomOpApi ort, OrtValue* value) {
|
||||
OrtTensorTypeAndShapeInfo* info = ort.GetTensorShapeAndType(value);
|
||||
auto dimensionCount = ort.GetDimensionCount(info);
|
||||
resize(dimensionCount);
|
||||
ort.GetDimensions(info, data(), dimensionCount);
|
||||
ort.ReleaseTensorTypeAndShapeInfo(info);
|
||||
|
|
@ -178,57 +187,60 @@ template <typename T, size_t N>
|
|||
constexpr size_t countof(T (&)[N]) { return N; }
|
||||
|
||||
struct MyCustomKernel {
|
||||
MyCustomKernel(const OrtCustomOpApi& ort, const OrtKernelInfo& /*info*/) : ort_(ort) {
|
||||
MyCustomKernel(onnxruntime::CustomOpApi ort, const OrtKernelInfo* /*info*/) : ort_(ort) {
|
||||
}
|
||||
|
||||
void GetOutputShape(OrtValue** inputs, size_t /*input_count*/, size_t /*output_index*/, OrtTensorTypeAndShapeInfo* info) {
|
||||
OrtTensorDimensions dimensions(ort_, inputs[0]);
|
||||
ORT_THROW_ON_ERROR(ort_.SetDims(info, dimensions.data(), dimensions.size()));
|
||||
void GetOutputShape(OrtKernelContext* context, size_t /*output_index*/, OrtTensorTypeAndShapeInfo* info) {
|
||||
OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
OrtTensorDimensions dimensions(ort_, input_X);
|
||||
ort_.SetDimensions(info, dimensions.data(), dimensions.size());
|
||||
}
|
||||
|
||||
void Compute(OrtValue** inputs, size_t /*input_count*/, OrtValue** outputs, size_t /*output_count*/) {
|
||||
const float* X;
|
||||
const float* Y;
|
||||
ORT_THROW_ON_ERROR(ort_.GetTensorMutableData(inputs[0], reinterpret_cast<void**>(const_cast<float**>(&X))));
|
||||
ORT_THROW_ON_ERROR(ort_.GetTensorMutableData(inputs[1], reinterpret_cast<void**>(const_cast<float**>(&Y))));
|
||||
void Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
OrtValue* input_Y = ort_.KernelContext_GetInput(context, 1);
|
||||
float* X = ort_.GetTensorMutableData<float>(input_X);
|
||||
float* Y = ort_.GetTensorMutableData<float>(input_Y);
|
||||
|
||||
float* out;
|
||||
ORT_THROW_ON_ERROR(ort_.GetTensorMutableData(outputs[0], reinterpret_cast<void**>(&out)));
|
||||
// Setup output
|
||||
OrtTensorDimensions dimensions(ort_, input_X);
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||
float* out = ort_.GetTensorMutableData<float>(output);
|
||||
|
||||
int64_t size = OrtTensorDimensions(ort_, inputs[0]).ElementCount();
|
||||
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorShapeAndType(output);
|
||||
int64_t size = ort_.GetTensorShapeElementCount(output_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
|
||||
|
||||
// Do computation
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
out[i] = X[i] + Y[i];
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
const OrtCustomOpApi& ort_;
|
||||
onnxruntime::CustomOpApi ort_;
|
||||
};
|
||||
|
||||
struct MyCustomOp : OrtCustomOp {
|
||||
MyCustomOp() {
|
||||
OrtCustomOp::version = ORT_API_VERSION;
|
||||
OrtCustomOp::CreateKernel = [](OrtCustomOp* /*this_*/, const OrtCustomOpApi* api, const OrtKernelInfo* info, void** output) { *output = new MyCustomKernel(*api, *info); };
|
||||
OrtCustomOp::GetName = [](OrtCustomOp* /*this_*/) { return "Foo"; };
|
||||
struct MyCustomOp : onnxruntime::CustomOpBase<MyCustomOp, MyCustomKernel> {
|
||||
void* CreateKernel(onnxruntime::CustomOpApi api, const OrtKernelInfo* info) { return new MyCustomKernel(api, info); };
|
||||
const char* GetName() const { return "Foo"; };
|
||||
|
||||
static const ONNXTensorElementDataType c_inputTypes[] = {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT};
|
||||
OrtCustomOp::GetInputTypeCount = [](OrtCustomOp* /*this_*/) { return countof(c_inputTypes); };
|
||||
OrtCustomOp::GetInputType = [](OrtCustomOp* /*this_*/, size_t index) { return c_inputTypes[index]; };
|
||||
size_t GetInputTypeCount() const { return 2; };
|
||||
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
||||
|
||||
static const ONNXTensorElementDataType c_outputTypes[] = {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT};
|
||||
OrtCustomOp::GetOutputTypeCount = [](OrtCustomOp* /*this_*/) { return countof(c_outputTypes); };
|
||||
OrtCustomOp::GetOutputType = [](OrtCustomOp* /*this_*/, size_t index) { return c_outputTypes[index]; };
|
||||
|
||||
OrtCustomOp::KernelGetOutputShape = [](void* op_kernel, OrtValue** inputs, size_t input_count, size_t output_index, OrtTensorTypeAndShapeInfo* output) { static_cast<MyCustomKernel*>(op_kernel)->GetOutputShape(inputs, input_count, output_index, output); };
|
||||
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtValue** inputs, size_t input_count, OrtValue** outputs, size_t output_count) { static_cast<MyCustomKernel*>(op_kernel)->Compute(inputs, input_count, outputs, output_count); };
|
||||
OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<MyCustomKernel*>(op_kernel); };
|
||||
}
|
||||
size_t GetOutputTypeCount() const { return 1; };
|
||||
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
||||
};
|
||||
|
||||
TEST_F(CApiTest, custom_op_handler) {
|
||||
std::cout << "Running custom op inference" << std::endl;
|
||||
std::vector<int64_t> dims_x = {3, 2};
|
||||
std::vector<float> values_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
|
||||
|
||||
std::vector<Input> inputs(1);
|
||||
Input& input = inputs[0];
|
||||
input.name = "X";
|
||||
input.dims = {3, 2};
|
||||
input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
|
||||
|
||||
// prepare expected inputs and outputs
|
||||
std::vector<int64_t> expected_dims_y = {3, 2};
|
||||
|
|
@ -238,7 +250,7 @@ TEST_F(CApiTest, custom_op_handler) {
|
|||
OrtCustomOpDomain* custom_op_domain = OrtCreateCustomOpDomain("");
|
||||
ORT_THROW_ON_ERROR(OrtCustomOpDomain_Add(custom_op_domain, &custom_op));
|
||||
|
||||
TestInference<PATH_TYPE>(env, CUSTOM_OP_MODEL_URI, dims_x, values_x, expected_dims_y, expected_values_y, 0, custom_op_domain);
|
||||
TestInference<PATH_TYPE>(env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 0, custom_op_domain);
|
||||
OrtReleaseCustomOpDomain(custom_op_domain);
|
||||
}
|
||||
|
||||
|
|
@ -251,6 +263,7 @@ TEST_F(CApiTest, create_session_without_session_option) {
|
|||
OrtReleaseSession(ret);
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST_F(CApiTest, create_tensor) {
|
||||
const char* s[] = {"abc", "kmp"};
|
||||
int64_t expected_len = 2;
|
||||
|
|
|
|||
Loading…
Reference in a new issue