mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-26 03:00:54 +00:00
136 lines
4.1 KiB
C++
136 lines
4.1 KiB
C++
#include "custom_op_library.h"
|
|
|
|
#define ORT_API_MANUAL_INIT
|
|
#include "onnxruntime_cxx_api.h"
|
|
#undef ORT_API_MANUAL_INIT
|
|
|
|
#include <vector>
|
|
#include <cmath>
|
|
|
|
static const char* c_OpDomain = "test.customop";
|
|
|
|
struct OrtTensorDimensions : std::vector<int64_t> {
|
|
OrtTensorDimensions(Ort::CustomOpApi ort, const OrtValue* value) {
|
|
OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);
|
|
std::vector<int64_t>::operator=(ort.GetTensorShape(info));
|
|
ort.ReleaseTensorTypeAndShapeInfo(info);
|
|
}
|
|
};
|
|
|
|
|
|
struct KernelOne {
|
|
KernelOne(OrtApi api)
|
|
:api_(api),
|
|
ort_(api_)
|
|
{
|
|
}
|
|
|
|
void Compute(OrtKernelContext* context) {
|
|
// Setup inputs
|
|
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
|
const OrtValue* input_Y = ort_.KernelContext_GetInput(context, 1);
|
|
const float* X = ort_.GetTensorData<float>(input_X);
|
|
const float* Y = ort_.GetTensorData<float>(input_Y);
|
|
|
|
// Setup output
|
|
OrtTensorDimensions dimensions(ort_, input_X);
|
|
|
|
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
|
float* out = ort_.GetTensorMutableData<float>(output);
|
|
|
|
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
|
|
int64_t size = ort_.GetTensorShapeElementCount(output_info);
|
|
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
|
|
|
|
// Do computation
|
|
for (int64_t i = 0; i < size; i++) {
|
|
out[i] = X[i] + Y[i];
|
|
}
|
|
}
|
|
|
|
private:
|
|
OrtApi api_; // keep a copy of the struct, whose ref is used in the ort_
|
|
Ort::CustomOpApi ort_;
|
|
};
|
|
|
|
struct KernelTwo {
|
|
KernelTwo(OrtApi api)
|
|
: api_(api),
|
|
ort_(api_) {
|
|
}
|
|
|
|
void Compute(OrtKernelContext* context) {
|
|
// Setup inputs
|
|
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
|
const float* X = ort_.GetTensorData<float>(input_X);
|
|
|
|
// Setup output
|
|
OrtTensorDimensions dimensions(ort_, input_X);
|
|
|
|
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
|
int32_t* out = ort_.GetTensorMutableData<int32_t>(output);
|
|
|
|
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
|
|
int64_t size = ort_.GetTensorShapeElementCount(output_info);
|
|
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
|
|
|
|
// Do computation
|
|
for (int64_t i = 0; i < size; i++) {
|
|
out[i] = (int32_t)(round(X[i]));
|
|
}
|
|
}
|
|
|
|
private:
|
|
OrtApi api_; // keep a copy of the struct, whose ref is used in the ort_
|
|
Ort::CustomOpApi ort_;
|
|
};
|
|
|
|
struct CustomOpOne : Ort::CustomOpBase<CustomOpOne, KernelOne> {
|
|
void* CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) {
|
|
return new KernelOne(api);
|
|
};
|
|
|
|
const char* GetName() const { return "CustomOpOne"; };
|
|
|
|
size_t GetInputTypeCount() const { return 2; };
|
|
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
|
|
|
size_t GetOutputTypeCount() const { return 1; };
|
|
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
|
|
|
} c_CustomOpOne;
|
|
|
|
struct CustomOpTwo : Ort::CustomOpBase<CustomOpTwo, KernelTwo> {
|
|
void* CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) {
|
|
return new KernelTwo(api);
|
|
};
|
|
|
|
const char* GetName() const { return "CustomOpTwo"; };
|
|
|
|
size_t GetInputTypeCount() const { return 1; };
|
|
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
|
|
|
size_t GetOutputTypeCount() const { return 1; };
|
|
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; };
|
|
|
|
} c_CustomOpTwo;
|
|
|
|
|
|
OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) {
|
|
OrtCustomOpDomain* domain = nullptr;
|
|
const OrtApi* ortApi = api->GetApi(ORT_API_VERSION);
|
|
|
|
if (auto status = ortApi->CreateCustomOpDomain(c_OpDomain, &domain)) {
|
|
return status;
|
|
}
|
|
|
|
if (auto status = ortApi->CustomOpDomain_Add(domain, &c_CustomOpOne)) {
|
|
return status;
|
|
}
|
|
|
|
if (auto status = ortApi->CustomOpDomain_Add(domain, &c_CustomOpTwo)) {
|
|
return status;
|
|
}
|
|
|
|
return ortApi->AddCustomOpDomain(options, domain);
|
|
}
|