mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Custom op interface to the C API to remove shared library dependency (#668)
* Adding a custom op interface to the C API to remove shared library dependency. * Fixup const issues * Renaming to make things a little simpler * Add a comment
This commit is contained in:
parent
6c40aed95c
commit
cd52431b8f
3 changed files with 58 additions and 21 deletions
|
|
@ -538,8 +538,24 @@ typedef struct OrtKernelInfo OrtKernelInfo;
|
|||
/*
|
||||
* These allow reading node attributes during kernel creation
|
||||
*/
|
||||
ORT_API_STATUS(OrtKernelInfoGetAttribute_float, _In_ OrtKernelInfo* info, _In_ const char* name, _Out_ float* out);
|
||||
ORT_API_STATUS(OrtKernelInfoGetAttribute_int64, _In_ OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out);
|
||||
ORT_API_STATUS(OrtKernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out);
|
||||
ORT_API_STATUS(OrtKernelInfoGetAttribute_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out);
|
||||
|
||||
struct OrtCustomOpApi {
|
||||
OrtStatus*(ORT_API_CALL* KernelInfoGetAttribute_float)(_In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out);
|
||||
OrtStatus*(ORT_API_CALL* KernelInfoGetAttribute_int64)(_In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out);
|
||||
|
||||
OrtStatus*(ORT_API_CALL* GetTensorShapeAndType)(_In_ const OrtValue* value, _Out_ OrtTensorTypeAndShapeInfo** out);
|
||||
|
||||
size_t(ORT_API_CALL* GetNumOfDimensions)(_In_ const OrtTensorTypeAndShapeInfo* info);
|
||||
void(ORT_API_CALL* GetDimensions)(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length);
|
||||
OrtStatus*(ORT_API_CALL* SetDims)(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count);
|
||||
|
||||
OrtStatus*(ORT_API_CALL* GetTensorMutableData)(_Inout_ OrtValue* value, _Out_ void** out);
|
||||
|
||||
void(ORT_API_CALL* ReleaseTensorTypeAndShapeInfo)(OrtTensorTypeAndShapeInfo* input);
|
||||
};
|
||||
typedef struct OrtCustomOpApi OrtCustomOpApi;
|
||||
|
||||
/*
|
||||
* The OrtCustomOp structure defines a custom op's schema and its kernel callbacks. The callbacks are filled in by
|
||||
|
|
@ -549,7 +565,7 @@ struct OrtCustomOp {
|
|||
uint32_t version; // Initialize to ORT_API_VERSION
|
||||
|
||||
// This callback creates the kernel, which is a user defined parameter that is passed to the Kernel* callbacks below.
|
||||
void(ORT_API_CALL* CreateKernel)(_In_ struct OrtCustomOp* op, _In_ OrtKernelInfo* info, _Out_ void** op_kernel);
|
||||
void(ORT_API_CALL* CreateKernel)(_In_ struct OrtCustomOp* op, _In_ const OrtCustomOpApi* api, _In_ const OrtKernelInfo* info, _Out_ void** op_kernel);
|
||||
|
||||
// Returns the name of the op
|
||||
const char*(ORT_API_CALL* GetName)(_In_ struct OrtCustomOp* op);
|
||||
|
|
|
|||
|
|
@ -58,17 +58,32 @@
|
|||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
||||
constexpr OrtCustomOpApi g_custom_op_api = {
|
||||
&OrtKernelInfoGetAttribute_float,
|
||||
&OrtKernelInfoGetAttribute_int64,
|
||||
|
||||
&OrtGetTensorShapeAndType,
|
||||
|
||||
&OrtGetNumOfDimensions,
|
||||
&OrtGetDimensions,
|
||||
&OrtSetDims,
|
||||
|
||||
&OrtGetTensorMutableData,
|
||||
|
||||
&OrtReleaseTensorTypeAndShapeInfo,
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType MLDataTypeToOnnxRuntimeTensorElementDataType(const onnxruntime::DataTypeImpl* cpp_type);
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtKernelInfoGetAttribute_float, _In_ OrtKernelInfo* info, _In_ const char* name, _Out_ float* out) {
|
||||
auto status = reinterpret_cast<onnxruntime::OpKernelInfo*>(info)->GetAttr<float>(name, out);
|
||||
ORT_API_STATUS_IMPL(OrtKernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out) {
|
||||
auto status = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAttr<float>(name, out);
|
||||
if (status.IsOK())
|
||||
return nullptr;
|
||||
return onnxruntime::ToOrtStatus(status);
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtKernelInfoGetAttribute_int64, _In_ OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out) {
|
||||
auto status = reinterpret_cast<onnxruntime::OpKernelInfo*>(info)->GetAttr<int64_t>(name, out);
|
||||
ORT_API_STATUS_IMPL(OrtKernelInfoGetAttribute_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out) {
|
||||
auto status = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAttr<int64_t>(name, out);
|
||||
if (status.IsOK())
|
||||
return nullptr;
|
||||
return onnxruntime::ToOrtStatus(status);
|
||||
|
|
@ -111,7 +126,9 @@ inline std::basic_string<T> GetCurrentTimeString() {
|
|||
} // namespace
|
||||
struct CustomOpKernel : OpKernel {
|
||||
CustomOpKernel(const OpKernelInfo& info, OrtCustomOp& op) : OpKernel(info), op_(op) {
|
||||
op_.CreateKernel(&op_, reinterpret_cast<OrtKernelInfo*>(const_cast<OpKernelInfo*>(&info)), &op_kernel_);
|
||||
if (op_.version != 1)
|
||||
throw std::invalid_argument("Unsupported version '" + std::to_string(op_.version) + "' in custom op '" + op.GetName(&op));
|
||||
op_.CreateKernel(&op_, &g_custom_op_api, reinterpret_cast<OrtKernelInfo*>(const_cast<OpKernelInfo*>(&info)), &op_kernel_);
|
||||
}
|
||||
|
||||
~CustomOpKernel() {
|
||||
|
|
|
|||
|
|
@ -156,13 +156,13 @@ INSTANTIATE_TEST_CASE_P(CApiTestWithProviders,
|
|||
::testing::Values(0, 1, 2, 3, 4));
|
||||
|
||||
struct OrtTensorDimensions : std::vector<int64_t> {
|
||||
OrtTensorDimensions(OrtValue* value) {
|
||||
OrtTensorDimensions(const OrtCustomOpApi& ort, OrtValue* value) {
|
||||
OrtTensorTypeAndShapeInfo* info;
|
||||
ORT_THROW_ON_ERROR(OrtGetTensorShapeAndType(value, &info));
|
||||
auto dimensionCount = OrtGetNumOfDimensions(info);
|
||||
ORT_THROW_ON_ERROR(ort.GetTensorShapeAndType(value, &info));
|
||||
auto dimensionCount = ort.GetNumOfDimensions(info);
|
||||
resize(dimensionCount);
|
||||
OrtGetDimensions(info, data(), dimensionCount);
|
||||
OrtReleaseTensorTypeAndShapeInfo(info);
|
||||
ort.GetDimensions(info, data(), dimensionCount);
|
||||
ort.ReleaseTensorTypeAndShapeInfo(info);
|
||||
}
|
||||
|
||||
size_t ElementCount() const {
|
||||
|
|
@ -173,38 +173,42 @@ struct OrtTensorDimensions : std::vector<int64_t> {
|
|||
}
|
||||
};
|
||||
|
||||
// Once we use C++17 this could be replaced with std::size
|
||||
template <typename T, size_t N>
|
||||
constexpr size_t countof(T (&)[N]) { return N; }
|
||||
|
||||
struct MyCustomKernel {
|
||||
MyCustomKernel(OrtKernelInfo& /*info*/) {
|
||||
MyCustomKernel(const OrtCustomOpApi& ort, const OrtKernelInfo& /*info*/) : ort_(ort) {
|
||||
}
|
||||
|
||||
void GetOutputShape(OrtValue** inputs, size_t /*input_count*/, size_t /*output_index*/, OrtTensorTypeAndShapeInfo* info) {
|
||||
OrtTensorDimensions dimensions(inputs[0]);
|
||||
ORT_THROW_ON_ERROR(OrtSetDims(info, dimensions.data(), dimensions.size()));
|
||||
OrtTensorDimensions dimensions(ort_, inputs[0]);
|
||||
ORT_THROW_ON_ERROR(ort_.SetDims(info, dimensions.data(), dimensions.size()));
|
||||
}
|
||||
|
||||
void Compute(OrtValue** inputs, size_t /*input_count*/, OrtValue** outputs, size_t /*output_count*/) {
|
||||
const float* X;
|
||||
const float* Y;
|
||||
ORT_THROW_ON_ERROR(OrtGetTensorMutableData(inputs[0], reinterpret_cast<void**>(const_cast<float**>(&X))));
|
||||
ORT_THROW_ON_ERROR(OrtGetTensorMutableData(inputs[1], reinterpret_cast<void**>(const_cast<float**>(&Y))));
|
||||
ORT_THROW_ON_ERROR(ort_.GetTensorMutableData(inputs[0], reinterpret_cast<void**>(const_cast<float**>(&X))));
|
||||
ORT_THROW_ON_ERROR(ort_.GetTensorMutableData(inputs[1], reinterpret_cast<void**>(const_cast<float**>(&Y))));
|
||||
|
||||
float* out;
|
||||
ORT_THROW_ON_ERROR(OrtGetTensorMutableData(outputs[0], reinterpret_cast<void**>(&out)));
|
||||
ORT_THROW_ON_ERROR(ort_.GetTensorMutableData(outputs[0], reinterpret_cast<void**>(&out)));
|
||||
|
||||
int64_t size = OrtTensorDimensions(inputs[0]).ElementCount();
|
||||
int64_t size = OrtTensorDimensions(ort_, inputs[0]).ElementCount();
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
out[i] = X[i] + Y[i];
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
const OrtCustomOpApi& ort_;
|
||||
};
|
||||
|
||||
struct MyCustomOp : OrtCustomOp {
|
||||
MyCustomOp() {
|
||||
OrtCustomOp::version = ORT_API_VERSION;
|
||||
OrtCustomOp::CreateKernel = [](OrtCustomOp* /*this_*/, OrtKernelInfo* info, void** output) { *output = new MyCustomKernel(*info); };
|
||||
OrtCustomOp::CreateKernel = [](OrtCustomOp* /*this_*/, const OrtCustomOpApi* api, const OrtKernelInfo* info, void** output) { *output = new MyCustomKernel(*api, *info); };
|
||||
OrtCustomOp::GetName = [](OrtCustomOp* /*this_*/) { return "Foo"; };
|
||||
|
||||
static const ONNXTensorElementDataType c_inputTypes[] = {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT};
|
||||
|
|
|
|||
Loading…
Reference in a new issue