diff --git a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h new file mode 100644 index 0000000000..a98307a77b --- /dev/null +++ b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h @@ -0,0 +1,679 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Summary +// The header has APIs to save custom op authors the trouble of defining schemas, +// which will be inferred by functions' signature, as long as their argument list has types supported here. +// Input could be: +// 1. Tensor of onnx data types. +// 2. Span of onnx data types. +// 3. Scalar of onnx data types. +// A input could be optional if indicated as std::optional<...>. +// For an output, it must be a tensor of onnx data types. +// Further, the header also has utility for a simple custom struct, where resources could be kept, to be registered as a custom op. +// For concrete examples, please search keyword "LiteCustomOpTest" under "/onnxruntime/test/". +// Note - all APIs in this header are ABI. + +#pragma once +#include "onnxruntime_cxx_api.h" +#include +#include +#include + +namespace Ort { +namespace Custom { + +class TensorBase { + public: + TensorBase(OrtKernelContext* ctx) : ctx_(ctx) {} + operator bool() const { + return shape_.has_value(); + } + + protected: + struct KernelContext ctx_; + std::optional> shape_; +}; + +template +struct Span { + const T* data_ = {}; + size_t size_ = {}; + void Assign(const T* data, size_t size) { + data_ = data; + size_ = size; + } + size_t size() const { return size_; } + T operator[](size_t indice) const { + return data_[indice]; + } +}; + +template +class Tensor : public TensorBase { + public: + using TT = typename std::remove_reference::type; + Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx), indice_(indice), is_input_(is_input) { + if (is_input_) { + if (indice >= ctx_.GetInputCount()) { + ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT); + } + const_value_ = ctx_.GetInput(indice); + auto type_shape_info = const_value_.GetTensorTypeAndShapeInfo(); + shape_ = type_shape_info.GetShape(); + } + } + const std::vector& Shape() const { + if (!shape_.has_value()) { + ORT_CXX_API_THROW("tensor shape is not yet initialized", OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + return shape_.value(); + } + int64_t NumberOfElement() const { + if (shape_.has_value()) { + return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies()); + } else { + return 0; + } + } + const TT* Data() const { + return reinterpret_cast(const_value_.GetTensorRawData()); + } + TT* Allocate(const std::vector& shape) { + shape_ = shape; + if (!data_) { + shape_ = shape; + data_ = ctx_.GetOutput(indice_, shape).template GetTensorMutableData(); + } + return data_; + } + static TT GetT() { return (TT)0; } + const Span& AsSpan() { + if (!shape_.has_value() || shape_->size() != 1) { + ORT_CXX_API_THROW("invalid shape while trying to get a span out of Ort::Custom::Tensor", + OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + span_.Assign(Data(), static_cast((*shape_)[0])); + return span_; + } + const T& AsScalar() { + if (!shape_.has_value() || shape_->size() != 1 || (*shape_)[0] != 1) { + ORT_CXX_API_THROW("invalid shape while trying to get a scalar from Ort::Custom::Tensor", + OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + return *Data(); + } + + private: + size_t indice_; + bool is_input_; + ConstValue const_value_; // for input + TT* data_{}; // for output + Span span_; +}; + +template <> +class Tensor : public TensorBase { + public: + using strings = std::vector; + + Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx), indice_(indice), is_input_(is_input) { + if (is_input_) { + if (indice >= ctx_.GetInputCount()) { + ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT); + } + auto const_value = ctx_.GetInput(indice); + auto type_shape_info = const_value.GetTensorTypeAndShapeInfo(); + shape_ = type_shape_info.GetShape(); + auto num_chars = const_value.GetStringTensorDataLength(); + // note - there will be copy ... + auto num_strings = static_cast(NumberOfElement()); + if (num_strings) { + std::vector chars(num_chars + 1, '\0'); + std::vector offsets(num_strings); + const_value.GetStringTensorContent(static_cast(chars.data()), num_chars, offsets.data(), offsets.size()); + auto upper_bound = num_strings - 1; + input_strings_.resize(num_strings); + for (size_t i = upper_bound;; --i) { + if (i < upper_bound) { + chars[offsets[i + 1]] = '\0'; + } + input_strings_[i] = chars.data() + offsets[i]; + if (0 == i) { + break; + } + } + } + } + } + int64_t NumberOfElement() const { + if (shape_.has_value()) { + return std::accumulate(shape_->begin(), shape_->end(), 1ULL, std::multiplies()); + } else { + return 0; + } + } + const strings& Data() const { + return input_strings_; + } + void SetStringOutput(const strings& ss, const std::vector& dims) { + shape_ = dims; + std::vector raw; + for (const auto& s : ss) { + raw.push_back(s.data()); + } + auto output = ctx_.GetOutput(indice_, dims.data(), dims.size()); + // note - there will be copy ... + output.FillStringTensor(raw.data(), raw.size()); + } + const Span& AsSpan() { + ORT_CXX_API_THROW("span for TensorT of string not implemented", OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + const std::string& AsScalar() { + if (input_strings_.size() != 1) { + ORT_CXX_API_THROW("invalid shape while trying to get a scalar string from Ort::Custom::Tensor", + OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + return input_strings_[0]; + } + + private: + size_t indice_; + bool is_input_; + std::vector input_strings_; // for input +}; + +template <> +class Tensor : public TensorBase { + public: + using strings = std::vector; + using string_views = std::vector; + + Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx), indice_(indice), is_input_(is_input) { + if (is_input_) { + if (indice >= ctx_.GetInputCount()) { + ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT); + } + auto const_value = ctx_.GetInput(indice); + auto type_shape_info = const_value.GetTensorTypeAndShapeInfo(); + shape_ = type_shape_info.GetShape(); + auto num_chars = const_value.GetStringTensorDataLength(); + chars_.resize(num_chars + 1, '\0'); + auto num_strings = static_cast(NumberOfElement()); + if (num_strings) { + std::vector offsets(num_strings); + const_value.GetStringTensorContent(static_cast(chars_.data()), num_chars, offsets.data(), offsets.size()); + offsets.push_back(num_chars); + for (size_t i = 0; i < num_strings; ++i) { + input_string_views_.emplace_back(chars_.data() + offsets[i], offsets[i + 1] - offsets[i]); + } + } + } + } + int64_t NumberOfElement() const { + if (shape_.has_value()) { + return std::accumulate(shape_->begin(), shape_->end(), 1ULL, std::multiplies()); + } else { + return 0; + } + } + const string_views& Data() const { + return input_string_views_; + } + void SetStringOutput(const strings& ss, const std::vector& dims) { + shape_ = dims; + std::vector raw; + for (const auto& s : ss) { + raw.push_back(s.data()); + } + auto output = ctx_.GetOutput(indice_, dims.data(), dims.size()); + // note - there will be copy ... + output.FillStringTensor(raw.data(), raw.size()); + } + const Span& AsSpan() { + ORT_CXX_API_THROW("span for TensorT of string view not implemented", OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + std::string_view AsScalar() { + if (input_string_views_.size() != 1) { + ORT_CXX_API_THROW("invalid shape while trying to get a scalar string view from Ort::Custom::Tensor", + OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + return input_string_views_[0]; + } + + private: + size_t indice_; + bool is_input_; + std::vector chars_; // for input + std::vector input_string_views_; // for input +}; + +using TensorPtr = std::unique_ptr; + +//////////////////////////// OrtCustomOpBase //////////////////////////////// + +struct OrtCustomOpBase : public OrtCustomOp { + using ConstOptionalFloatTensor = std::optional&>; + using OptionalFloatTensor = std::optional>; + + // CreateTuple + template + static typename std::enable_if>::type + CreateTuple(OrtKernelContext*, std::vector&, size_t, size_t, const std::string&) { + return std::make_tuple(); + } + + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { + std::tuple current = std::tuple{context}; + auto next = CreateTuple(context, tensors, num_input, num_output, ep); + return std::tuple_cat(current, next); + } + +#define CREATE_TUPLE_INPUT(data_type) \ + template \ + static typename std::enable_if*>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ + tensors.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast(tensors.back().get())}; \ + auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if&>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ + tensors.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast(*tensors.back().get())}; \ + auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if*>>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ + if (ith_input < num_input) { \ + tensors.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())}; \ + auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } else { \ + std::tuple current = std::tuple{}; \ + auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + } \ + template \ + static typename std::enable_if*>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ + if ("CPUExecutionProvider" != ep) { \ + ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ + } \ + tensors.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{&reinterpret_cast*>(tensors.back().get())->AsSpan()}; \ + auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if&>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ + if ("CPUExecutionProvider" != ep) { \ + ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ + } \ + tensors.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())->AsSpan()}; \ + auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if*>>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ + if (ith_input < num_input) { \ + if ("CPUExecutionProvider" != ep) { \ + ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ + } \ + tensors.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{&reinterpret_cast*>(tensors.back().get())->AsSpan()}; \ + auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } else { \ + std::tuple current = std::tuple{}; \ + auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + } \ + template \ + static typename std::enable_if::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ + if ("CPUExecutionProvider" != ep) { \ + ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ + } \ + tensors.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())->AsScalar()}; \ + auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ + if (ith_input < num_input) { \ + if ("CPUExecutionProvider" != ep) { \ + ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ + } \ + tensors.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())->AsScalar()}; \ + auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } else { \ + std::tuple current = std::tuple{}; \ + auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + } +#define CREATE_TUPLE_OUTPUT(data_type) \ + template \ + static typename std::enable_if*>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ + tensors.push_back(std::make_unique>(context, ith_output, false)); \ + std::tuple current = std::tuple{reinterpret_cast(tensors.back().get())}; \ + auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if&>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ + tensors.push_back(std::make_unique>(context, ith_output, false)); \ + std::tuple current = std::tuple{reinterpret_cast(*tensors.back().get())}; \ + auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if*>>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ + if (ith_output < num_output) { \ + tensors.push_back(std::make_unique>(context, ith_output, false)); \ + std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())}; \ + auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } else { \ + std::tuple current = std::tuple{}; \ + auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + } +#define CREATE_TUPLE(data_type) \ + CREATE_TUPLE_INPUT(data_type) \ + CREATE_TUPLE_OUTPUT(data_type) + + CREATE_TUPLE(bool) + CREATE_TUPLE(float) + CREATE_TUPLE(Ort::Float16_t) + CREATE_TUPLE(Ort::BFloat16_t) + CREATE_TUPLE(double) + CREATE_TUPLE(int8_t) + CREATE_TUPLE(int16_t) + CREATE_TUPLE(int32_t) + CREATE_TUPLE(int64_t) + CREATE_TUPLE(uint8_t) + CREATE_TUPLE(uint16_t) + CREATE_TUPLE(uint32_t) + CREATE_TUPLE(uint64_t) + CREATE_TUPLE(std::string) + CREATE_TUPLE_INPUT(std::string_view) + + // ParseArgs ... + template + static typename std::enable_if<0 == sizeof...(Ts)>::type + ParseArgs(std::vector&, std::vector&) { + } + + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + ParseArgs(input_types, output_types); + } + +#define PARSE_INPUT_BASE(pack_type, onnx_type) \ + template \ + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type \ + ParseArgs(std::vector& input_types, std::vector& output_types) { \ + input_types.push_back(onnx_type); \ + ParseArgs(input_types, output_types); \ + } \ + template \ + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same>::value>::type \ + ParseArgs(std::vector& input_types, std::vector& output_types) { \ + input_types.push_back(onnx_type); \ + ParseArgs(input_types, output_types); \ + } + +#define PARSE_INPUT(data_type, onnx_type) \ + PARSE_INPUT_BASE(const Custom::Tensor*, onnx_type) \ + PARSE_INPUT_BASE(const Custom::Tensor&, onnx_type) \ + PARSE_INPUT_BASE(const Custom::Span*, onnx_type) \ + PARSE_INPUT_BASE(const Custom::Span&, onnx_type) \ + PARSE_INPUT_BASE(data_type, onnx_type) + +#define PARSE_OUTPUT(data_type, onnx_type) \ + template \ + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same*>::value>::type \ + ParseArgs(std::vector& input_types, std::vector& output_types) { \ + output_types.push_back(onnx_type); \ + ParseArgs(input_types, output_types); \ + } \ + template \ + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same&>::value>::type \ + ParseArgs(std::vector& input_types, std::vector& output_types) { \ + output_types.push_back(onnx_type); \ + ParseArgs(input_types, output_types); \ + } \ + template \ + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same*>>::value>::type \ + ParseArgs(std::vector& input_types, std::vector& output_types) { \ + output_types.push_back(onnx_type); \ + ParseArgs(input_types, output_types); \ + } + +#define PARSE_ARGS(data_type, onnx_type) \ + PARSE_INPUT(data_type, onnx_type) \ + PARSE_OUTPUT(data_type, onnx_type) + + PARSE_ARGS(bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL) + PARSE_ARGS(float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) + PARSE_ARGS(Ort::Float16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) + PARSE_ARGS(Ort::BFloat16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16) + PARSE_ARGS(double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) + PARSE_ARGS(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) + PARSE_ARGS(int16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16) + PARSE_ARGS(int32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) + PARSE_ARGS(int64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) + PARSE_ARGS(uint8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) + PARSE_ARGS(uint16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16) + PARSE_ARGS(uint32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32) + PARSE_ARGS(uint64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64) + PARSE_ARGS(std::string, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) + PARSE_ARGS(std::string_view, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) // todo - remove string_view output + + OrtCustomOpBase(const char* op_name, + const char* execution_provider) : op_name_(op_name), + execution_provider_(execution_provider) { + OrtCustomOp::version = ORT_API_VERSION; + + OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast(op)->op_name_.c_str(); }; + OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtCustomOpBase*)op)->execution_provider_.c_str(); }; + OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp*, size_t) { return OrtMemTypeDefault; }; + + OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* op) { + auto self = reinterpret_cast(op); + return self->input_types_.size(); + }; + + OrtCustomOp::GetInputType = [](const OrtCustomOp* op, size_t indice) { + auto self = reinterpret_cast(op); + return self->input_types_[indice]; + }; + + OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* op) { + auto self = reinterpret_cast(op); + return self->output_types_.size(); + }; + + OrtCustomOp::GetOutputType = [](const OrtCustomOp* op, size_t indice) { + auto self = reinterpret_cast(op); + return self->output_types_[indice]; + }; + + OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp*, size_t) { + return INPUT_OUTPUT_OPTIONAL; + }; + + OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp*, size_t) { + return INPUT_OUTPUT_OPTIONAL; + }; + + OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) { return 0; }; + OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) { return 0; }; + OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) { return 0; }; + OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) { return 0; }; + } + + const std::string op_name_; + const std::string execution_provider_; + + std::vector input_types_; + std::vector output_types_; +}; + +//////////////////////////// OrtCustomFunc //////////////////////////////// +// The struct is to implement function-as-op. +// E.g. a function might be defined as: +// void Filter(const Ort::Custom::Tensor& floats_in, Ort::Custom::Tensor& floats_out) { ... } +// It could be registered this way: +// Ort::CustomOpDomain v2_domain{"v2"}; +// std::unique_ptr fil_op_ptr{Ort::Custom::CreateCustomOp("Filter", "CPUExecutionProvider", Filter)}; +// v2_domain.Add(fil_op_ptr.get()); +// session_options.Add(v2_domain); +// For the complete example, please search keyword "LiteCustomOpTest" under "/onnxruntime/test/". +template +struct OrtCustomFunc : public OrtCustomOpBase { + using ComputeFn = void (*)(Args...); + using MyType = OrtCustomFunc; + + struct Kernel { + size_t num_input_{}; + size_t num_output_{}; + ComputeFn compute_fn_{}; + std::string ep_{}; + }; + + OrtCustomFunc(const char* op_name, + const char* execution_provider, + ComputeFn compute_fn) : OrtCustomOpBase(op_name, execution_provider), + compute_fn_(compute_fn) { + ParseArgs(input_types_, output_types_); + + OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { + auto kernel = reinterpret_cast(op_kernel); + std::vector tensors; + auto t = CreateTuple<0, 0, Args...>(context, tensors, kernel->num_input_, kernel->num_output_, kernel->ep_); + std::apply([kernel](Args const&... t_args) { kernel->compute_fn_(t_args...); }, t); + }; + + OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { + auto kernel = std::make_unique(); + kernel->compute_fn_ = static_cast(this_)->compute_fn_; + Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_)); + Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_)); + auto self = static_cast(this_); + kernel->ep_ = self->execution_provider_; + return reinterpret_cast(kernel.release()); + }; + + OrtCustomOp::KernelDestroy = [](void* op_kernel) { + delete reinterpret_cast(op_kernel); + }; + } + + ComputeFn compute_fn_; +}; // struct OrtCustomFunc + +/////////////////////////// OrtCustomStruct /////////////////////////// +// The struct is to implement struct-as-op. +// E.g. a struct might be defined as: +// struct Merge { +// Merge(const OrtApi* ort_api, const OrtKernelInfo* info) {...} +// void Compute(const Ort::Custom::Tensor& strings_in, +// std::string_view string_in, +// Ort::Custom::Tensor* strings_out) {...} +// bool reverse_ = false; +// }; +// It could be registered this way: +// Ort::CustomOpDomain v2_domain{"v2"}; +// std::unique_ptr mrg_op_ptr{Ort::Custom::CreateCustomOp("Merge", "CPUExecutionProvider")}; +// v2_domain.Add(mrg_op_ptr.get()); +// session_options.Add(v2_domain); +// For the complete example, please search keyword "LiteCustomOpTest" under "/onnxruntime/test/". +template +struct OrtCustomStruct : public OrtCustomOpBase { + template + using CustomComputeFn = void (CustomOp::*)(Args...); + using MyType = OrtCustomStruct; + + struct Kernel { + size_t num_input_{}; + size_t num_output_{}; + std::unique_ptr custom_op_; + std::string ep_{}; + }; + + OrtCustomStruct(const char* op_name, + const char* execution_provider) : OrtCustomOpBase(op_name, + execution_provider) { + init(&CustomOp::Compute); + } + + template + void init(CustomComputeFn) { + ParseArgs(input_types_, output_types_); + + OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { + auto kernel = reinterpret_cast(op_kernel); + std::vector tensors; + auto t = CreateTuple<0, 0, Args...>(context, tensors, kernel->num_input_, kernel->num_output_, kernel->ep_); + std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t); + }; + + OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { + auto kernel = std::make_unique(); + Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_)); + Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_)); + kernel->custom_op_ = std::make_unique(ort_api, info); + auto self = static_cast(this_); + kernel->ep_ = self->execution_provider_; + return reinterpret_cast(kernel.release()); + }; + + OrtCustomOp::KernelDestroy = [](void* op_kernel) { + delete reinterpret_cast(op_kernel); + }; + } +}; // struct OrtCustomStruct + +/////////////////////////// CreateCustomOp //////////////////////////// + +template +OrtCustomOp* CreateCustomOp(const char* op_name, + const char* execution_provider, + void (*custom_compute_fn)(Args...)) { + using OrtCustomOpTPtr = OrtCustomFunc; + return std::make_unique(op_name, execution_provider, custom_compute_fn).release(); +} + +template +OrtCustomOp* CreateCustomOp(const char* op_name, + const char* execution_provider) { + using OrtCustomOpTPtr = OrtCustomStruct; + return std::make_unique(op_name, execution_provider).release(); +} + +} // namespace Custom +} // namespace Ort diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 6cd4044b19..2652cf5e2e 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -571,20 +571,20 @@ Status IsCompatible(const ONNX_NAMESPACE::OpSchema& schema, const OrtCustomOp* o "custom op schemas mismatch, expecting ", i + 1, i == 0 ? "st" : (i == 1 ? "nd" : "th"), " input to be of variadic type"); + ORT_RETURN_IF_NOT(formal_parameter.GetIsHomogeneous() == (op->GetVariadicInputHomogeneity(op) != 0), + "custom op schemas mismatch, expecting ", i + 1, + i == 0 ? "st" : (i == 1 ? "nd" : "th"), + " input to keep same homogeneity"); + ORT_RETURN_IF_NOT(formal_parameter.GetMinArity() == op->GetVariadicInputMinArity(op), + "custom op schemas mismatch, expecting ", i + 1, + i == 0 ? "st" : (i == 1 ? "nd" : "th"), + " input to keep same arity"); } else { ORT_RETURN_IF_NOT(formal_parameter.GetOption() == onnx::OpSchema::FormalParameterOption::Single, "custom op schemas mismatch, expecting ", i + 1, i == 0 ? "st" : (i == 1 ? "nd" : "th"), " input to be of single type"); } - ORT_RETURN_IF_NOT(formal_parameter.GetIsHomogeneous() == (op->GetVariadicOutputHomogeneity(op) != 0), - "custom op schemas mismatch, expecting ", i + 1, - i == 0 ? "st" : (i == 1 ? "nd" : "th"), - " input to keep same homogeneity"); - ORT_RETURN_IF_NOT(formal_parameter.GetMinArity() == op->GetVariadicInputMinArity(op), - "custom op schemas mismatch, expecting ", i + 1, - i == 0 ? "st" : (i == 1 ? "nd" : "th"), - " input to keep same arity"); } // check outputs const auto& output_parameters = schema.outputs(); @@ -602,20 +602,20 @@ Status IsCompatible(const ONNX_NAMESPACE::OpSchema& schema, const OrtCustomOp* o "custom op schemas mismatch, expecting ", i + 1, i == 0 ? "st" : (i == 1 ? "nd" : "th"), " output to be of variadic type"); + ORT_RETURN_IF_NOT(formal_parameter.GetIsHomogeneous() == (op->GetVariadicOutputHomogeneity(op) != 0), + "custom op schemas mismatch, expecting ", i + 1, + i == 0 ? "st" : (i == 1 ? "nd" : "th"), + " output to keep same homogeneity"); + ORT_RETURN_IF_NOT(formal_parameter.GetMinArity() == op->GetVariadicInputMinArity(op), + "custom op schemas mismatch, expecting ", i + 1, + i == 0 ? "st" : (i == 1 ? "nd" : "th"), + " output to keep same arity"); } else { ORT_RETURN_IF_NOT(formal_parameter.GetOption() == onnx::OpSchema::FormalParameterOption::Single, "custom op schemas mismatch, expecting ", i + 1, i == 0 ? "st" : (i == 1 ? "nd" : "th"), " output to be of single type"); } - ORT_RETURN_IF_NOT(formal_parameter.GetIsHomogeneous() == (op->GetVariadicOutputHomogeneity(op) != 0), - "custom op schemas mismatch, expecting ", i + 1, - i == 0 ? "st" : (i == 1 ? "nd" : "th"), - " output to keep same homogeneity"); - ORT_RETURN_IF_NOT(formal_parameter.GetMinArity() == op->GetVariadicInputMinArity(op), - "custom op schemas mismatch, expecting ", i + 1, - i == 0 ? "st" : (i == 1 ? "nd" : "th"), - " output to keep same arity"); } return Status::OK(); } diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index af1df0ed63..218a354000 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -18,6 +18,7 @@ #include "core/graph/constants.h" #include "core/session/onnxruntime_c_api.h" #include "core/session/onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_lite_custom_op.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/onnxruntime_run_options_config_keys.h" #include "core/util/thread_utils.h" @@ -2827,6 +2828,197 @@ TEST(CApiTest, TestMultiStreamInferenceSimpleSSD) { } #endif +TEST(LiteCustomOpTest, CustomFunc) { + Ort::SessionOptions session_options; + session_options.SetIntraOpNumThreads(1); + session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED); + session_options.SetLogSeverityLevel(0); +#if defined(_WIN32) + session_options.RegisterCustomOpsLibrary(ORT_TSTR("custom_op_library.dll")); +#elif defined(__APPLE__) + session_options.RegisterCustomOpsLibrary(ORT_TSTR("libcustom_op_library.dylib")); +#else + session_options.RegisterCustomOpsLibrary(ORT_TSTR("./libcustom_op_library.so")); +#endif + + Ort::Session session{*ort_env, TSTR("testdata/fuse_select_filter.onnx"), session_options}; + + const char* input_names[] = {"vector_1", "vector_2", "alpha", "indices"}; + const char* output_names[] = {"vector_filtered"}; + + float vector_1_value[] = {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f}; + int64_t vector_1_dim[] = {10}; + + float vector_2_value[] = {0.f, 1.f, 2.f, 3.f, 4.f, 5.f}; + int64_t vector_2_dim[] = {6}; + + int32_t alpha_value[] = {2}; + int64_t alpha_dim[] = {1}; + + int32_t indices_value[] = {0, 1, 2, 3, 4, 5}; + int64_t indices_dim[] = {6}; + + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + + Ort::Value input_tensors[] = { + Ort::Value::CreateTensor(memory_info, vector_1_value, 10, vector_1_dim, 1), + Ort::Value::CreateTensor(memory_info, vector_2_value, 6, vector_2_dim, 1), + Ort::Value::CreateTensor(memory_info, alpha_value, 1, alpha_dim, 1), + Ort::Value::CreateTensor(memory_info, indices_value, 6, indices_dim, 1)}; + + Ort::RunOptions run_options; + auto output_tensors = session.Run(run_options, input_names, input_tensors, 4, output_names, 1); + const auto& vector_filterred = output_tensors.at(0); + auto type_shape_info = vector_filterred.GetTensorTypeAndShapeInfo(); + const float* floats_output = static_cast(vector_filterred.GetTensorRawData()); + ASSERT_TRUE(floats_output[0] == 8); + ASSERT_TRUE(floats_output[1] == 16); +} + +struct Merge { + Merge(const OrtApi* ort_api, const OrtKernelInfo* info) { + int64_t reverse; + ORT_ENFORCE(ort_api->KernelInfoGetAttribute_int64(info, "reverse", &reverse) == nullptr); + reverse_ = reverse != 0; + } + void Compute(const Ort::Custom::Tensor& strings_in, + std::string_view string_in, + Ort::Custom::Tensor* strings_out) { + std::vector string_pool; + for (const auto& s : strings_in.Data()) { + string_pool.emplace_back(s.data(), s.size()); + } + string_pool.emplace_back(string_in.data(), string_in.size()); + if (reverse_) { + for (auto& str : string_pool) { + std::reverse(str.begin(), str.end()); + } + std::reverse(string_pool.begin(), string_pool.end()); + } + strings_out->SetStringOutput(string_pool, {static_cast(string_pool.size())}); + } + bool reverse_ = false; +}; + +TEST(LiteCustomOpTest, CustomStruct) { + const auto& ortApi = Ort::GetApi(); + + Ort::CustomOpDomain v2_domain{"v2"}; + std::unique_ptr mrg_op_ptr{Ort::Custom::CreateCustomOp("Merge", "CPUExecutionProvider")}; + v2_domain.Add(mrg_op_ptr.get()); + + Ort::SessionOptions session_options; + session_options.SetIntraOpNumThreads(1); + session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED); + session_options.Add(v2_domain); + session_options.SetLogSeverityLevel(0); + + Ort::Session session{*ort_env, TSTR("testdata/merge.onnx"), session_options}; + + const char* input_names[] = {"str_in_1", "str_in_2"}; + const char* output_names[] = {"str_out"}; + + OrtAllocator* allocator = nullptr; + ASSERT_TRUE(!ortApi.GetAllocatorWithDefaultOptions(&allocator)); + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + + int64_t str_1_dims[] = {2}; + int64_t str_2_dims[] = {1}; + + Ort::Value input_tensors[] = {Ort::Value::CreateTensor(allocator, str_1_dims, 1, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING), + Ort::Value::CreateTensor(allocator, str_2_dims, 1, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING)}; + + const char* str_1_raw[] = {"abc", "de"}; + const char* str_2_raw[] = {"fg"}; + + input_tensors[0].FillStringTensor(str_1_raw, 2); + input_tensors[1].FillStringTensor(str_2_raw, 1); + + Ort::RunOptions run_options; + auto output_tensors = session.Run(run_options, input_names, input_tensors, 2, output_names, 1); + const auto& str_out_tensor = output_tensors.at(0); + auto num_chars = str_out_tensor.GetStringTensorDataLength(); + std::vector chars(num_chars + 1, '\0'); + std::vector offsets(3); + str_out_tensor.GetStringTensorContent(static_cast(chars.data()), num_chars, offsets.data(), offsets.size()); + ASSERT_TRUE(strncmp(chars.data(), "gfedcba", 7) == 0); +} + +TEST(LiteCustomOpTest, MissingOptional) { + Ort::SessionOptions session_options; + session_options.SetIntraOpNumThreads(1); + session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED); + session_options.SetLogSeverityLevel(0); +#if defined(_WIN32) + session_options.RegisterCustomOpsLibrary(ORT_TSTR("custom_op_library.dll")); +#elif defined(__APPLE__) + session_options.RegisterCustomOpsLibrary(ORT_TSTR("libcustom_op_library.dylib")); +#else + session_options.RegisterCustomOpsLibrary(ORT_TSTR("./libcustom_op_library.so")); +#endif + + Ort::Session session(*ort_env, TSTR("testdata/optional_2.onnx"), session_options); + + const char* input_names[] = {"float_in_1", "float_in_2"}; + const char* output_names[] = {"float_out_1"}; + + float vector_1_value[] = {0.f, 1.f, 2.f}; + int64_t vector_1_dim[] = {3}; + + float vector_2_value[] = {4.f, 5.f, 6.f, 7.f}; + int64_t vector_2_dim[] = {4}; + + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + + Ort::Value input_tensors[] = { + Ort::Value::CreateTensor(memory_info, vector_1_value, vector_1_dim[0], vector_1_dim, 1), + Ort::Value::CreateTensor(memory_info, vector_2_value, vector_2_dim[0], vector_2_dim, 1)}; + + Ort::RunOptions run_options; + auto output_tensors = session.Run(run_options, input_names, input_tensors, 2, output_names, 1); + ASSERT_TRUE(output_tensors.size() == 1); +} + +TEST(LiteCustomOpTest, HasOptional) { + Ort::SessionOptions session_options; + session_options.SetIntraOpNumThreads(1); + session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED); + session_options.SetLogSeverityLevel(0); +#if defined(_WIN32) + session_options.RegisterCustomOpsLibrary(ORT_TSTR("custom_op_library.dll")); +#elif defined(__APPLE__) + session_options.RegisterCustomOpsLibrary(ORT_TSTR("libcustom_op_library.dylib")); +#else + session_options.RegisterCustomOpsLibrary(ORT_TSTR("./libcustom_op_library.so")); +#endif + + Ort::Session session(*ort_env, TSTR("testdata/optional_3.onnx"), session_options); + + const char* input_names[] = {"float_in_1", "float_in_2", "float_in_3"}; + const char* output_names[] = {"float_out_1", "float_out_2"}; + + float vector_1_value[] = {0.f, 1.f, 2.f}; + int64_t vector_1_dim[] = {3}; + + float vector_2_value[] = {4.f, 5.f, 6.f, 7.f}; + int64_t vector_2_dim[] = {4}; + + float vector_3_value[] = {8.f, 9.f}; + int64_t vector_3_dim[] = {2}; + + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + + Ort::Value input_tensors[] = { + Ort::Value::CreateTensor(memory_info, vector_1_value, vector_1_dim[0], vector_1_dim, 1), + Ort::Value::CreateTensor(memory_info, vector_2_value, vector_2_dim[0], vector_2_dim, 1), + Ort::Value::CreateTensor(memory_info, vector_3_value, vector_3_dim[0], vector_3_dim, 1), + }; + + Ort::RunOptions run_options; + auto output_tensors = session.Run(run_options, input_names, input_tensors, 3, output_names, 2); + ASSERT_TRUE(output_tensors.size() == 2); +} + #if !defined(ORT_MINIMAL_BUILD) TEST(MultiKernelSingleSchemaTest, valid) { Ort::SessionOptions session_options; diff --git a/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc b/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc index fbc520cefa..843e5a9d86 100644 --- a/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc +++ b/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc @@ -3,6 +3,7 @@ #define ORT_API_MANUAL_INIT #include "onnxruntime_cxx_api.h" #undef ORT_API_MANUAL_INIT +#include "onnxruntime_lite_custom_op.h" #include #include @@ -47,28 +48,7 @@ struct KernelOne { } }; -struct KernelTwo { - void Compute(OrtKernelContext* context) { - // Setup inputs - Ort::KernelContext ctx(context); - auto input_X = ctx.GetInput(0); - const float* X = input_X.GetTensorData(); - - // Setup output - auto dimensions = input_X.GetTensorTypeAndShapeInfo().GetShape(); - - auto output = ctx.GetOutput(0, dimensions); - int32_t* out = output.GetTensorMutableData(); - - const size_t size = output.GetTensorTypeAndShapeInfo().GetElementCount(); - - // Do computation - for (size_t i = 0; i < size; i++) { - out[i] = static_cast(round(X[i])); - } - } -}; - +// legacy custom op registration struct CustomOpOne : Ort::CustomOpBase { void* CreateKernel(const OrtApi& /* api */, const OrtKernelInfo* /* info */) const { return std::make_unique().release(); @@ -87,78 +67,97 @@ struct CustomOpOne : Ort::CustomOpBase { ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; }; -struct CustomOpTwo : Ort::CustomOpBase { - void* CreateKernel(const OrtApi& /* api */, const OrtKernelInfo* /* info */) const { - return std::make_unique().release(); - }; - - const char* GetName() const { return "CustomOpTwo"; }; - - size_t GetInputTypeCount() const { return 1; }; - ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; - - size_t GetOutputTypeCount() const { return 1; }; - ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; }; -}; - -//////////////////////////////////////////////// - -template -T MulTopCompute(const T& input_0, const T& input_1) { - return input_0 * input_1; +// lite custom op as a function +void KernelTwo(const Ort::Custom::Tensor& X, + Ort::Custom::Tensor& Y) { + const auto& shape = X.Shape(); + auto X_raw = X.Data(); + auto Y_raw = Y.Allocate(shape); + auto total = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies()); + for (int64_t i = 0; i < total; i++) { + Y_raw[i] = static_cast(round(X_raw[i])); + } } -struct MulTopKernelFloat { - MulTopKernelFloat(const OrtKernelInfo*){}; - ~MulTopKernelFloat() = default; - void Compute(OrtKernelContext* context) { - Ort::KernelContext ctx(context); - auto tensor_in = ctx.GetInput(0); - const float* float_in = tensor_in.GetTensorData(); - int64_t output_shape = 1; - auto tensor_out = ctx.GetOutput(0, &output_shape, 1); - auto float_out = tensor_out.GetTensorMutableData(); - *float_out = MulTopCompute(float_in[0], float_in[1]); +template +void MulTop(const Ort::Custom::Span& in, Ort::Custom::Tensor& out) { + out.Allocate({1})[0] = in[0] * in[1]; +} + +void Fuse( + OrtKernelContext*, + const Ort::Custom::Span& vector_1, + const Ort::Custom::Span& vector_2, + int32_t alpha, + Ort::Custom::Tensor& vector_output) { + auto len_output = std::min(vector_1.size(), vector_2.size()); + float* floats_out = static_cast(vector_output.Allocate({(int64_t)len_output})); + for (size_t i = 0; i < len_output; ++i) { + floats_out[i] = (vector_1[i] + vector_2[i]) * alpha; } -}; +} -struct MulTopOpFloat : Ort::CustomOpBase { - void* CreateKernel(const OrtApi&, const OrtKernelInfo* info) const { return new MulTopKernelFloat(info); } - const char* GetName() const { return "MulTop"; } - const char* GetExecutionProviderType() const { return "CPUExecutionProvider"; } - size_t GetInputTypeCount() const { return 1; } - ONNXTensorElementDataType GetInputType(size_t) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; } - size_t GetOutputTypeCount() const { return 1; } - ONNXTensorElementDataType GetOutputType(size_t) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; } -}; - -//////////////////////////////////////////////// - -struct MulTopKernelInt32 { - MulTopKernelInt32(const OrtKernelInfo*){}; - ~MulTopKernelInt32() = default; - void Compute(OrtKernelContext* context) { - Ort::KernelContext ctx(context); - auto tensor_in = ctx.GetInput(0); - const int32_t* int_in = tensor_in.GetTensorData(); - int64_t output_shape = 1; - auto tensor_out = ctx.GetOutput(0, &output_shape, 1); - auto int_out = tensor_out.GetTensorMutableData(); - *int_out = MulTopCompute(int_in[0], int_in[1]); +void Select(const Ort::Custom::Span& indices_in, + Ort::Custom::Tensor& indices_out) { + std::vector selected_indices; + for (size_t i = 0; i < indices_in.size(); ++i) { + if (indices_in[i] % 2 == 0) { + selected_indices.push_back(indices_in[i]); + } } -}; -struct MulTopOpInt32 : Ort::CustomOpBase { - void* CreateKernel(const OrtApi&, const OrtKernelInfo* info) const { return new MulTopKernelInt32(info); } - const char* GetName() const { return "MulTop"; } - const char* GetExecutionProviderType() const { return "CPUExecutionProvider"; } - size_t GetInputTypeCount() const { return 1; } - ONNXTensorElementDataType GetInputType(size_t) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; } - size_t GetOutputTypeCount() const { return 1; } - ONNXTensorElementDataType GetOutputType(size_t) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; } -}; + int32_t* int_out = static_cast(indices_out.Allocate({static_cast(selected_indices.size())})); + for (size_t j = 0; j < selected_indices.size(); ++j) { + int_out[j] = selected_indices[j]; + } +} -//////////////////////////////////////////////// +void Filter(const Ort::Custom::Tensor& floats_in, + Ort::Custom::Tensor& floats_out) { + const float* in = floats_in.Data(); + auto in_len = floats_in.NumberOfElement(); + + std::vector filter_floats; + for (int64_t i = 0; i < in_len; ++i) { + if (in[i] > 1.f) { + filter_floats.push_back(in[i]); + } + } + + float* out = static_cast(floats_out.Allocate({static_cast(filter_floats.size())})); + for (size_t j = 0; j < filter_floats.size(); ++j) { + out[j] = filter_floats[j]; + } +} + +void Box(const Ort::Custom::Tensor* float_in_1, + const Ort::Custom::Tensor* float_in_2, + std::optional*> float_in_3, + Ort::Custom::Tensor* float_out_1, + std::optional*> float_out_2) { + auto raw_in_1 = float_in_1->Data(); + auto raw_in_2 = float_in_2->Data(); + + auto l_in_1 = float_in_1->Shape()[0]; + auto l_in_2 = float_in_2->Shape()[0]; + auto l_out_1 = l_in_1 + l_in_2; + + auto raw_out_1 = float_out_1->Allocate({l_out_1}); + + for (int64_t i = 0; i < l_out_1; ++i) { + raw_out_1[i] = i < l_in_1 ? raw_in_1[i] : raw_in_2[i - l_in_1]; + } + + if (float_in_3.has_value() && float_out_2.has_value()) { + auto raw_in_3 = float_in_3.value()->Data(); + auto l_in_3 = float_in_3.value()->Shape()[0]; + auto l_out_2 = l_in_2 + l_in_3; + auto raw_out_2 = float_out_2.value()->Allocate({l_out_2}); + for (int64_t i = 0; i < l_out_2; ++i) { + raw_out_2[i] = i < l_in_2 ? raw_in_2[i] : raw_in_3[i - l_in_2]; + } + } +} static void AddOrtCustomOpDomainToContainer(Ort::CustomOpDomain&& domain) { static std::vector ort_custom_op_domain_container; @@ -171,21 +170,30 @@ OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtA Ort::Global::api_ = api->GetApi(ORT_API_VERSION); static const CustomOpOne c_CustomOpOne; - static const CustomOpTwo c_CustomOpTwo; + static const std::unique_ptr c_CustomOpTwo{Ort::Custom::CreateCustomOp("CustomOpTwo", "CPUExecutionProvider", KernelTwo)}; - static const MulTopOpFloat c_MulTopOpFloat; - static const MulTopOpInt32 c_MulTopOpInt32; + static const std::unique_ptr c_MulTopOpFloat{Ort::Custom::CreateCustomOp("MulTop", "CPUExecutionProvider", MulTop)}; + static const std::unique_ptr c_MulTopOpInt32{Ort::Custom::CreateCustomOp("MulTop", "CPUExecutionProvider", MulTop)}; + + static const std::unique_ptr fus_op_ptr{Ort::Custom::CreateCustomOp("Fuse", "CPUExecutionProvider", Fuse)}; + static const std::unique_ptr sel_op_ptr{Ort::Custom::CreateCustomOp("Select", "CPUExecutionProvider", Select)}; + static const std::unique_ptr fil_op_ptr{Ort::Custom::CreateCustomOp("Filter", "CPUExecutionProvider", Filter)}; + static const std::unique_ptr box_op_ptr{Ort::Custom::CreateCustomOp("Box", "CPUExecutionProvider", Box)}; OrtStatus* result = nullptr; ORT_TRY { Ort::CustomOpDomain domain{c_OpDomain}; domain.Add(&c_CustomOpOne); - domain.Add(&c_CustomOpTwo); + domain.Add(c_CustomOpTwo.get()); Ort::CustomOpDomain domain_v2{"v2"}; - domain_v2.Add(&c_MulTopOpFloat); - domain_v2.Add(&c_MulTopOpInt32); + domain_v2.Add(c_MulTopOpFloat.get()); + domain_v2.Add(c_MulTopOpInt32.get()); + domain_v2.Add(fus_op_ptr.get()); + domain_v2.Add(sel_op_ptr.get()); + domain_v2.Add(fil_op_ptr.get()); + domain_v2.Add(box_op_ptr.get()); Ort::UnownedSessionOptions session_options(options); session_options.Add(domain); diff --git a/onnxruntime/test/testdata/fuse_select_filter.onnx b/onnxruntime/test/testdata/fuse_select_filter.onnx new file mode 100644 index 0000000000..15d7dd6478 --- /dev/null +++ b/onnxruntime/test/testdata/fuse_select_filter.onnx @@ -0,0 +1,28 @@ +:Ä +P +vector_1 +vector_2 +alpha vector_fused fuse_node"Fuse* + fuse_algo :v2 +4 +indicesindices_selected select_node"Select:v2 +N + vector_fused +indices_selectedvector_gathered gather_node"GatherElements +; +vector_gatheredvector_filtered filter_node"Filter:v2graphZ +vector_1 + + ÿÿÿÿÿÿÿÿÿZ +vector_2 + + ÿÿÿÿÿÿÿÿÿZ +alpha + + ÿÿÿÿÿÿÿÿÿZ +indices + + ÿÿÿÿÿÿÿÿÿb& +vector_filtered + + ÿÿÿÿÿÿÿÿÿB \ No newline at end of file diff --git a/onnxruntime/test/testdata/merge.onnx b/onnxruntime/test/testdata/merge.onnx new file mode 100644 index 0000000000..1ae7a86d1f --- /dev/null +++ b/onnxruntime/test/testdata/merge.onnx @@ -0,0 +1,15 @@ +:¯ +D +str_in_1 +str_in_2str_out +merge_node"Merge* +reverse :v2graphZ +str_in_1 + + ÿÿÿÿÿÿÿÿÿZ +str_in_2 + + ÿÿÿÿÿÿÿÿÿb +str_out + + ÿÿÿÿÿÿÿÿÿB \ No newline at end of file diff --git a/onnxruntime/test/testdata/optional_2.onnx b/onnxruntime/test/testdata/optional_2.onnx new file mode 100644 index 0000000000..5c39419eda --- /dev/null +++ b/onnxruntime/test/testdata/optional_2.onnx @@ -0,0 +1,17 @@ +:« +8 + +float_in_1 + +float_in_2 float_out_1box_node"Box:v2graphZ! + +float_in_1 + + ÿÿÿÿÿÿÿÿÿZ! + +float_in_2 + + ÿÿÿÿÿÿÿÿÿb" + float_out_1 + + ÿÿÿÿÿÿÿÿÿB \ No newline at end of file diff --git a/onnxruntime/test/testdata/optional_3.onnx b/onnxruntime/test/testdata/optional_3.onnx new file mode 100644 index 0000000000..5aa13948bb --- /dev/null +++ b/onnxruntime/test/testdata/optional_3.onnx @@ -0,0 +1,26 @@ +:‹ +Q + +float_in_1 + +float_in_2 + +float_in_3 float_out_1 float_out_2box_node"Box:v2graphZ! + +float_in_1 + + ÿÿÿÿÿÿÿÿÿZ! + +float_in_2 + + ÿÿÿÿÿÿÿÿÿZ! + +float_in_3 + + ÿÿÿÿÿÿÿÿÿb" + float_out_1 + + ÿÿÿÿÿÿÿÿÿb" + float_out_2 + + ÿÿÿÿÿÿÿÿÿB \ No newline at end of file