onnxruntime/onnxruntime/core/session/custom_ops.cc
cao lei 966fa74597
Add 2 C API for ort extension (#19808)
### Description
<!-- Describe your changes. -->
Add 2 C API for ORT extension:
- KernelInfo_GetAllocator
- OrtCustomOp::GetMayInplace


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Add 2 C API for ORT extension project, which will leverage these 2 APIs
for GroupQueryAttention custom op.
2024-03-14 06:00:41 -07:00

1250 lines
53 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef _WIN32
#pragma warning(disable : 4267)
#endif
#include <string>
#include <type_traits>
#include <unordered_map>
#include <unordered_set>
#include "core/common/gsl.h"
#include "core/framework/data_types.h"
#include "core/framework/error_code_helper.h"
#include "core/framework/onnxruntime_typeinfo.h"
#include "core/framework/op_kernel_context_internal.h"
#include "core/framework/op_kernel_info.h"
#include "core/framework/tensor_type_and_shape.h"
#include "core/framework/tensorprotoutils.h"
#include "core/graph/onnx_protobuf.h"
#include "core/session/allocator_adapters.h"
#include "core/session/api_utils.h"
#include "core/session/custom_ops.h"
#include "core/session/inference_session.h"
#include "core/session/ort_apis.h"
#include "core/platform/threadpool.h"
// NOTE: OrtKernelContext is used by both custom ops and compiled kernels.
// In a minimal build, ORT_EXTENDED_MINIMAL_BUILD is used to enable EPs like CoreML/NNAPI which use compiled kernels,
// and ORT_MINIMAL_BUILD_CUSTOM_OPS is used to allow external custom op libraries to be used.
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
#define ENABLE_ORT_KERNEL_CONTEXT_API 1
#endif
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
#define ENABLE_CUSTOM_OP_API 1
#endif
#if !defined(ORT_MINIMAL_BUILD)
static constexpr uint32_t min_ort_version_with_optional_io_support = 8;
static constexpr uint32_t min_ort_version_with_variadic_io_support = 14;
static constexpr uint32_t min_ort_version_with_custom_version = 17;
#endif
#if ENABLE_CUSTOM_OP_API
static constexpr uint32_t min_ort_version_with_compute_v2_support = 16;
static constexpr uint32_t min_ort_version_with_shape_inference = 17;
#endif
#if !defined(DISABLE_FLOAT8_TYPES)
#define SUPPORTED_TENSOR_TYPES DataTypeImpl::AllTensorTypesIRv9()
#else
#define SUPPORTED_TENSOR_TYPES DataTypeImpl::AllTensorTypesIRv4()
#endif
#if defined(ORT_MINIMAL_BUILD)
struct OrtShapeInferContext {
size_t GetInputCount() const { return 0; }
OrtTensorTypeAndShapeInfo* GetInputTypeShape(size_t) const { return {}; }
onnxruntime::Status SetOutputTypeShape(size_t, const OrtTensorTypeAndShapeInfo*) const {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"OrtShapeInferContext::SetOutputTypeShape not implemented for minimal build");
}
const ONNX_NAMESPACE::AttributeProto* GetAttr(const char*) const { return {}; }
};
#else
struct OrtShapeInferContext {
OrtShapeInferContext(ONNX_NAMESPACE::InferenceContext& ctx) : ctx_(ctx) {
auto num_inputs = ctx_.getNumInputs();
for (size_t ith_input = 0; ith_input < num_inputs; ++ith_input) {
const auto* input_type = ctx_.getInputType(ith_input);
const auto& value_case = input_type->value_case();
ORT_ENFORCE(value_case == ONNX_NAMESPACE::TypeProto::kTensorType,
"shape inference not yet supported for non-tensor types");
const auto& shape_proto = input_type->tensor_type().shape();
const auto& type_proto = input_type->tensor_type();
auto elem_type = ::onnxruntime::utils::CApiElementTypeFromProtoType(type_proto.elem_type());
auto tensor_shape = ::onnxruntime::utils::GetTensorShapeFromTensorShapeProto(shape_proto);
auto symbolic_dims = GetSymbolicDims(shape_proto);
input_type_shapes_.emplace_back(
OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper(elem_type, tensor_shape, &symbolic_dims).release());
}
}
~OrtShapeInferContext() = default;
size_t GetInputCount() const { return input_type_shapes_.size(); }
OrtTensorTypeAndShapeInfo* GetInputTypeShape(size_t idx) const {
return input_type_shapes_.at(idx).get();
}
onnxruntime::Status SetOutputTypeShape(size_t index, const OrtTensorTypeAndShapeInfo* info) const {
ORT_RETURN_IF_NOT(info, "Invalid shape info");
ONNX_NAMESPACE::TensorShapeProto shape_proto;
const auto& symbolic_dims = info->dim_params;
const auto& integer_dims = info->shape.GetDims();
ORT_RETURN_IF_NOT(symbolic_dims.size() == integer_dims.size(), "symbolic and integer dims mismatch!");
for (size_t ith = 0; ith < symbolic_dims.size(); ith++) {
auto* dim_proto = shape_proto.add_dim();
if (symbolic_dims[ith].size() > 0) {
dim_proto->set_dim_param(symbolic_dims[ith]);
} else {
dim_proto->set_dim_value(integer_dims[ith]);
}
}
ONNX_NAMESPACE::updateOutputShape(ctx_, index, shape_proto);
return onnxruntime::Status::OK();
}
const ONNX_NAMESPACE::AttributeProto* GetAttr(const char* attr_name) const {
return ctx_.getAttribute(attr_name);
}
private:
static std::vector<std::string> GetSymbolicDims(const ONNX_NAMESPACE::TensorShapeProto& shape_proto) {
std::vector<std::string> symblic_dims;
for (int ith = 0; ith < shape_proto.dim_size(); ith++) {
const auto& dim = shape_proto.dim(ith);
if (::onnxruntime::utils::HasDimValue(dim)) {
symblic_dims.emplace_back();
} else {
symblic_dims.emplace_back(dim.dim_param());
}
}
return symblic_dims;
}
ONNX_NAMESPACE::InferenceContext& ctx_;
using TypeShapePtr = std::unique_ptr<OrtTensorTypeAndShapeInfo>;
onnxruntime::InlinedVector<TypeShapePtr> input_type_shapes_;
};
#endif
#if ENABLE_ORT_KERNEL_CONTEXT_API
template <typename T>
static OrtStatusPtr ExecuteIfKernelContextApiEnabled(const T& fn) {
API_IMPL_BEGIN
return fn();
API_IMPL_END
}
#else
template <typename T>
static OrtStatusPtr ExecuteIfKernelContextApiEnabled(const T&) {
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "OrtKernelContext API is not enabled in this build");
}
#endif
ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetInputCount, _In_ const OrtKernelContext* context, _Out_ size_t* out) {
return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr {
*out = reinterpret_cast<const onnxruntime::OpKernelContextInternal*>(context)->InputCount();
return nullptr;
});
};
ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetOutputCount, _In_ const OrtKernelContext* context, _Out_ size_t* out) {
return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr {
*out = reinterpret_cast<const onnxruntime::OpKernelContextInternal*>(context)->OutputCount();
return nullptr;
});
};
ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetInput, _In_ const OrtKernelContext* context, _In_ size_t index,
_Out_ const OrtValue** out) {
return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr {
const auto* ctx = reinterpret_cast<const onnxruntime::OpKernelContextInternal*>(context);
*out = reinterpret_cast<const OrtValue*>(ctx->GetInputMLValue(onnxruntime::narrow<int>(index)));
return nullptr;
});
};
ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index,
_In_ const int64_t* dim_values, size_t dim_count, _Out_ OrtValue** out) {
return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr {
onnxruntime::TensorShape shape(dim_values, dim_count);
auto* ctx = reinterpret_cast<onnxruntime::OpKernelContextInternal*>(context);
*out = reinterpret_cast<OrtValue*>(ctx->OutputMLValue(onnxruntime::narrow<int>(index), shape));
return nullptr;
});
};
ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetGPUComputeStream, _In_ const OrtKernelContext* context,
_Outptr_ void** out) {
return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr {
auto* stream = reinterpret_cast<const onnxruntime::OpKernelContext*>(context)->GetComputeStream();
if (stream)
*out = stream->GetHandle();
else
*out = nullptr;
return nullptr;
});
};
ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetAllocator, _In_ const OrtKernelContext* context,
_In_ const OrtMemoryInfo* mem_info, _Outptr_ OrtAllocator** out) {
return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr {
const auto* ctx = reinterpret_cast<const onnxruntime::OpKernelContextInternal*>(context);
onnxruntime::AllocatorPtr allocator = ctx->GetAllocator(mem_info->device);
if (!allocator) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "No requested allocator available");
}
auto p = std::make_unique<onnxruntime::OrtAllocatorImplWrappingIAllocator>(std::move(allocator));
*out = p.release();
return nullptr;
});
};
ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetResource, _In_ const OrtKernelContext* context,
_In_ int resource_version, _In_ int resource_id, _Outptr_ void** resource) {
return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr {
*resource = {};
const auto* ctx = reinterpret_cast<const onnxruntime::OpKernelContext*>(context);
auto* stream = reinterpret_cast<onnxruntime::Stream*>(ctx->GetComputeStream());
if (!stream) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Failed to fetch a stream hosting the requested resource");
}
*resource = stream->GetResource(resource_version, resource_id);
return nullptr;
});
};
ORT_API_STATUS_IMPL(OrtApis::KernelContext_ParallelFor, _In_ const OrtKernelContext* context,
_In_ void (*fn)(void*, size_t), _In_ size_t total, _In_ size_t num_batch, _In_ void* usr_data) {
return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr {
if (!context) {
return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, "Invalid context");
}
if (fn && total) {
const auto* ctx = reinterpret_cast<const onnxruntime::OpKernelContext*>(context);
auto* tp = ctx->GetOperatorThreadPool();
if (num_batch) {
onnxruntime::concurrency::ThreadPool::TryBatchParallelFor(
tp,
static_cast<std::ptrdiff_t>(total),
[fn, usr_data](std::ptrdiff_t ith) { fn(usr_data, static_cast<size_t>(ith)); },
static_cast<std::ptrdiff_t>(num_batch));
} else {
onnxruntime::concurrency::ThreadPool::TrySimpleParallelFor(
tp,
static_cast<std::ptrdiff_t>(total),
[fn, usr_data](std::ptrdiff_t ith) { fn(usr_data, static_cast<size_t>(ith)); });
}
}
return nullptr;
});
};
ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetLogger, _In_ const OrtKernelContext* context,
_Outptr_ const OrtLogger** logger) {
return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr {
const auto& kernel_ctx_logger = reinterpret_cast<const onnxruntime::OpKernelContextInternal*>(context)->Logger();
*logger = reinterpret_cast<const OrtLogger*>(&kernel_ctx_logger);
return nullptr;
});
}
// Enabled via ExecuteIfKernelContextApiEnabled due to KernelContext_GetLogger
ORT_API_STATUS_IMPL(OrtApis::Logger_LogMessage, _In_ const OrtLogger* logger, OrtLoggingLevel log_severity_level,
_In_z_ const char* message, _In_z_ const ORTCHAR_T* file_path, int line_number,
_In_z_ const char* func_name) {
return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr {
const auto& actual_logger = *reinterpret_cast<const onnxruntime::logging::Logger*>(logger);
const auto severity = static_cast<onnxruntime::logging::Severity>(log_severity_level);
const auto log_data_type = onnxruntime::logging::DataType::SYSTEM;
if (actual_logger.OutputIsEnabled(severity, log_data_type)) {
#ifdef _WIN32
const std::string file_path_str = onnxruntime::ToUTF8String(file_path);
onnxruntime::CodeLocation location(file_path_str.c_str(), line_number, func_name);
#else
onnxruntime::CodeLocation location(file_path, line_number, func_name);
#endif
onnxruntime::logging::Capture(
actual_logger,
severity,
onnxruntime::logging::Category::onnxruntime,
log_data_type,
location)
.Stream()
<< message;
}
return nullptr;
});
}
// Enabled via ExecuteIfKernelContextApiEnabled due to KernelContext_GetLogger
ORT_API_STATUS_IMPL(OrtApis::Logger_GetLoggingSeverityLevel, _In_ const OrtLogger* logger,
_Out_ OrtLoggingLevel* out) {
return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr {
const auto& actual_logger = *reinterpret_cast<const onnxruntime::logging::Logger*>(logger);
*out = static_cast<OrtLoggingLevel>(actual_logger.GetSeverity());
return nullptr;
});
}
#if ENABLE_CUSTOM_OP_API
template <typename T>
static OrtStatusPtr ExecuteIfCustomOpsApiEnabled(const T& fn) {
API_IMPL_BEGIN
return fn();
API_IMPL_END
}
#else
template <typename T>
static OrtStatusPtr ExecuteIfCustomOpsApiEnabled(const T&) {
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Custom operator API is not enabled in this build");
}
#endif
ORT_API_STATUS_IMPL(OrtApis::ShapeInferContext_GetInputCount, _In_ const OrtShapeInferContext* context,
_Out_ size_t* out) {
return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr {
*out = context->GetInputCount();
return nullptr;
});
}
ORT_API_STATUS_IMPL(OrtApis::ShapeInferContext_GetInputTypeShape, _In_ const OrtShapeInferContext* context,
_In_ size_t index, _Outptr_ OrtTensorTypeAndShapeInfo** info) {
return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr {
*info = context->GetInputTypeShape(index);
if (*info) {
return nullptr;
} else {
return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT,
"Failed to fetch type shape info for the index.");
}
});
}
ORT_API_STATUS_IMPL(OrtApis::ShapeInferContext_GetAttribute, _In_ const OrtShapeInferContext* context,
_In_ const char* attr_name, _Outptr_ const OrtOpAttr** attr) {
return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr {
*attr = reinterpret_cast<const OrtOpAttr*>(context->GetAttr(attr_name));
if (*attr) {
return nullptr;
} else {
return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Attribute does not exist.");
}
});
}
ORT_API_STATUS_IMPL(OrtApis::ShapeInferContext_SetOutputTypeShape, _In_ const OrtShapeInferContext* context,
_In_ size_t index, _In_ const OrtTensorTypeAndShapeInfo* info) {
return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr {
auto status = context->SetOutputTypeShape(index, info);
if (status.IsOK()) {
return nullptr;
} else {
return OrtApis::CreateStatus(static_cast<OrtErrorCode>(status.Code()), status.ErrorMessage().c_str());
}
});
}
ORT_API_STATUS_IMPL(OrtApis::ReadOpAttr, _In_ const OrtOpAttr* op_attr, _In_ OrtOpAttrType type, _Inout_ void* data,
_In_ size_t len, _Out_ size_t* out) {
return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr {
if (!op_attr) {
return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Invalid attribute.");
}
auto attr = reinterpret_cast<const ONNX_NAMESPACE::AttributeProto*>(op_attr);
OrtStatusPtr ret = nullptr;
*out = 0;
switch (type) {
case OrtOpAttrType::ORT_OP_ATTR_FLOAT: {
if (len < sizeof(float)) {
ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT,
"Size of data not large enough to hold a float.");
} else {
if (attr->has_f()) {
auto output_f = reinterpret_cast<float*>(data);
*output_f = attr->f();
} else {
ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Attribute has no float value.");
}
}
*out = sizeof(float);
break;
}
case OrtOpAttrType::ORT_OP_ATTR_FLOATS: {
const auto& floats = attr->floats();
auto num_floats = floats.size();
if (len < sizeof(float) * num_floats) {
ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT,
"Size of data not large enough to hold the array of floats.");
} else {
auto output_f = reinterpret_cast<float*>(data);
for (auto f : floats) {
*output_f = f;
output_f++;
}
}
*out = num_floats * sizeof(float);
break;
}
case OrtOpAttrType::ORT_OP_ATTR_INT: {
if (len < sizeof(int)) {
ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT,
"Size of data not large enough to hold an int64.");
} else {
if (attr->has_i()) {
auto output_i = reinterpret_cast<int64_t*>(data);
*output_i = attr->i();
} else {
ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Attribute has no int64 value.");
}
}
*out = sizeof(int64_t);
break;
}
case OrtOpAttrType::ORT_OP_ATTR_INTS: {
const auto& ints = attr->ints();
auto num_ints = ints.size();
if (len < sizeof(int64_t) * num_ints) {
ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT,
"Size of data not large enough to hold the array of int64.");
} else {
auto output_i = reinterpret_cast<int64_t*>(data);
for (auto i : ints) {
*output_i = i;
output_i++;
}
}
*out = num_ints * sizeof(int64_t);
break;
}
case OrtOpAttrType::ORT_OP_ATTR_STRING: {
const auto& s = attr->s();
if (len < s.size() + 1) {
ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT,
"Size of data not large enough to hold the string.");
} else {
char* output_c = reinterpret_cast<char*>(data);
for (char c : s) {
*output_c++ = c;
}
*output_c = '\0';
}
*out = s.size() + 1;
break;
}
case OrtOpAttrType::ORT_OP_ATTR_STRINGS: {
const auto& ss = attr->strings();
size_t num_bytes = 0;
for_each(ss.begin(), ss.end(), [&num_bytes](const std::string& s) { num_bytes += s.size() + 1; });
if (len < num_bytes) {
ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT,
"Size of data not large enough to hold the array of strings.");
} else {
char* output_c = reinterpret_cast<char*>(data);
for (const auto& s : ss) {
for (char c : s) {
*output_c++ = c;
}
*output_c++ = '\0';
}
}
*out = num_bytes;
break;
}
default:
ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Unexpected attribute type. ");
}
return ret;
});
}
ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name,
_Out_ float* out) {
return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr {
auto status = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAttr<float>(name, out);
if (status.IsOK())
return nullptr;
return onnxruntime::ToOrtStatus(status);
});
}
ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_int64, _In_ const OrtKernelInfo* info, _In_ const char* name,
_Out_ int64_t* out) {
return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr {
auto status = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAttr<int64_t>(name, out);
if (status.IsOK())
return nullptr;
return onnxruntime::ToOrtStatus(status);
});
}
ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_string, _In_ const OrtKernelInfo* info, _In_ const char* name,
_Out_ char* out, _Inout_ size_t* size) {
return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr {
std::string value;
auto status = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAttr<std::string>(name, &value);
if (status.IsOK()) {
if (out == nullptr) { // User is querying the true size of the attribute
*size = value.size() + 1;
return nullptr;
} else if (*size >= value.size() + 1) {
std::memcpy(out, value.data(), value.size());
out[value.size()] = '\0';
*size = value.size() + 1;
return nullptr;
} else { // User has provided a buffer that is not large enough
*size = value.size() + 1;
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Result buffer is not large enough");
}
}
return onnxruntime::ToOrtStatus(status);
});
}
template <typename T, typename std::enable_if<std::is_fundamental<T>::value, int>::type = 0>
static Status CopyDataFromVectorToMemory(const std::vector<T>& values, T* out, size_t* size) {
if (out == nullptr) { // User is querying the true size of the attribute
*size = values.size();
return Status::OK();
} else if (*size >= values.size()) {
std::memcpy(out, values.data(), values.size() * sizeof(T));
*size = values.size();
} else { // User has provided a buffer that is not large enough
*size = values.size();
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Result buffer is not large enough");
}
return Status::OK();
}
ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttributeArray_float, _In_ const OrtKernelInfo* info, _In_ const char* name,
_Out_ float* out, _Inout_ size_t* size) {
return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr {
std::vector<float> values;
auto status = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAttrs<float>(name, values);
if (status.IsOK()) {
status = CopyDataFromVectorToMemory<float>(values, out, size);
}
return onnxruntime::ToOrtStatus(status);
});
}
ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttributeArray_int64, _In_ const OrtKernelInfo* info, _In_ const char* name,
_Out_ int64_t* out, _Inout_ size_t* size) {
return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr {
std::vector<int64_t> values;
auto status = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAttrs<int64_t>(name, values);
if (status.IsOK()) {
status = CopyDataFromVectorToMemory<int64_t>(values, out, size);
}
return onnxruntime::ToOrtStatus(status);
});
}
ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_tensor, _In_ const OrtKernelInfo* info, _In_z_ const char* name,
_Inout_ OrtAllocator* allocator, _Outptr_ OrtValue** out) {
return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr {
const auto* op_kinfo = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info);
// Get TensorProto attribute
onnx::TensorProto tensor_proto;
auto status = op_kinfo->GetAttr<onnx::TensorProto>(name, &tensor_proto);
if (!status.IsOK()) {
return onnxruntime::ToOrtStatus(status);
}
// Determine the tensor's size in bytes.
size_t req_size = 0;
status = onnxruntime::utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &req_size);
if (!status.IsOK()) {
return onnxruntime::ToOrtStatus(status);
}
// Create Tensor that owns buffer memory that will be allocated with the provided OrtAllocator.
onnxruntime::TensorShape tensor_shape = onnxruntime::utils::GetTensorShapeFromTensorProto(tensor_proto);
const auto* type = onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType();
onnxruntime::AllocatorPtr alloc_ptr = std::make_shared<onnxruntime::IAllocatorImplWrappingOrtAllocator>(allocator);
auto tensorp = std::make_unique<onnxruntime::Tensor>(type, tensor_shape, std::move(alloc_ptr));
// Deserialize TensorProto into pre-allocated, empty Tensor.
status = onnxruntime::utils::TensorProtoToTensor(onnxruntime::Env::Default(), nullptr, tensor_proto, *tensorp);
if (!status.IsOK()) {
return onnxruntime::ToOrtStatus(status);
}
// Initialize OrtValue from Tensor.
auto ml_tensor = onnxruntime::DataTypeImpl::GetType<onnxruntime::Tensor>();
auto value = std::make_unique<OrtValue>();
value->Init(tensorp.release(), ml_tensor, ml_tensor->GetDeleteFunc());
*out = value.release();
return nullptr;
});
}
ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetInputCount, _In_ const OrtKernelInfo* info, _Out_ size_t* out) {
return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr {
*out = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetInputCount();
return nullptr;
});
};
ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetOutputCount, _In_ const OrtKernelInfo* info, _Out_ size_t* out) {
return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr {
*out = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetOutputCount();
return nullptr;
});
};
ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetInputName, _In_ const OrtKernelInfo* info, size_t index,
_Out_ char* out, _Inout_ size_t* size) {
return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr {
const auto* op_info = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info);
const auto input_defs = op_info->node().InputDefs();
if (index >= input_defs.size()) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "::OrtKernelInfo input index is out of bounds");
}
auto status = CopyStringToOutputArg(input_defs[index]->Name(),
"Output buffer is not large enough for ::OrtKernelInfo input name", out, size);
return onnxruntime::ToOrtStatus(status);
});
}
ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetOutputName, _In_ const OrtKernelInfo* info, size_t index, _Out_ char* out,
_Inout_ size_t* size) {
return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr {
const auto* op_info = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info);
const auto output_defs = op_info->node().OutputDefs();
if (index >= output_defs.size()) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "::OrtKernelInfo output index is out of bounds");
}
auto status = CopyStringToOutputArg(output_defs[index]->Name(),
"Output buffer is not large enough for ::OrtKernelInfo output name",
out, size);
return onnxruntime::ToOrtStatus(status);
});
}
ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetInputTypeInfo, _In_ const OrtKernelInfo* info, size_t index,
_Outptr_ OrtTypeInfo** type_info) {
return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr {
const auto* op_info = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info);
const auto input_defs = op_info->node().InputDefs();
if (index >= input_defs.size()) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "::OrtKernelInfo input index is out of bounds");
}
const onnxruntime::NodeArg* node_arg = input_defs[index];
const ONNX_NAMESPACE::TypeProto* type_proto = node_arg->TypeAsProto();
if (type_proto == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_GRAPH, "::OrtKernelInfo input does not have a type");
}
auto type_info_ret = OrtTypeInfo::FromTypeProto(*type_proto);
*type_info = type_info_ret.release();
return nullptr;
});
}
ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetOutputTypeInfo, _In_ const OrtKernelInfo* info, size_t index,
_Outptr_ OrtTypeInfo** type_info) {
return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr {
const auto* op_info = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info);
const auto output_defs = op_info->node().OutputDefs();
if (index >= output_defs.size()) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "::OrtKernelInfo output index is out of bounds");
}
const onnxruntime::NodeArg* node_arg = output_defs[index];
const ONNX_NAMESPACE::TypeProto* type_proto = node_arg->TypeAsProto();
if (type_proto == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_GRAPH, "::OrtKernelInfo output does not have a type");
}
auto type_info_ret = OrtTypeInfo::FromTypeProto(*type_proto);
*type_info = type_info_ret.release();
return nullptr;
});
}
ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetConstantInput_tensor, _In_ const OrtKernelInfo* info, _In_ size_t index,
_Out_ int* is_constant, _Outptr_ const OrtValue** out) {
return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr {
const auto* op_info = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info);
*is_constant = static_cast<int>(op_info->TryGetConstantInput(gsl::narrow_cast<int>(index), out));
return nullptr;
});
};
ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetNodeName, _In_ const OrtKernelInfo* info, _Out_ char* out,
_Inout_ size_t* size) {
return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr {
const auto* op_info = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info);
auto status = CopyStringToOutputArg(op_info->node().Name(),
"Output buffer is not large enough for ::OrtKernelInfo node name", out, size);
return onnxruntime::ToOrtStatus(status);
});
}
ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetLogger, _In_ const OrtKernelInfo* info, _Outptr_ const OrtLogger** logger) {
return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr {
const auto* ep = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetExecutionProvider();
if (ep == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_GRAPH, "::OrtKernelInfo does not have an execution provider");
}
const auto* ep_logger = ep->GetLogger();
if (ep_logger == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_GRAPH,
"::OrtKernelInfo cannot get a valid logger from "
"its execution provider");
}
*logger = reinterpret_cast<const OrtLogger*>(ep_logger);
return nullptr;
});
}
ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out) {
return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr {
onnxruntime::AllocatorPtr allocator = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAllocator(mem_type);
if (!allocator) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "No requested allocator available");
}
auto p = std::make_unique<onnxruntime::OrtAllocatorImplWrappingIAllocator>(std::move(allocator));
*out = p.release();
return nullptr;
});
}
ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out) {
if (count_or_bytes == 0) {
*out = nullptr;
return nullptr;
}
onnxruntime::AllocatorPtr allocator = reinterpret_cast<const onnxruntime::OpKernelContext*>(context)->GetAllocator(mem_info->device);
if (!allocator) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "No requested allocator available");
}
onnxruntime::Stream* stream = reinterpret_cast<const onnxruntime::OpKernelContext*>(context)->GetComputeStream();
*out = AllocateBufferWithOptions(*allocator, count_or_bytes, false, stream, stream->GetWaitNotificationFn());
return nullptr;
};
#if ENABLE_CUSTOM_OP_API
#include "core/framework/customregistry.h"
namespace onnxruntime {
struct CustomOpKernel : OpKernel {
CustomOpKernel(const OpKernelInfo& info, const OrtCustomOp& op) : OpKernel(info), op_(op) {
if (op_.version > ORT_API_VERSION) {
ORT_THROW("Unsupported version '" + std::to_string(op_.version) + "' in custom op '" + op.GetName(&op));
}
if (op_.version >= min_ort_version_with_compute_v2_support &&
op_.CreateKernelV2) {
op_kernel_ = nullptr;
Ort::ThrowOnError(
op_.CreateKernelV2(
&op_,
OrtGetApiBase()->GetApi(op_.version),
reinterpret_cast<const OrtKernelInfo*>(&info),
&op_kernel_));
} else {
op_kernel_ = op_.CreateKernel(&op_, OrtGetApiBase()->GetApi(op_.version),
reinterpret_cast<const OrtKernelInfo*>(&info));
}
}
~CustomOpKernel() override {
op_.KernelDestroy(op_kernel_);
}
Status Compute(OpKernelContext* ctx) const override {
if (op_.version >= min_ort_version_with_compute_v2_support &&
op_.KernelComputeV2) {
auto status_ptr = op_.KernelComputeV2(op_kernel_, reinterpret_cast<OrtKernelContext*>(ctx));
return ToStatus(status_ptr);
} else {
op_.KernelCompute(op_kernel_, reinterpret_cast<OrtKernelContext*>(ctx));
return Status::OK();
}
}
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CustomOpKernel);
const OrtCustomOp& op_;
void* op_kernel_;
};
#if !defined(ORT_MINIMAL_BUILD)
KernelCreateInfo CreateKernelCreateInfo(const std::string& domain, const OrtCustomOp* op) {
const size_t input_count = op->GetInputTypeCount(op);
const size_t output_count = op->GetOutputTypeCount(op);
KernelDefBuilder def_builder;
def_builder.SetName(op->GetName(op))
.SetDomain(domain);
if (op->version >= min_ort_version_with_custom_version) {
if (op->GetStartVersion && op->GetEndVersion) {
def_builder.SinceVersion(op->GetStartVersion(op), op->GetEndVersion(op));
} else if (op->GetStartVersion) {
def_builder.SinceVersion(op->GetStartVersion(op));
} else {
def_builder.SinceVersion(1);
}
} else {
def_builder.SinceVersion(1);
}
// GetInputMemoryType was introduced in ver 13. This check allows custom ops compiled using older versions
// to work with newer versions (> 12) of the ORT binary.
if (op->version > 12) {
for (size_t i = 0; i < input_count; i++) {
def_builder.InputMemoryType(op->GetInputMemoryType(op, i), gsl::narrow_cast<int>(i));
}
}
for (size_t i = 0; i < input_count; i++) {
const auto input_type = op->GetInputType(op, i);
const auto input_name = "Input" + std::to_string(i);
if (input_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) {
def_builder.TypeConstraint(input_name, SUPPORTED_TENSOR_TYPES);
} else {
def_builder.TypeConstraint(input_name,
DataTypeImpl::TensorTypeFromONNXEnum(static_cast<int>(input_type))->AsTensorType());
}
}
for (size_t i = 0; i < output_count; i++) {
const auto output_type = op->GetOutputType(op, i);
const auto output_name = "Output" + std::to_string(i);
if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) {
def_builder.TypeConstraint(output_name, SUPPORTED_TENSOR_TYPES);
} else {
def_builder.TypeConstraint(output_name,
DataTypeImpl::TensorTypeFromONNXEnum(static_cast<int>(output_type))->AsTensorType());
}
}
if (const char* provider_type = op->GetExecutionProviderType(op)) {
def_builder.Provider(provider_type);
} else {
def_builder.Provider(onnxruntime::kCpuExecutionProvider);
}
KernelCreateFn kernel_create_fn = [op](FuncManager&, const OpKernelInfo& info,
std::unique_ptr<OpKernel>& out) -> Status {
out = std::make_unique<CustomOpKernel>(info, *op);
return Status::OK();
};
return KernelCreateInfo(def_builder.Build(), kernel_create_fn);
}
ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const std::vector<const OrtCustomOp*>& ops) {
// The function registers the first schema assuming all the other one are the same except the types constraints.
ORT_ENFORCE(ops.size() > 0, "No kernels to registers.");
int undefined = 0;
// Creation of the schema for the first kernel in ops.
const OrtCustomOp* op = *ops.begin();
ONNX_NAMESPACE::OpSchema schema(op->GetName(op), "custom op registered at runtime", 0);
auto create_type_constraint = [&ops, &schema, &undefined](const OrtCustomOp* op, int count, int i, bool is_input) {
onnx::OpSchema::FormalParameterOption option = onnx::OpSchema::FormalParameterOption::Single;
bool is_homogeneous = true;
int min_arity = 1;
// The OrtCustomOp interface did not support the methods to query input/output characteristics before
// ORT API version 8. So, query the relevant methods ONLY from API version 8 onwards.
if (op->version >= min_ort_version_with_optional_io_support) {
const auto characteristic = is_input ? op->GetInputCharacteristic(op, i) : op->GetOutputCharacteristic(op, i);
// Support for optional and variadic inputs/output was added in versions 8 and 14, respectively.
if (characteristic == OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL) {
option = onnx::OpSchema::FormalParameterOption::Optional;
} else if ((op->version >= min_ort_version_with_variadic_io_support) &&
(characteristic == OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC)) {
ORT_ENFORCE(i == count - 1, "Only the last ", (is_input ? "input" : "output"),
" to a custom op may be marked variadic.");
option = onnx::OpSchema::FormalParameterOption::Variadic;
min_arity = is_input ? op->GetVariadicInputMinArity(op) : op->GetVariadicOutputMinArity(op);
is_homogeneous = static_cast<bool>(is_input
? op->GetVariadicInputHomogeneity(op)
: op->GetVariadicOutputHomogeneity(op));
}
}
// The loop goes through all operators sharing the same schema to build
// the minimal type constraints for all of them. All kernels must have
// the same number of inputs / outputs among themselves to be able to build
// the type constraints. Any kind of incompatibility between a schema and
// a kernel is checked by method IsCompatible once the schema is created
// by this method.
std::unordered_set<ONNXTensorElementDataType> all_types;
for (auto o : ops) {
ORT_ENFORCE(static_cast<size_t>(i) != (is_input ? o->GetInputTypeCount(o) : o->GetOutputTypeCount(o)),
"Another version of operator '", schema.Name(),
"'has a different number of ", (is_input ? "inputs" : "outputs"),
". onnxruntime allows the overloading of an operator "
"if all versions have the same number of declared ",
(is_input ? "inputs" : "outputs"), ".");
const auto type = is_input ? o->GetInputType(o, i) : o->GetOutputType(o, i);
if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) {
// If 'type' is undefined, all types are allowed regardless of what other versions of the same operator
// define. In that case, all_types is cleared, that's the convention used by the code following this loop
// to declare all types as possible types.
all_types.clear();
break;
}
all_types.insert(type);
}
std::string prefix = is_input ? "Input" : "Output";
std::string name = prefix + std::to_string(i);
if (is_input) {
schema.Input(gsl::narrow_cast<int>(i), name, "", name, option, is_homogeneous, min_arity);
} else {
schema.Output(gsl::narrow_cast<int>(i), name, "", name, option, is_homogeneous, min_arity);
}
if (!all_types.empty()) {
// all_types is not empty then only the types in this container are allowed of this input.
std::vector<std::string> types;
for (auto type : all_types) {
const ONNX_NAMESPACE::TypeProto* type_proto =
DataTypeImpl::TensorTypeFromONNXEnum(static_cast<int>(type))->GetTypeProto();
types.push_back(*ONNX_NAMESPACE::Utils::DataTypeUtils::ToType(*type_proto));
}
schema.TypeConstraint(name, types, "defined list of types");
} else {
// all_types is empty. As mentioned in the previous loop, all types are allowed.
schema.TypeConstraint(name, DataTypeImpl::ToString(SUPPORTED_TENSOR_TYPES), "all types");
undefined++;
}
};
const size_t input_count = op->GetInputTypeCount(op);
for (size_t i = 0; i < input_count; i++) {
create_type_constraint(op, static_cast<int>(input_count), static_cast<int>(i), true);
}
const size_t output_count = op->GetOutputTypeCount(op);
for (size_t i = 0; i < output_count; i++) {
const auto type = op->GetOutputType(op, i);
if (ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type) {
if (op->GetOutputCharacteristic(op, i) == OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED) {
ORT_ENFORCE(1 == undefined,
"There must be one (and only one) dynamic typed input to the custom op. "
"Its type info at runtime will be used to infer the type info of this dynamic typed output "
"which is required for the success of the model loading step. "
"More than one dynamic typed inputs are currently not supported as differing types at runtime "
"means the output type cannot be inferred without which model loading cannot proceed.");
}
}
create_type_constraint(op, static_cast<int>(output_count), static_cast<int>(i), false);
}
schema.SetDomain(domain);
if (op->version >= min_ort_version_with_custom_version && op->GetStartVersion) {
schema.SinceVersion(op->GetStartVersion(op));
} else {
schema.SinceVersion(1);
}
schema.AllowUncheckedAttributes();
if (op->version >= min_ort_version_with_shape_inference && op->InferOutputShapeFn) {
schema.TypeAndShapeInferenceFunction([op](ONNX_NAMESPACE::InferenceContext& infer_ctx) {
OrtShapeInferContext ctx(infer_ctx);
op->InferOutputShapeFn(op, &ctx);
});
}
return schema;
}
Status IsCompatible(const ONNX_NAMESPACE::OpSchema& schema, const OrtCustomOp* op) {
const size_t input_count = op->GetInputTypeCount(op);
const size_t output_count = op->GetOutputTypeCount(op);
// check inputs
const auto& input_parameters = schema.inputs();
ORT_RETURN_IF_NOT(input_parameters.size() == input_count, "input count does not match");
for (size_t i = 0; i < input_parameters.size(); ++i) {
const auto characteristic = op->GetInputCharacteristic(op, i);
const auto& formal_parameter = input_parameters[i];
if (characteristic == OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL) {
ORT_RETURN_IF_NOT(op->version < min_ort_version_with_optional_io_support ||
formal_parameter.GetOption() == onnx::OpSchema::FormalParameterOption::Optional,
"custom op schemas mismatch, expecting ", i + 1,
i == 0 ? "st" : (i == 1 ? "nd" : "th"),
" input to be of optional type");
} else if (characteristic == OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC) {
ORT_RETURN_IF_NOT(formal_parameter.GetOption() == onnx::OpSchema::FormalParameterOption::Variadic,
"custom op schemas mismatch, expecting ", i + 1,
i == 0 ? "st" : (i == 1 ? "nd" : "th"),
" input to be of variadic type");
ORT_RETURN_IF_NOT(op->version < min_ort_version_with_variadic_io_support ||
formal_parameter.GetIsHomogeneous() == (op->GetVariadicInputHomogeneity(op) != 0),
"custom op schemas mismatch, expecting ", i + 1,
i == 0 ? "st" : (i == 1 ? "nd" : "th"),
" input to keep same homogeneity");
ORT_RETURN_IF_NOT(formal_parameter.GetMinArity() == op->GetVariadicInputMinArity(op),
"custom op schemas mismatch, expecting ", i + 1,
i == 0 ? "st" : (i == 1 ? "nd" : "th"),
" input to keep same arity");
} else {
ORT_RETURN_IF_NOT(formal_parameter.GetOption() == onnx::OpSchema::FormalParameterOption::Single,
"custom op schemas mismatch, expecting ", i + 1,
i == 0 ? "st" : (i == 1 ? "nd" : "th"),
" input to be of single type");
}
}
// check outputs
const auto& output_parameters = schema.outputs();
ORT_RETURN_IF_NOT(output_parameters.size() == output_count, "output count does not match");
for (size_t i = 0; i < output_parameters.size(); ++i) {
const auto characteristic = op->GetOutputCharacteristic(op, i);
const auto& formal_parameter = output_parameters[i];
if (characteristic == OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL) {
ORT_RETURN_IF_NOT(formal_parameter.GetOption() == onnx::OpSchema::FormalParameterOption::Optional,
"custom op schemas mismatch, expecting ", i + 1,
i == 0 ? "st" : (i == 1 ? "nd" : "th"),
" output to be of optional type");
} else if (characteristic == OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC) {
ORT_RETURN_IF_NOT(formal_parameter.GetOption() == onnx::OpSchema::FormalParameterOption::Variadic,
"custom op schemas mismatch, expecting ", i + 1,
i == 0 ? "st" : (i == 1 ? "nd" : "th"),
" output to be of variadic type");
ORT_RETURN_IF_NOT(formal_parameter.GetIsHomogeneous() == (op->GetVariadicOutputHomogeneity(op) != 0),
"custom op schemas mismatch, expecting ", i + 1,
i == 0 ? "st" : (i == 1 ? "nd" : "th"),
" output to keep same homogeneity");
ORT_RETURN_IF_NOT(formal_parameter.GetMinArity() == op->GetVariadicOutputMinArity(op),
"custom op schemas mismatch, expecting ", i + 1,
i == 0 ? "st" : (i == 1 ? "nd" : "th"),
" output to keep same arity");
} else {
ORT_RETURN_IF_NOT(formal_parameter.GetOption() == onnx::OpSchema::FormalParameterOption::Single,
"custom op schemas mismatch, expecting ", i + 1,
i == 0 ? "st" : (i == 1 ? "nd" : "th"),
" output to be of single type");
}
}
return Status::OK();
}
void InferOutputTypes(const InlinedVector<const KernelDef*>& kernel_defs,
ONNX_NAMESPACE::InferenceContext& infer_ctx) {
for (const auto& kernel_def : kernel_defs) {
const auto& type_constraints = kernel_def->TypeConstraints();
auto num_inputs = infer_ctx.getNumInputs();
bool matched = true;
ONNXTensorElementDataType undef = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
// first, make sure there is a constraint for every input
for (size_t i = 0; i < num_inputs && matched; ++i) {
auto input_name = "Input" + std::to_string(i);
auto input_type = infer_ctx.getInputType(i);
if (input_type) {
auto elem_type = static_cast<ONNXTensorElementDataType>(input_type->tensor_type().elem_type());
auto tc_iter = type_constraints.find(input_name);
if (tc_iter != type_constraints.end()) {
if (tc_iter->second.size() > 1) {
undef = elem_type;
} else if (tc_iter->second.size() != 1 ||
tc_iter->second[0] != DataTypeImpl::TensorTypeFromONNXEnum(elem_type)) {
matched = false;
}
} else {
matched = false;
}
} else {
matched = false;
}
} // for
// next, ensure that there is a constraint for every output
auto num_outputs = infer_ctx.getNumOutputs();
for (size_t i = 0; i < num_outputs && matched; i++) {
auto output_name = "Output" + std::to_string(i);
auto tc_iter = type_constraints.find(output_name);
if (tc_iter == type_constraints.end() || tc_iter->second.size() < 1) {
matched = false;
}
}
if (matched) {
for (size_t i = 0; i < num_outputs; i++) {
auto output_name = "Output" + std::to_string(i);
auto output_type = infer_ctx.getOutputType(i);
auto tc_iter = type_constraints.find(output_name);
if (tc_iter->second.size() > 1) {
output_type->mutable_tensor_type()->set_elem_type(undef);
} else {
output_type->mutable_tensor_type()->set_elem_type(
tc_iter->second[0]->GetTypeProto()->tensor_type().elem_type());
}
}
break;
}
}
}
#endif
common::Status CreateCustomRegistry(gsl::span<OrtCustomOpDomain* const> op_domains,
std::shared_ptr<CustomRegistry>& output) {
output = std::make_shared<CustomRegistry>();
for (const auto& domain : op_domains) {
#if !defined(ORT_MINIMAL_BUILD)
std::unordered_map<std::string, ONNX_NAMESPACE::OpSchema> schema_map;
std::unordered_map<std::string, InlinedVector<const KernelDef*>> kernel_def_map;
// Domain is not empty - add it to the DomainToVersion ONNX map
// If domain is empty, it is assumed to be part of the ONNX domain
if (!domain->domain_.empty()) {
// Add it to the DomainToVersion ONNX map if it doesn't already exist
// For example, two sessions using the same session_options should not add the same custom op domain
// to the version map twice
auto& domain_to_version_range_instance = ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance();
const auto& domain_to_version_map = domain_to_version_range_instance.Map();
if (domain_to_version_map.find(domain->domain_) == domain_to_version_map.end()) {
domain_to_version_range_instance.AddDomainToVersion(domain->domain_, 1, 1000);
}
}
// domain_kernels aggregate all custom operator per names.
std::unordered_map<std::string, std::vector<const OrtCustomOp*>> domain_kernels;
for (const auto* op : domain->custom_ops_) {
// define kernel
auto it = domain_kernels.find(op->GetName(op));
if (it == domain_kernels.end()) {
domain_kernels[op->GetName(op)] = {op};
} else {
domain_kernels[op->GetName(op)].push_back(op);
}
}
// Creation of the schemas, one per unique name.
for (auto& [name, ops] : domain_kernels) {
auto schema = CreateSchema(domain->domain_, ops);
// schema.Name() is equal to ops[0]->GetName(ops[0]) and op->GetName(op) is the value
// used as a key for dictionary domain_kernels, therefore name == schema.Name().
schema_map.emplace(schema.Name(), schema);
// This loops checks that all custom operators sharing the same name are compatible with the defined schema.
for (const auto* op : ops) {
// define kernel
auto kernel_create_info = CreateKernelCreateInfo(domain->domain_, op);
kernel_def_map[op->GetName(op)].push_back(kernel_create_info.kernel_def.get());
ORT_RETURN_IF_ERROR(output->RegisterCustomKernel(kernel_create_info));
// If IsCompatible returns false, then all custom operators named
// 'op->GetName(op)' are not compatible among themselves.
// They should have the same number of inputs and outputs, the same characteristics,
// (optional, ...). Only the type can change.
ORT_RETURN_IF_ERROR(IsCompatible(schema, op));
}
}
std::vector<ONNX_NAMESPACE::OpSchema> schemas;
for (auto schema_iter : schema_map) {
schemas.push_back(schema_iter.second);
InlinedVector<const KernelDef*> kernel_defs = std::move(kernel_def_map[schema_iter.first]);
auto infer_fn = schemas.back().GetTypeAndShapeInferenceFunction();
ONNX_NAMESPACE::InferenceFunction extended_infer_fn =
[infer_fn, kernel_defs](ONNX_NAMESPACE::InferenceContext& infer_ctx) {
InferOutputTypes(kernel_defs, infer_ctx);
if (infer_fn) {
infer_fn(infer_ctx);
}
};
schemas.back().TypeAndShapeInferenceFunction(extended_infer_fn);
}
ORT_RETURN_IF_ERROR(output->RegisterOpSet(schemas,
domain->domain_,
1 /* baseline opset version */,
1000 /* opset version */));
#else
// For a minimal build, we may not need any of the ONNX schema stuff but we still need to track
// the type template parameters to be used during the kernel def building step below
for (const auto* op : domain->custom_ops_) {
size_t undefined = 0;
size_t input_count = op->GetInputTypeCount(op);
for (size_t i = 0; i < input_count; i++) {
auto type = op->GetInputType(op, i);
if (ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type) {
undefined++;
}
}
KernelDefBuilder def_builder;
def_builder.SetName(op->GetName(op))
.SetDomain(domain->domain_)
.SinceVersion(1);
// GetInputMemoryType was introduced in ver 13. This check allows custom ops compiled using older versions
// to work with newer versions (> 12) of the ORT binary.
if (op->version > 12) {
for (size_t i = 0; i < input_count; i++) {
def_builder.InputMemoryType(op->GetInputMemoryType(op, i), i);
}
}
for (size_t i = 0; i < undefined; i++) {
def_builder.TypeConstraint("T" + std::to_string(i), SUPPORTED_TENSOR_TYPES);
}
if (const char* provider_type = op->GetExecutionProviderType(op)) {
def_builder.Provider(provider_type);
} else {
def_builder.Provider(onnxruntime::kCpuExecutionProvider);
}
KernelCreateFn kernel_create_fn = [op](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status {
out = std::make_unique<CustomOpKernel>(info, *op);
return Status::OK();
};
KernelCreateInfo create_info(def_builder.Build(), kernel_create_fn);
ORT_RETURN_IF_ERROR(output->RegisterCustomKernel(create_info));
}
#endif
} // for each domain
return Status::OK();
}
} // namespace onnxruntime
#endif // ENABLE_CUSTOM_OP_API