Custom op interface to the C API to remove shared library dependency (#668)

* Adding a custom op interface to the C API to remove shared library dependency.

* Fixup const issues

* Renaming to make things a little simpler

* Add a comment
This commit is contained in:
Ryan Hill 2019-03-21 15:46:50 -07:00 committed by GitHub
parent 6c40aed95c
commit cd52431b8f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 58 additions and 21 deletions

View file

@ -538,8 +538,24 @@ typedef struct OrtKernelInfo OrtKernelInfo;
/*
* These allow reading node attributes during kernel creation
*/
ORT_API_STATUS(OrtKernelInfoGetAttribute_float, _In_ OrtKernelInfo* info, _In_ const char* name, _Out_ float* out);
ORT_API_STATUS(OrtKernelInfoGetAttribute_int64, _In_ OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out);
ORT_API_STATUS(OrtKernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out);
ORT_API_STATUS(OrtKernelInfoGetAttribute_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out);
struct OrtCustomOpApi {
OrtStatus*(ORT_API_CALL* KernelInfoGetAttribute_float)(_In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out);
OrtStatus*(ORT_API_CALL* KernelInfoGetAttribute_int64)(_In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out);
OrtStatus*(ORT_API_CALL* GetTensorShapeAndType)(_In_ const OrtValue* value, _Out_ OrtTensorTypeAndShapeInfo** out);
size_t(ORT_API_CALL* GetNumOfDimensions)(_In_ const OrtTensorTypeAndShapeInfo* info);
void(ORT_API_CALL* GetDimensions)(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length);
OrtStatus*(ORT_API_CALL* SetDims)(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count);
OrtStatus*(ORT_API_CALL* GetTensorMutableData)(_Inout_ OrtValue* value, _Out_ void** out);
void(ORT_API_CALL* ReleaseTensorTypeAndShapeInfo)(OrtTensorTypeAndShapeInfo* input);
};
typedef struct OrtCustomOpApi OrtCustomOpApi;
/*
* The OrtCustomOp structure defines a custom op's schema and its kernel callbacks. The callbacks are filled in by
@ -549,7 +565,7 @@ struct OrtCustomOp {
uint32_t version; // Initialize to ORT_API_VERSION
// This callback creates the kernel, which is a user defined parameter that is passed to the Kernel* callbacks below.
void(ORT_API_CALL* CreateKernel)(_In_ struct OrtCustomOp* op, _In_ OrtKernelInfo* info, _Out_ void** op_kernel);
void(ORT_API_CALL* CreateKernel)(_In_ struct OrtCustomOp* op, _In_ const OrtCustomOpApi* api, _In_ const OrtKernelInfo* info, _Out_ void** op_kernel);
// Returns the name of the op
const char*(ORT_API_CALL* GetName)(_In_ struct OrtCustomOp* op);

View file

@ -58,17 +58,32 @@
using namespace ONNX_NAMESPACE;
constexpr OrtCustomOpApi g_custom_op_api = {
&OrtKernelInfoGetAttribute_float,
&OrtKernelInfoGetAttribute_int64,
&OrtGetTensorShapeAndType,
&OrtGetNumOfDimensions,
&OrtGetDimensions,
&OrtSetDims,
&OrtGetTensorMutableData,
&OrtReleaseTensorTypeAndShapeInfo,
};
ONNXTensorElementDataType MLDataTypeToOnnxRuntimeTensorElementDataType(const onnxruntime::DataTypeImpl* cpp_type);
ORT_API_STATUS_IMPL(OrtKernelInfoGetAttribute_float, _In_ OrtKernelInfo* info, _In_ const char* name, _Out_ float* out) {
auto status = reinterpret_cast<onnxruntime::OpKernelInfo*>(info)->GetAttr<float>(name, out);
ORT_API_STATUS_IMPL(OrtKernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out) {
auto status = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAttr<float>(name, out);
if (status.IsOK())
return nullptr;
return onnxruntime::ToOrtStatus(status);
}
ORT_API_STATUS_IMPL(OrtKernelInfoGetAttribute_int64, _In_ OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out) {
auto status = reinterpret_cast<onnxruntime::OpKernelInfo*>(info)->GetAttr<int64_t>(name, out);
ORT_API_STATUS_IMPL(OrtKernelInfoGetAttribute_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out) {
auto status = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAttr<int64_t>(name, out);
if (status.IsOK())
return nullptr;
return onnxruntime::ToOrtStatus(status);
@ -111,7 +126,9 @@ inline std::basic_string<T> GetCurrentTimeString() {
} // namespace
struct CustomOpKernel : OpKernel {
CustomOpKernel(const OpKernelInfo& info, OrtCustomOp& op) : OpKernel(info), op_(op) {
op_.CreateKernel(&op_, reinterpret_cast<OrtKernelInfo*>(const_cast<OpKernelInfo*>(&info)), &op_kernel_);
if (op_.version != 1)
throw std::invalid_argument("Unsupported version '" + std::to_string(op_.version) + "' in custom op '" + op.GetName(&op));
op_.CreateKernel(&op_, &g_custom_op_api, reinterpret_cast<OrtKernelInfo*>(const_cast<OpKernelInfo*>(&info)), &op_kernel_);
}
~CustomOpKernel() {

View file

@ -156,13 +156,13 @@ INSTANTIATE_TEST_CASE_P(CApiTestWithProviders,
::testing::Values(0, 1, 2, 3, 4));
struct OrtTensorDimensions : std::vector<int64_t> {
OrtTensorDimensions(OrtValue* value) {
OrtTensorDimensions(const OrtCustomOpApi& ort, OrtValue* value) {
OrtTensorTypeAndShapeInfo* info;
ORT_THROW_ON_ERROR(OrtGetTensorShapeAndType(value, &info));
auto dimensionCount = OrtGetNumOfDimensions(info);
ORT_THROW_ON_ERROR(ort.GetTensorShapeAndType(value, &info));
auto dimensionCount = ort.GetNumOfDimensions(info);
resize(dimensionCount);
OrtGetDimensions(info, data(), dimensionCount);
OrtReleaseTensorTypeAndShapeInfo(info);
ort.GetDimensions(info, data(), dimensionCount);
ort.ReleaseTensorTypeAndShapeInfo(info);
}
size_t ElementCount() const {
@ -173,38 +173,42 @@ struct OrtTensorDimensions : std::vector<int64_t> {
}
};
// Once we use C++17 this could be replaced with std::size
template <typename T, size_t N>
constexpr size_t countof(T (&)[N]) { return N; }
struct MyCustomKernel {
MyCustomKernel(OrtKernelInfo& /*info*/) {
MyCustomKernel(const OrtCustomOpApi& ort, const OrtKernelInfo& /*info*/) : ort_(ort) {
}
void GetOutputShape(OrtValue** inputs, size_t /*input_count*/, size_t /*output_index*/, OrtTensorTypeAndShapeInfo* info) {
OrtTensorDimensions dimensions(inputs[0]);
ORT_THROW_ON_ERROR(OrtSetDims(info, dimensions.data(), dimensions.size()));
OrtTensorDimensions dimensions(ort_, inputs[0]);
ORT_THROW_ON_ERROR(ort_.SetDims(info, dimensions.data(), dimensions.size()));
}
void Compute(OrtValue** inputs, size_t /*input_count*/, OrtValue** outputs, size_t /*output_count*/) {
const float* X;
const float* Y;
ORT_THROW_ON_ERROR(OrtGetTensorMutableData(inputs[0], reinterpret_cast<void**>(const_cast<float**>(&X))));
ORT_THROW_ON_ERROR(OrtGetTensorMutableData(inputs[1], reinterpret_cast<void**>(const_cast<float**>(&Y))));
ORT_THROW_ON_ERROR(ort_.GetTensorMutableData(inputs[0], reinterpret_cast<void**>(const_cast<float**>(&X))));
ORT_THROW_ON_ERROR(ort_.GetTensorMutableData(inputs[1], reinterpret_cast<void**>(const_cast<float**>(&Y))));
float* out;
ORT_THROW_ON_ERROR(OrtGetTensorMutableData(outputs[0], reinterpret_cast<void**>(&out)));
ORT_THROW_ON_ERROR(ort_.GetTensorMutableData(outputs[0], reinterpret_cast<void**>(&out)));
int64_t size = OrtTensorDimensions(inputs[0]).ElementCount();
int64_t size = OrtTensorDimensions(ort_, inputs[0]).ElementCount();
for (int64_t i = 0; i < size; i++) {
out[i] = X[i] + Y[i];
}
}
private:
const OrtCustomOpApi& ort_;
};
struct MyCustomOp : OrtCustomOp {
MyCustomOp() {
OrtCustomOp::version = ORT_API_VERSION;
OrtCustomOp::CreateKernel = [](OrtCustomOp* /*this_*/, OrtKernelInfo* info, void** output) { *output = new MyCustomKernel(*info); };
OrtCustomOp::CreateKernel = [](OrtCustomOp* /*this_*/, const OrtCustomOpApi* api, const OrtKernelInfo* info, void** output) { *output = new MyCustomKernel(*api, *info); };
OrtCustomOp::GetName = [](OrtCustomOp* /*this_*/) { return "Foo"; };
static const ONNXTensorElementDataType c_inputTypes[] = {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT};