onnxruntime/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc
Yulong Wang fdc5c308c4
introduce macro ORT_API_MANUAL_INIT in C++ API (#4536)
* introduce macro ORT_API_MANUAL_INIT in C++ API

* resolve comments
2020-07-17 13:23:30 -07:00

136 lines
4.1 KiB
C++

#include "custom_op_library.h"
#define ORT_API_MANUAL_INIT
#include "onnxruntime_cxx_api.h"
#undef ORT_API_MANUAL_INIT
#include <vector>
#include <cmath>
static const char* c_OpDomain = "test.customop";
struct OrtTensorDimensions : std::vector<int64_t> {
OrtTensorDimensions(Ort::CustomOpApi ort, const OrtValue* value) {
OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);
std::vector<int64_t>::operator=(ort.GetTensorShape(info));
ort.ReleaseTensorTypeAndShapeInfo(info);
}
};
struct KernelOne {
KernelOne(OrtApi api)
:api_(api),
ort_(api_)
{
}
void Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
const OrtValue* input_Y = ort_.KernelContext_GetInput(context, 1);
const float* X = ort_.GetTensorData<float>(input_X);
const float* Y = ort_.GetTensorData<float>(input_Y);
// Setup output
OrtTensorDimensions dimensions(ort_, input_X);
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
float* out = ort_.GetTensorMutableData<float>(output);
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
int64_t size = ort_.GetTensorShapeElementCount(output_info);
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
// Do computation
for (int64_t i = 0; i < size; i++) {
out[i] = X[i] + Y[i];
}
}
private:
OrtApi api_; // keep a copy of the struct, whose ref is used in the ort_
Ort::CustomOpApi ort_;
};
struct KernelTwo {
KernelTwo(OrtApi api)
: api_(api),
ort_(api_) {
}
void Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
const float* X = ort_.GetTensorData<float>(input_X);
// Setup output
OrtTensorDimensions dimensions(ort_, input_X);
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
int32_t* out = ort_.GetTensorMutableData<int32_t>(output);
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
int64_t size = ort_.GetTensorShapeElementCount(output_info);
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
// Do computation
for (int64_t i = 0; i < size; i++) {
out[i] = (int32_t)(round(X[i]));
}
}
private:
OrtApi api_; // keep a copy of the struct, whose ref is used in the ort_
Ort::CustomOpApi ort_;
};
struct CustomOpOne : Ort::CustomOpBase<CustomOpOne, KernelOne> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) {
return new KernelOne(api);
};
const char* GetName() const { return "CustomOpOne"; };
size_t GetInputTypeCount() const { return 2; };
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
size_t GetOutputTypeCount() const { return 1; };
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
} c_CustomOpOne;
struct CustomOpTwo : Ort::CustomOpBase<CustomOpTwo, KernelTwo> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) {
return new KernelTwo(api);
};
const char* GetName() const { return "CustomOpTwo"; };
size_t GetInputTypeCount() const { return 1; };
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
size_t GetOutputTypeCount() const { return 1; };
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; };
} c_CustomOpTwo;
OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) {
OrtCustomOpDomain* domain = nullptr;
const OrtApi* ortApi = api->GetApi(ORT_API_VERSION);
if (auto status = ortApi->CreateCustomOpDomain(c_OpDomain, &domain)) {
return status;
}
if (auto status = ortApi->CustomOpDomain_Add(domain, &c_CustomOpOne)) {
return status;
}
if (auto status = ortApi->CustomOpDomain_Add(domain, &c_CustomOpTwo)) {
return status;
}
return ortApi->AddCustomOpDomain(options, domain);
}