From 034698cf6a3f9a7365025ddbe6a8b5fc17faa682 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Tue, 2 May 2023 01:10:10 -0700 Subject: [PATCH] Revert "Implement lite custom op API (#15590)" (#15768) This reverts commit cdf4fc49fc0a8dfe3f2962c78ebf4b513911ca68 because it breaks the "debug_node_input_output" build in "Post Merge" pipeline --- .../core/session/onnxruntime_lite_custom_op.h | 679 ------------------ onnxruntime/core/session/custom_ops.cc | 32 +- onnxruntime/test/shared_lib/test_inference.cc | 192 ----- .../custom_op_library/custom_op_library.cc | 190 +++-- .../test/testdata/fuse_select_filter.onnx | 28 - onnxruntime/test/testdata/merge.onnx | 15 - onnxruntime/test/testdata/optional_2.onnx | 17 - onnxruntime/test/testdata/optional_3.onnx | 26 - 8 files changed, 107 insertions(+), 1072 deletions(-) delete mode 100644 include/onnxruntime/core/session/onnxruntime_lite_custom_op.h delete mode 100644 onnxruntime/test/testdata/fuse_select_filter.onnx delete mode 100644 onnxruntime/test/testdata/merge.onnx delete mode 100644 onnxruntime/test/testdata/optional_2.onnx delete mode 100644 onnxruntime/test/testdata/optional_3.onnx diff --git a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h deleted file mode 100644 index a98307a77b..0000000000 --- a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h +++ /dev/null @@ -1,679 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// Summary -// The header has APIs to save custom op authors the trouble of defining schemas, -// which will be inferred by functions' signature, as long as their argument list has types supported here. -// Input could be: -// 1. Tensor of onnx data types. -// 2. Span of onnx data types. -// 3. Scalar of onnx data types. -// A input could be optional if indicated as std::optional<...>. -// For an output, it must be a tensor of onnx data types. -// Further, the header also has utility for a simple custom struct, where resources could be kept, to be registered as a custom op. -// For concrete examples, please search keyword "LiteCustomOpTest" under "/onnxruntime/test/". -// Note - all APIs in this header are ABI. - -#pragma once -#include "onnxruntime_cxx_api.h" -#include -#include -#include - -namespace Ort { -namespace Custom { - -class TensorBase { - public: - TensorBase(OrtKernelContext* ctx) : ctx_(ctx) {} - operator bool() const { - return shape_.has_value(); - } - - protected: - struct KernelContext ctx_; - std::optional> shape_; -}; - -template -struct Span { - const T* data_ = {}; - size_t size_ = {}; - void Assign(const T* data, size_t size) { - data_ = data; - size_ = size; - } - size_t size() const { return size_; } - T operator[](size_t indice) const { - return data_[indice]; - } -}; - -template -class Tensor : public TensorBase { - public: - using TT = typename std::remove_reference::type; - Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx), indice_(indice), is_input_(is_input) { - if (is_input_) { - if (indice >= ctx_.GetInputCount()) { - ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT); - } - const_value_ = ctx_.GetInput(indice); - auto type_shape_info = const_value_.GetTensorTypeAndShapeInfo(); - shape_ = type_shape_info.GetShape(); - } - } - const std::vector& Shape() const { - if (!shape_.has_value()) { - ORT_CXX_API_THROW("tensor shape is not yet initialized", OrtErrorCode::ORT_RUNTIME_EXCEPTION); - } - return shape_.value(); - } - int64_t NumberOfElement() const { - if (shape_.has_value()) { - return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies()); - } else { - return 0; - } - } - const TT* Data() const { - return reinterpret_cast(const_value_.GetTensorRawData()); - } - TT* Allocate(const std::vector& shape) { - shape_ = shape; - if (!data_) { - shape_ = shape; - data_ = ctx_.GetOutput(indice_, shape).template GetTensorMutableData(); - } - return data_; - } - static TT GetT() { return (TT)0; } - const Span& AsSpan() { - if (!shape_.has_value() || shape_->size() != 1) { - ORT_CXX_API_THROW("invalid shape while trying to get a span out of Ort::Custom::Tensor", - OrtErrorCode::ORT_RUNTIME_EXCEPTION); - } - span_.Assign(Data(), static_cast((*shape_)[0])); - return span_; - } - const T& AsScalar() { - if (!shape_.has_value() || shape_->size() != 1 || (*shape_)[0] != 1) { - ORT_CXX_API_THROW("invalid shape while trying to get a scalar from Ort::Custom::Tensor", - OrtErrorCode::ORT_RUNTIME_EXCEPTION); - } - return *Data(); - } - - private: - size_t indice_; - bool is_input_; - ConstValue const_value_; // for input - TT* data_{}; // for output - Span span_; -}; - -template <> -class Tensor : public TensorBase { - public: - using strings = std::vector; - - Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx), indice_(indice), is_input_(is_input) { - if (is_input_) { - if (indice >= ctx_.GetInputCount()) { - ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT); - } - auto const_value = ctx_.GetInput(indice); - auto type_shape_info = const_value.GetTensorTypeAndShapeInfo(); - shape_ = type_shape_info.GetShape(); - auto num_chars = const_value.GetStringTensorDataLength(); - // note - there will be copy ... - auto num_strings = static_cast(NumberOfElement()); - if (num_strings) { - std::vector chars(num_chars + 1, '\0'); - std::vector offsets(num_strings); - const_value.GetStringTensorContent(static_cast(chars.data()), num_chars, offsets.data(), offsets.size()); - auto upper_bound = num_strings - 1; - input_strings_.resize(num_strings); - for (size_t i = upper_bound;; --i) { - if (i < upper_bound) { - chars[offsets[i + 1]] = '\0'; - } - input_strings_[i] = chars.data() + offsets[i]; - if (0 == i) { - break; - } - } - } - } - } - int64_t NumberOfElement() const { - if (shape_.has_value()) { - return std::accumulate(shape_->begin(), shape_->end(), 1ULL, std::multiplies()); - } else { - return 0; - } - } - const strings& Data() const { - return input_strings_; - } - void SetStringOutput(const strings& ss, const std::vector& dims) { - shape_ = dims; - std::vector raw; - for (const auto& s : ss) { - raw.push_back(s.data()); - } - auto output = ctx_.GetOutput(indice_, dims.data(), dims.size()); - // note - there will be copy ... - output.FillStringTensor(raw.data(), raw.size()); - } - const Span& AsSpan() { - ORT_CXX_API_THROW("span for TensorT of string not implemented", OrtErrorCode::ORT_RUNTIME_EXCEPTION); - } - const std::string& AsScalar() { - if (input_strings_.size() != 1) { - ORT_CXX_API_THROW("invalid shape while trying to get a scalar string from Ort::Custom::Tensor", - OrtErrorCode::ORT_RUNTIME_EXCEPTION); - } - return input_strings_[0]; - } - - private: - size_t indice_; - bool is_input_; - std::vector input_strings_; // for input -}; - -template <> -class Tensor : public TensorBase { - public: - using strings = std::vector; - using string_views = std::vector; - - Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx), indice_(indice), is_input_(is_input) { - if (is_input_) { - if (indice >= ctx_.GetInputCount()) { - ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT); - } - auto const_value = ctx_.GetInput(indice); - auto type_shape_info = const_value.GetTensorTypeAndShapeInfo(); - shape_ = type_shape_info.GetShape(); - auto num_chars = const_value.GetStringTensorDataLength(); - chars_.resize(num_chars + 1, '\0'); - auto num_strings = static_cast(NumberOfElement()); - if (num_strings) { - std::vector offsets(num_strings); - const_value.GetStringTensorContent(static_cast(chars_.data()), num_chars, offsets.data(), offsets.size()); - offsets.push_back(num_chars); - for (size_t i = 0; i < num_strings; ++i) { - input_string_views_.emplace_back(chars_.data() + offsets[i], offsets[i + 1] - offsets[i]); - } - } - } - } - int64_t NumberOfElement() const { - if (shape_.has_value()) { - return std::accumulate(shape_->begin(), shape_->end(), 1ULL, std::multiplies()); - } else { - return 0; - } - } - const string_views& Data() const { - return input_string_views_; - } - void SetStringOutput(const strings& ss, const std::vector& dims) { - shape_ = dims; - std::vector raw; - for (const auto& s : ss) { - raw.push_back(s.data()); - } - auto output = ctx_.GetOutput(indice_, dims.data(), dims.size()); - // note - there will be copy ... - output.FillStringTensor(raw.data(), raw.size()); - } - const Span& AsSpan() { - ORT_CXX_API_THROW("span for TensorT of string view not implemented", OrtErrorCode::ORT_RUNTIME_EXCEPTION); - } - std::string_view AsScalar() { - if (input_string_views_.size() != 1) { - ORT_CXX_API_THROW("invalid shape while trying to get a scalar string view from Ort::Custom::Tensor", - OrtErrorCode::ORT_RUNTIME_EXCEPTION); - } - return input_string_views_[0]; - } - - private: - size_t indice_; - bool is_input_; - std::vector chars_; // for input - std::vector input_string_views_; // for input -}; - -using TensorPtr = std::unique_ptr; - -//////////////////////////// OrtCustomOpBase //////////////////////////////// - -struct OrtCustomOpBase : public OrtCustomOp { - using ConstOptionalFloatTensor = std::optional&>; - using OptionalFloatTensor = std::optional>; - - // CreateTuple - template - static typename std::enable_if>::type - CreateTuple(OrtKernelContext*, std::vector&, size_t, size_t, const std::string&) { - return std::make_tuple(); - } - - template - static typename std::enable_if::value, std::tuple>::type - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { - std::tuple current = std::tuple{context}; - auto next = CreateTuple(context, tensors, num_input, num_output, ep); - return std::tuple_cat(current, next); - } - -#define CREATE_TUPLE_INPUT(data_type) \ - template \ - static typename std::enable_if*>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - tensors.push_back(std::make_unique>(context, ith_input, true)); \ - std::tuple current = std::tuple{reinterpret_cast(tensors.back().get())}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - template \ - static typename std::enable_if&>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - tensors.push_back(std::make_unique>(context, ith_input, true)); \ - std::tuple current = std::tuple{reinterpret_cast(*tensors.back().get())}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - template \ - static typename std::enable_if*>>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - if (ith_input < num_input) { \ - tensors.push_back(std::make_unique>(context, ith_input, true)); \ - std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } else { \ - std::tuple current = std::tuple{}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - } \ - template \ - static typename std::enable_if*>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - if ("CPUExecutionProvider" != ep) { \ - ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ - } \ - tensors.push_back(std::make_unique>(context, ith_input, true)); \ - std::tuple current = std::tuple{&reinterpret_cast*>(tensors.back().get())->AsSpan()}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - template \ - static typename std::enable_if&>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - if ("CPUExecutionProvider" != ep) { \ - ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ - } \ - tensors.push_back(std::make_unique>(context, ith_input, true)); \ - std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())->AsSpan()}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - template \ - static typename std::enable_if*>>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - if (ith_input < num_input) { \ - if ("CPUExecutionProvider" != ep) { \ - ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ - } \ - tensors.push_back(std::make_unique>(context, ith_input, true)); \ - std::tuple current = std::tuple{&reinterpret_cast*>(tensors.back().get())->AsSpan()}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } else { \ - std::tuple current = std::tuple{}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - } \ - template \ - static typename std::enable_if::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - if ("CPUExecutionProvider" != ep) { \ - ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ - } \ - tensors.push_back(std::make_unique>(context, ith_input, true)); \ - std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())->AsScalar()}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - template \ - static typename std::enable_if>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - if (ith_input < num_input) { \ - if ("CPUExecutionProvider" != ep) { \ - ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ - } \ - tensors.push_back(std::make_unique>(context, ith_input, true)); \ - std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())->AsScalar()}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } else { \ - std::tuple current = std::tuple{}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - } -#define CREATE_TUPLE_OUTPUT(data_type) \ - template \ - static typename std::enable_if*>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - tensors.push_back(std::make_unique>(context, ith_output, false)); \ - std::tuple current = std::tuple{reinterpret_cast(tensors.back().get())}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - template \ - static typename std::enable_if&>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - tensors.push_back(std::make_unique>(context, ith_output, false)); \ - std::tuple current = std::tuple{reinterpret_cast(*tensors.back().get())}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - template \ - static typename std::enable_if*>>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - if (ith_output < num_output) { \ - tensors.push_back(std::make_unique>(context, ith_output, false)); \ - std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } else { \ - std::tuple current = std::tuple{}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - } -#define CREATE_TUPLE(data_type) \ - CREATE_TUPLE_INPUT(data_type) \ - CREATE_TUPLE_OUTPUT(data_type) - - CREATE_TUPLE(bool) - CREATE_TUPLE(float) - CREATE_TUPLE(Ort::Float16_t) - CREATE_TUPLE(Ort::BFloat16_t) - CREATE_TUPLE(double) - CREATE_TUPLE(int8_t) - CREATE_TUPLE(int16_t) - CREATE_TUPLE(int32_t) - CREATE_TUPLE(int64_t) - CREATE_TUPLE(uint8_t) - CREATE_TUPLE(uint16_t) - CREATE_TUPLE(uint32_t) - CREATE_TUPLE(uint64_t) - CREATE_TUPLE(std::string) - CREATE_TUPLE_INPUT(std::string_view) - - // ParseArgs ... - template - static typename std::enable_if<0 == sizeof...(Ts)>::type - ParseArgs(std::vector&, std::vector&) { - } - - template - static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type - ParseArgs(std::vector& input_types, std::vector& output_types) { - ParseArgs(input_types, output_types); - } - -#define PARSE_INPUT_BASE(pack_type, onnx_type) \ - template \ - static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type \ - ParseArgs(std::vector& input_types, std::vector& output_types) { \ - input_types.push_back(onnx_type); \ - ParseArgs(input_types, output_types); \ - } \ - template \ - static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same>::value>::type \ - ParseArgs(std::vector& input_types, std::vector& output_types) { \ - input_types.push_back(onnx_type); \ - ParseArgs(input_types, output_types); \ - } - -#define PARSE_INPUT(data_type, onnx_type) \ - PARSE_INPUT_BASE(const Custom::Tensor*, onnx_type) \ - PARSE_INPUT_BASE(const Custom::Tensor&, onnx_type) \ - PARSE_INPUT_BASE(const Custom::Span*, onnx_type) \ - PARSE_INPUT_BASE(const Custom::Span&, onnx_type) \ - PARSE_INPUT_BASE(data_type, onnx_type) - -#define PARSE_OUTPUT(data_type, onnx_type) \ - template \ - static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same*>::value>::type \ - ParseArgs(std::vector& input_types, std::vector& output_types) { \ - output_types.push_back(onnx_type); \ - ParseArgs(input_types, output_types); \ - } \ - template \ - static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same&>::value>::type \ - ParseArgs(std::vector& input_types, std::vector& output_types) { \ - output_types.push_back(onnx_type); \ - ParseArgs(input_types, output_types); \ - } \ - template \ - static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same*>>::value>::type \ - ParseArgs(std::vector& input_types, std::vector& output_types) { \ - output_types.push_back(onnx_type); \ - ParseArgs(input_types, output_types); \ - } - -#define PARSE_ARGS(data_type, onnx_type) \ - PARSE_INPUT(data_type, onnx_type) \ - PARSE_OUTPUT(data_type, onnx_type) - - PARSE_ARGS(bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL) - PARSE_ARGS(float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) - PARSE_ARGS(Ort::Float16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) - PARSE_ARGS(Ort::BFloat16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16) - PARSE_ARGS(double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) - PARSE_ARGS(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) - PARSE_ARGS(int16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16) - PARSE_ARGS(int32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) - PARSE_ARGS(int64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) - PARSE_ARGS(uint8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) - PARSE_ARGS(uint16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16) - PARSE_ARGS(uint32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32) - PARSE_ARGS(uint64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64) - PARSE_ARGS(std::string, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) - PARSE_ARGS(std::string_view, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) // todo - remove string_view output - - OrtCustomOpBase(const char* op_name, - const char* execution_provider) : op_name_(op_name), - execution_provider_(execution_provider) { - OrtCustomOp::version = ORT_API_VERSION; - - OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast(op)->op_name_.c_str(); }; - OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtCustomOpBase*)op)->execution_provider_.c_str(); }; - OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp*, size_t) { return OrtMemTypeDefault; }; - - OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* op) { - auto self = reinterpret_cast(op); - return self->input_types_.size(); - }; - - OrtCustomOp::GetInputType = [](const OrtCustomOp* op, size_t indice) { - auto self = reinterpret_cast(op); - return self->input_types_[indice]; - }; - - OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* op) { - auto self = reinterpret_cast(op); - return self->output_types_.size(); - }; - - OrtCustomOp::GetOutputType = [](const OrtCustomOp* op, size_t indice) { - auto self = reinterpret_cast(op); - return self->output_types_[indice]; - }; - - OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp*, size_t) { - return INPUT_OUTPUT_OPTIONAL; - }; - - OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp*, size_t) { - return INPUT_OUTPUT_OPTIONAL; - }; - - OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) { return 0; }; - OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) { return 0; }; - OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) { return 0; }; - OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) { return 0; }; - } - - const std::string op_name_; - const std::string execution_provider_; - - std::vector input_types_; - std::vector output_types_; -}; - -//////////////////////////// OrtCustomFunc //////////////////////////////// -// The struct is to implement function-as-op. -// E.g. a function might be defined as: -// void Filter(const Ort::Custom::Tensor& floats_in, Ort::Custom::Tensor& floats_out) { ... } -// It could be registered this way: -// Ort::CustomOpDomain v2_domain{"v2"}; -// std::unique_ptr fil_op_ptr{Ort::Custom::CreateCustomOp("Filter", "CPUExecutionProvider", Filter)}; -// v2_domain.Add(fil_op_ptr.get()); -// session_options.Add(v2_domain); -// For the complete example, please search keyword "LiteCustomOpTest" under "/onnxruntime/test/". -template -struct OrtCustomFunc : public OrtCustomOpBase { - using ComputeFn = void (*)(Args...); - using MyType = OrtCustomFunc; - - struct Kernel { - size_t num_input_{}; - size_t num_output_{}; - ComputeFn compute_fn_{}; - std::string ep_{}; - }; - - OrtCustomFunc(const char* op_name, - const char* execution_provider, - ComputeFn compute_fn) : OrtCustomOpBase(op_name, execution_provider), - compute_fn_(compute_fn) { - ParseArgs(input_types_, output_types_); - - OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { - auto kernel = reinterpret_cast(op_kernel); - std::vector tensors; - auto t = CreateTuple<0, 0, Args...>(context, tensors, kernel->num_input_, kernel->num_output_, kernel->ep_); - std::apply([kernel](Args const&... t_args) { kernel->compute_fn_(t_args...); }, t); - }; - - OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { - auto kernel = std::make_unique(); - kernel->compute_fn_ = static_cast(this_)->compute_fn_; - Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_)); - Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_)); - auto self = static_cast(this_); - kernel->ep_ = self->execution_provider_; - return reinterpret_cast(kernel.release()); - }; - - OrtCustomOp::KernelDestroy = [](void* op_kernel) { - delete reinterpret_cast(op_kernel); - }; - } - - ComputeFn compute_fn_; -}; // struct OrtCustomFunc - -/////////////////////////// OrtCustomStruct /////////////////////////// -// The struct is to implement struct-as-op. -// E.g. a struct might be defined as: -// struct Merge { -// Merge(const OrtApi* ort_api, const OrtKernelInfo* info) {...} -// void Compute(const Ort::Custom::Tensor& strings_in, -// std::string_view string_in, -// Ort::Custom::Tensor* strings_out) {...} -// bool reverse_ = false; -// }; -// It could be registered this way: -// Ort::CustomOpDomain v2_domain{"v2"}; -// std::unique_ptr mrg_op_ptr{Ort::Custom::CreateCustomOp("Merge", "CPUExecutionProvider")}; -// v2_domain.Add(mrg_op_ptr.get()); -// session_options.Add(v2_domain); -// For the complete example, please search keyword "LiteCustomOpTest" under "/onnxruntime/test/". -template -struct OrtCustomStruct : public OrtCustomOpBase { - template - using CustomComputeFn = void (CustomOp::*)(Args...); - using MyType = OrtCustomStruct; - - struct Kernel { - size_t num_input_{}; - size_t num_output_{}; - std::unique_ptr custom_op_; - std::string ep_{}; - }; - - OrtCustomStruct(const char* op_name, - const char* execution_provider) : OrtCustomOpBase(op_name, - execution_provider) { - init(&CustomOp::Compute); - } - - template - void init(CustomComputeFn) { - ParseArgs(input_types_, output_types_); - - OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { - auto kernel = reinterpret_cast(op_kernel); - std::vector tensors; - auto t = CreateTuple<0, 0, Args...>(context, tensors, kernel->num_input_, kernel->num_output_, kernel->ep_); - std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t); - }; - - OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { - auto kernel = std::make_unique(); - Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_)); - Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_)); - kernel->custom_op_ = std::make_unique(ort_api, info); - auto self = static_cast(this_); - kernel->ep_ = self->execution_provider_; - return reinterpret_cast(kernel.release()); - }; - - OrtCustomOp::KernelDestroy = [](void* op_kernel) { - delete reinterpret_cast(op_kernel); - }; - } -}; // struct OrtCustomStruct - -/////////////////////////// CreateCustomOp //////////////////////////// - -template -OrtCustomOp* CreateCustomOp(const char* op_name, - const char* execution_provider, - void (*custom_compute_fn)(Args...)) { - using OrtCustomOpTPtr = OrtCustomFunc; - return std::make_unique(op_name, execution_provider, custom_compute_fn).release(); -} - -template -OrtCustomOp* CreateCustomOp(const char* op_name, - const char* execution_provider) { - using OrtCustomOpTPtr = OrtCustomStruct; - return std::make_unique(op_name, execution_provider).release(); -} - -} // namespace Custom -} // namespace Ort diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index c510308da0..0da67a9671 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -571,20 +571,20 @@ Status IsCompatible(const ONNX_NAMESPACE::OpSchema& schema, const OrtCustomOp* o "custom op schemas mismatch, expecting ", i + 1, i == 0 ? "st" : (i == 1 ? "nd" : "th"), " input to be of variadic type"); - ORT_RETURN_IF_NOT(formal_parameter.GetIsHomogeneous() == (op->GetVariadicInputHomogeneity(op) != 0), - "custom op schemas mismatch, expecting ", i + 1, - i == 0 ? "st" : (i == 1 ? "nd" : "th"), - " input to keep same homogeneity"); - ORT_RETURN_IF_NOT(formal_parameter.GetMinArity() == op->GetVariadicInputMinArity(op), - "custom op schemas mismatch, expecting ", i + 1, - i == 0 ? "st" : (i == 1 ? "nd" : "th"), - " input to keep same arity"); } else { ORT_RETURN_IF_NOT(formal_parameter.GetOption() == onnx::OpSchema::FormalParameterOption::Single, "custom op schemas mismatch, expecting ", i + 1, i == 0 ? "st" : (i == 1 ? "nd" : "th"), " input to be of single type"); } + ORT_RETURN_IF_NOT(formal_parameter.GetIsHomogeneous() == (op->GetVariadicOutputHomogeneity(op) != 0), + "custom op schemas mismatch, expecting ", i + 1, + i == 0 ? "st" : (i == 1 ? "nd" : "th"), + " input to keep same homogeneity"); + ORT_RETURN_IF_NOT(formal_parameter.GetMinArity() == op->GetVariadicInputMinArity(op), + "custom op schemas mismatch, expecting ", i + 1, + i == 0 ? "st" : (i == 1 ? "nd" : "th"), + " input to keep same arity"); } // check outputs const auto& output_parameters = schema.outputs(); @@ -602,20 +602,20 @@ Status IsCompatible(const ONNX_NAMESPACE::OpSchema& schema, const OrtCustomOp* o "custom op schemas mismatch, expecting ", i + 1, i == 0 ? "st" : (i == 1 ? "nd" : "th"), " output to be of variadic type"); - ORT_RETURN_IF_NOT(formal_parameter.GetIsHomogeneous() == (op->GetVariadicOutputHomogeneity(op) != 0), - "custom op schemas mismatch, expecting ", i + 1, - i == 0 ? "st" : (i == 1 ? "nd" : "th"), - " output to keep same homogeneity"); - ORT_RETURN_IF_NOT(formal_parameter.GetMinArity() == op->GetVariadicInputMinArity(op), - "custom op schemas mismatch, expecting ", i + 1, - i == 0 ? "st" : (i == 1 ? "nd" : "th"), - " output to keep same arity"); } else { ORT_RETURN_IF_NOT(formal_parameter.GetOption() == onnx::OpSchema::FormalParameterOption::Single, "custom op schemas mismatch, expecting ", i + 1, i == 0 ? "st" : (i == 1 ? "nd" : "th"), " output to be of single type"); } + ORT_RETURN_IF_NOT(formal_parameter.GetIsHomogeneous() == (op->GetVariadicOutputHomogeneity(op) != 0), + "custom op schemas mismatch, expecting ", i + 1, + i == 0 ? "st" : (i == 1 ? "nd" : "th"), + " output to keep same homogeneity"); + ORT_RETURN_IF_NOT(formal_parameter.GetMinArity() == op->GetVariadicInputMinArity(op), + "custom op schemas mismatch, expecting ", i + 1, + i == 0 ? "st" : (i == 1 ? "nd" : "th"), + " output to keep same arity"); } return Status::OK(); } diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 218a354000..af1df0ed63 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -18,7 +18,6 @@ #include "core/graph/constants.h" #include "core/session/onnxruntime_c_api.h" #include "core/session/onnxruntime_cxx_api.h" -#include "core/session/onnxruntime_lite_custom_op.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/onnxruntime_run_options_config_keys.h" #include "core/util/thread_utils.h" @@ -2828,197 +2827,6 @@ TEST(CApiTest, TestMultiStreamInferenceSimpleSSD) { } #endif -TEST(LiteCustomOpTest, CustomFunc) { - Ort::SessionOptions session_options; - session_options.SetIntraOpNumThreads(1); - session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED); - session_options.SetLogSeverityLevel(0); -#if defined(_WIN32) - session_options.RegisterCustomOpsLibrary(ORT_TSTR("custom_op_library.dll")); -#elif defined(__APPLE__) - session_options.RegisterCustomOpsLibrary(ORT_TSTR("libcustom_op_library.dylib")); -#else - session_options.RegisterCustomOpsLibrary(ORT_TSTR("./libcustom_op_library.so")); -#endif - - Ort::Session session{*ort_env, TSTR("testdata/fuse_select_filter.onnx"), session_options}; - - const char* input_names[] = {"vector_1", "vector_2", "alpha", "indices"}; - const char* output_names[] = {"vector_filtered"}; - - float vector_1_value[] = {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f}; - int64_t vector_1_dim[] = {10}; - - float vector_2_value[] = {0.f, 1.f, 2.f, 3.f, 4.f, 5.f}; - int64_t vector_2_dim[] = {6}; - - int32_t alpha_value[] = {2}; - int64_t alpha_dim[] = {1}; - - int32_t indices_value[] = {0, 1, 2, 3, 4, 5}; - int64_t indices_dim[] = {6}; - - auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); - - Ort::Value input_tensors[] = { - Ort::Value::CreateTensor(memory_info, vector_1_value, 10, vector_1_dim, 1), - Ort::Value::CreateTensor(memory_info, vector_2_value, 6, vector_2_dim, 1), - Ort::Value::CreateTensor(memory_info, alpha_value, 1, alpha_dim, 1), - Ort::Value::CreateTensor(memory_info, indices_value, 6, indices_dim, 1)}; - - Ort::RunOptions run_options; - auto output_tensors = session.Run(run_options, input_names, input_tensors, 4, output_names, 1); - const auto& vector_filterred = output_tensors.at(0); - auto type_shape_info = vector_filterred.GetTensorTypeAndShapeInfo(); - const float* floats_output = static_cast(vector_filterred.GetTensorRawData()); - ASSERT_TRUE(floats_output[0] == 8); - ASSERT_TRUE(floats_output[1] == 16); -} - -struct Merge { - Merge(const OrtApi* ort_api, const OrtKernelInfo* info) { - int64_t reverse; - ORT_ENFORCE(ort_api->KernelInfoGetAttribute_int64(info, "reverse", &reverse) == nullptr); - reverse_ = reverse != 0; - } - void Compute(const Ort::Custom::Tensor& strings_in, - std::string_view string_in, - Ort::Custom::Tensor* strings_out) { - std::vector string_pool; - for (const auto& s : strings_in.Data()) { - string_pool.emplace_back(s.data(), s.size()); - } - string_pool.emplace_back(string_in.data(), string_in.size()); - if (reverse_) { - for (auto& str : string_pool) { - std::reverse(str.begin(), str.end()); - } - std::reverse(string_pool.begin(), string_pool.end()); - } - strings_out->SetStringOutput(string_pool, {static_cast(string_pool.size())}); - } - bool reverse_ = false; -}; - -TEST(LiteCustomOpTest, CustomStruct) { - const auto& ortApi = Ort::GetApi(); - - Ort::CustomOpDomain v2_domain{"v2"}; - std::unique_ptr mrg_op_ptr{Ort::Custom::CreateCustomOp("Merge", "CPUExecutionProvider")}; - v2_domain.Add(mrg_op_ptr.get()); - - Ort::SessionOptions session_options; - session_options.SetIntraOpNumThreads(1); - session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED); - session_options.Add(v2_domain); - session_options.SetLogSeverityLevel(0); - - Ort::Session session{*ort_env, TSTR("testdata/merge.onnx"), session_options}; - - const char* input_names[] = {"str_in_1", "str_in_2"}; - const char* output_names[] = {"str_out"}; - - OrtAllocator* allocator = nullptr; - ASSERT_TRUE(!ortApi.GetAllocatorWithDefaultOptions(&allocator)); - auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); - - int64_t str_1_dims[] = {2}; - int64_t str_2_dims[] = {1}; - - Ort::Value input_tensors[] = {Ort::Value::CreateTensor(allocator, str_1_dims, 1, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING), - Ort::Value::CreateTensor(allocator, str_2_dims, 1, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING)}; - - const char* str_1_raw[] = {"abc", "de"}; - const char* str_2_raw[] = {"fg"}; - - input_tensors[0].FillStringTensor(str_1_raw, 2); - input_tensors[1].FillStringTensor(str_2_raw, 1); - - Ort::RunOptions run_options; - auto output_tensors = session.Run(run_options, input_names, input_tensors, 2, output_names, 1); - const auto& str_out_tensor = output_tensors.at(0); - auto num_chars = str_out_tensor.GetStringTensorDataLength(); - std::vector chars(num_chars + 1, '\0'); - std::vector offsets(3); - str_out_tensor.GetStringTensorContent(static_cast(chars.data()), num_chars, offsets.data(), offsets.size()); - ASSERT_TRUE(strncmp(chars.data(), "gfedcba", 7) == 0); -} - -TEST(LiteCustomOpTest, MissingOptional) { - Ort::SessionOptions session_options; - session_options.SetIntraOpNumThreads(1); - session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED); - session_options.SetLogSeverityLevel(0); -#if defined(_WIN32) - session_options.RegisterCustomOpsLibrary(ORT_TSTR("custom_op_library.dll")); -#elif defined(__APPLE__) - session_options.RegisterCustomOpsLibrary(ORT_TSTR("libcustom_op_library.dylib")); -#else - session_options.RegisterCustomOpsLibrary(ORT_TSTR("./libcustom_op_library.so")); -#endif - - Ort::Session session(*ort_env, TSTR("testdata/optional_2.onnx"), session_options); - - const char* input_names[] = {"float_in_1", "float_in_2"}; - const char* output_names[] = {"float_out_1"}; - - float vector_1_value[] = {0.f, 1.f, 2.f}; - int64_t vector_1_dim[] = {3}; - - float vector_2_value[] = {4.f, 5.f, 6.f, 7.f}; - int64_t vector_2_dim[] = {4}; - - auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); - - Ort::Value input_tensors[] = { - Ort::Value::CreateTensor(memory_info, vector_1_value, vector_1_dim[0], vector_1_dim, 1), - Ort::Value::CreateTensor(memory_info, vector_2_value, vector_2_dim[0], vector_2_dim, 1)}; - - Ort::RunOptions run_options; - auto output_tensors = session.Run(run_options, input_names, input_tensors, 2, output_names, 1); - ASSERT_TRUE(output_tensors.size() == 1); -} - -TEST(LiteCustomOpTest, HasOptional) { - Ort::SessionOptions session_options; - session_options.SetIntraOpNumThreads(1); - session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED); - session_options.SetLogSeverityLevel(0); -#if defined(_WIN32) - session_options.RegisterCustomOpsLibrary(ORT_TSTR("custom_op_library.dll")); -#elif defined(__APPLE__) - session_options.RegisterCustomOpsLibrary(ORT_TSTR("libcustom_op_library.dylib")); -#else - session_options.RegisterCustomOpsLibrary(ORT_TSTR("./libcustom_op_library.so")); -#endif - - Ort::Session session(*ort_env, TSTR("testdata/optional_3.onnx"), session_options); - - const char* input_names[] = {"float_in_1", "float_in_2", "float_in_3"}; - const char* output_names[] = {"float_out_1", "float_out_2"}; - - float vector_1_value[] = {0.f, 1.f, 2.f}; - int64_t vector_1_dim[] = {3}; - - float vector_2_value[] = {4.f, 5.f, 6.f, 7.f}; - int64_t vector_2_dim[] = {4}; - - float vector_3_value[] = {8.f, 9.f}; - int64_t vector_3_dim[] = {2}; - - auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); - - Ort::Value input_tensors[] = { - Ort::Value::CreateTensor(memory_info, vector_1_value, vector_1_dim[0], vector_1_dim, 1), - Ort::Value::CreateTensor(memory_info, vector_2_value, vector_2_dim[0], vector_2_dim, 1), - Ort::Value::CreateTensor(memory_info, vector_3_value, vector_3_dim[0], vector_3_dim, 1), - }; - - Ort::RunOptions run_options; - auto output_tensors = session.Run(run_options, input_names, input_tensors, 3, output_names, 2); - ASSERT_TRUE(output_tensors.size() == 2); -} - #if !defined(ORT_MINIMAL_BUILD) TEST(MultiKernelSingleSchemaTest, valid) { Ort::SessionOptions session_options; diff --git a/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc b/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc index 843e5a9d86..fbc520cefa 100644 --- a/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc +++ b/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc @@ -3,7 +3,6 @@ #define ORT_API_MANUAL_INIT #include "onnxruntime_cxx_api.h" #undef ORT_API_MANUAL_INIT -#include "onnxruntime_lite_custom_op.h" #include #include @@ -48,7 +47,28 @@ struct KernelOne { } }; -// legacy custom op registration +struct KernelTwo { + void Compute(OrtKernelContext* context) { + // Setup inputs + Ort::KernelContext ctx(context); + auto input_X = ctx.GetInput(0); + const float* X = input_X.GetTensorData(); + + // Setup output + auto dimensions = input_X.GetTensorTypeAndShapeInfo().GetShape(); + + auto output = ctx.GetOutput(0, dimensions); + int32_t* out = output.GetTensorMutableData(); + + const size_t size = output.GetTensorTypeAndShapeInfo().GetElementCount(); + + // Do computation + for (size_t i = 0; i < size; i++) { + out[i] = static_cast(round(X[i])); + } + } +}; + struct CustomOpOne : Ort::CustomOpBase { void* CreateKernel(const OrtApi& /* api */, const OrtKernelInfo* /* info */) const { return std::make_unique().release(); @@ -67,97 +87,78 @@ struct CustomOpOne : Ort::CustomOpBase { ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; }; -// lite custom op as a function -void KernelTwo(const Ort::Custom::Tensor& X, - Ort::Custom::Tensor& Y) { - const auto& shape = X.Shape(); - auto X_raw = X.Data(); - auto Y_raw = Y.Allocate(shape); - auto total = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies()); - for (int64_t i = 0; i < total; i++) { - Y_raw[i] = static_cast(round(X_raw[i])); - } -} +struct CustomOpTwo : Ort::CustomOpBase { + void* CreateKernel(const OrtApi& /* api */, const OrtKernelInfo* /* info */) const { + return std::make_unique().release(); + }; + + const char* GetName() const { return "CustomOpTwo"; }; + + size_t GetInputTypeCount() const { return 1; }; + ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; + + size_t GetOutputTypeCount() const { return 1; }; + ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; }; +}; + +//////////////////////////////////////////////// template -void MulTop(const Ort::Custom::Span& in, Ort::Custom::Tensor& out) { - out.Allocate({1})[0] = in[0] * in[1]; +T MulTopCompute(const T& input_0, const T& input_1) { + return input_0 * input_1; } -void Fuse( - OrtKernelContext*, - const Ort::Custom::Span& vector_1, - const Ort::Custom::Span& vector_2, - int32_t alpha, - Ort::Custom::Tensor& vector_output) { - auto len_output = std::min(vector_1.size(), vector_2.size()); - float* floats_out = static_cast(vector_output.Allocate({(int64_t)len_output})); - for (size_t i = 0; i < len_output; ++i) { - floats_out[i] = (vector_1[i] + vector_2[i]) * alpha; +struct MulTopKernelFloat { + MulTopKernelFloat(const OrtKernelInfo*){}; + ~MulTopKernelFloat() = default; + void Compute(OrtKernelContext* context) { + Ort::KernelContext ctx(context); + auto tensor_in = ctx.GetInput(0); + const float* float_in = tensor_in.GetTensorData(); + int64_t output_shape = 1; + auto tensor_out = ctx.GetOutput(0, &output_shape, 1); + auto float_out = tensor_out.GetTensorMutableData(); + *float_out = MulTopCompute(float_in[0], float_in[1]); } -} +}; -void Select(const Ort::Custom::Span& indices_in, - Ort::Custom::Tensor& indices_out) { - std::vector selected_indices; - for (size_t i = 0; i < indices_in.size(); ++i) { - if (indices_in[i] % 2 == 0) { - selected_indices.push_back(indices_in[i]); - } +struct MulTopOpFloat : Ort::CustomOpBase { + void* CreateKernel(const OrtApi&, const OrtKernelInfo* info) const { return new MulTopKernelFloat(info); } + const char* GetName() const { return "MulTop"; } + const char* GetExecutionProviderType() const { return "CPUExecutionProvider"; } + size_t GetInputTypeCount() const { return 1; } + ONNXTensorElementDataType GetInputType(size_t) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; } + size_t GetOutputTypeCount() const { return 1; } + ONNXTensorElementDataType GetOutputType(size_t) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; } +}; + +//////////////////////////////////////////////// + +struct MulTopKernelInt32 { + MulTopKernelInt32(const OrtKernelInfo*){}; + ~MulTopKernelInt32() = default; + void Compute(OrtKernelContext* context) { + Ort::KernelContext ctx(context); + auto tensor_in = ctx.GetInput(0); + const int32_t* int_in = tensor_in.GetTensorData(); + int64_t output_shape = 1; + auto tensor_out = ctx.GetOutput(0, &output_shape, 1); + auto int_out = tensor_out.GetTensorMutableData(); + *int_out = MulTopCompute(int_in[0], int_in[1]); } +}; - int32_t* int_out = static_cast(indices_out.Allocate({static_cast(selected_indices.size())})); - for (size_t j = 0; j < selected_indices.size(); ++j) { - int_out[j] = selected_indices[j]; - } -} +struct MulTopOpInt32 : Ort::CustomOpBase { + void* CreateKernel(const OrtApi&, const OrtKernelInfo* info) const { return new MulTopKernelInt32(info); } + const char* GetName() const { return "MulTop"; } + const char* GetExecutionProviderType() const { return "CPUExecutionProvider"; } + size_t GetInputTypeCount() const { return 1; } + ONNXTensorElementDataType GetInputType(size_t) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; } + size_t GetOutputTypeCount() const { return 1; } + ONNXTensorElementDataType GetOutputType(size_t) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; } +}; -void Filter(const Ort::Custom::Tensor& floats_in, - Ort::Custom::Tensor& floats_out) { - const float* in = floats_in.Data(); - auto in_len = floats_in.NumberOfElement(); - - std::vector filter_floats; - for (int64_t i = 0; i < in_len; ++i) { - if (in[i] > 1.f) { - filter_floats.push_back(in[i]); - } - } - - float* out = static_cast(floats_out.Allocate({static_cast(filter_floats.size())})); - for (size_t j = 0; j < filter_floats.size(); ++j) { - out[j] = filter_floats[j]; - } -} - -void Box(const Ort::Custom::Tensor* float_in_1, - const Ort::Custom::Tensor* float_in_2, - std::optional*> float_in_3, - Ort::Custom::Tensor* float_out_1, - std::optional*> float_out_2) { - auto raw_in_1 = float_in_1->Data(); - auto raw_in_2 = float_in_2->Data(); - - auto l_in_1 = float_in_1->Shape()[0]; - auto l_in_2 = float_in_2->Shape()[0]; - auto l_out_1 = l_in_1 + l_in_2; - - auto raw_out_1 = float_out_1->Allocate({l_out_1}); - - for (int64_t i = 0; i < l_out_1; ++i) { - raw_out_1[i] = i < l_in_1 ? raw_in_1[i] : raw_in_2[i - l_in_1]; - } - - if (float_in_3.has_value() && float_out_2.has_value()) { - auto raw_in_3 = float_in_3.value()->Data(); - auto l_in_3 = float_in_3.value()->Shape()[0]; - auto l_out_2 = l_in_2 + l_in_3; - auto raw_out_2 = float_out_2.value()->Allocate({l_out_2}); - for (int64_t i = 0; i < l_out_2; ++i) { - raw_out_2[i] = i < l_in_2 ? raw_in_2[i] : raw_in_3[i - l_in_2]; - } - } -} +//////////////////////////////////////////////// static void AddOrtCustomOpDomainToContainer(Ort::CustomOpDomain&& domain) { static std::vector ort_custom_op_domain_container; @@ -170,30 +171,21 @@ OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtA Ort::Global::api_ = api->GetApi(ORT_API_VERSION); static const CustomOpOne c_CustomOpOne; - static const std::unique_ptr c_CustomOpTwo{Ort::Custom::CreateCustomOp("CustomOpTwo", "CPUExecutionProvider", KernelTwo)}; + static const CustomOpTwo c_CustomOpTwo; - static const std::unique_ptr c_MulTopOpFloat{Ort::Custom::CreateCustomOp("MulTop", "CPUExecutionProvider", MulTop)}; - static const std::unique_ptr c_MulTopOpInt32{Ort::Custom::CreateCustomOp("MulTop", "CPUExecutionProvider", MulTop)}; - - static const std::unique_ptr fus_op_ptr{Ort::Custom::CreateCustomOp("Fuse", "CPUExecutionProvider", Fuse)}; - static const std::unique_ptr sel_op_ptr{Ort::Custom::CreateCustomOp("Select", "CPUExecutionProvider", Select)}; - static const std::unique_ptr fil_op_ptr{Ort::Custom::CreateCustomOp("Filter", "CPUExecutionProvider", Filter)}; - static const std::unique_ptr box_op_ptr{Ort::Custom::CreateCustomOp("Box", "CPUExecutionProvider", Box)}; + static const MulTopOpFloat c_MulTopOpFloat; + static const MulTopOpInt32 c_MulTopOpInt32; OrtStatus* result = nullptr; ORT_TRY { Ort::CustomOpDomain domain{c_OpDomain}; domain.Add(&c_CustomOpOne); - domain.Add(c_CustomOpTwo.get()); + domain.Add(&c_CustomOpTwo); Ort::CustomOpDomain domain_v2{"v2"}; - domain_v2.Add(c_MulTopOpFloat.get()); - domain_v2.Add(c_MulTopOpInt32.get()); - domain_v2.Add(fus_op_ptr.get()); - domain_v2.Add(sel_op_ptr.get()); - domain_v2.Add(fil_op_ptr.get()); - domain_v2.Add(box_op_ptr.get()); + domain_v2.Add(&c_MulTopOpFloat); + domain_v2.Add(&c_MulTopOpInt32); Ort::UnownedSessionOptions session_options(options); session_options.Add(domain); diff --git a/onnxruntime/test/testdata/fuse_select_filter.onnx b/onnxruntime/test/testdata/fuse_select_filter.onnx deleted file mode 100644 index 15d7dd6478..0000000000 --- a/onnxruntime/test/testdata/fuse_select_filter.onnx +++ /dev/null @@ -1,28 +0,0 @@ -:Ä -P -vector_1 -vector_2 -alpha vector_fused fuse_node"Fuse* - fuse_algo :v2 -4 -indicesindices_selected select_node"Select:v2 -N - vector_fused -indices_selectedvector_gathered gather_node"GatherElements -; -vector_gatheredvector_filtered filter_node"Filter:v2graphZ -vector_1 - - ÿÿÿÿÿÿÿÿÿZ -vector_2 - - ÿÿÿÿÿÿÿÿÿZ -alpha - - ÿÿÿÿÿÿÿÿÿZ -indices - - ÿÿÿÿÿÿÿÿÿb& -vector_filtered - - ÿÿÿÿÿÿÿÿÿB \ No newline at end of file diff --git a/onnxruntime/test/testdata/merge.onnx b/onnxruntime/test/testdata/merge.onnx deleted file mode 100644 index 1ae7a86d1f..0000000000 --- a/onnxruntime/test/testdata/merge.onnx +++ /dev/null @@ -1,15 +0,0 @@ -:¯ -D -str_in_1 -str_in_2str_out -merge_node"Merge* -reverse :v2graphZ -str_in_1 - - ÿÿÿÿÿÿÿÿÿZ -str_in_2 - - ÿÿÿÿÿÿÿÿÿb -str_out - - ÿÿÿÿÿÿÿÿÿB \ No newline at end of file diff --git a/onnxruntime/test/testdata/optional_2.onnx b/onnxruntime/test/testdata/optional_2.onnx deleted file mode 100644 index 5c39419eda..0000000000 --- a/onnxruntime/test/testdata/optional_2.onnx +++ /dev/null @@ -1,17 +0,0 @@ -:« -8 - -float_in_1 - -float_in_2 float_out_1box_node"Box:v2graphZ! - -float_in_1 - - ÿÿÿÿÿÿÿÿÿZ! - -float_in_2 - - ÿÿÿÿÿÿÿÿÿb" - float_out_1 - - ÿÿÿÿÿÿÿÿÿB \ No newline at end of file diff --git a/onnxruntime/test/testdata/optional_3.onnx b/onnxruntime/test/testdata/optional_3.onnx deleted file mode 100644 index 5aa13948bb..0000000000 --- a/onnxruntime/test/testdata/optional_3.onnx +++ /dev/null @@ -1,26 +0,0 @@ -:‹ -Q - -float_in_1 - -float_in_2 - -float_in_3 float_out_1 float_out_2box_node"Box:v2graphZ! - -float_in_1 - - ÿÿÿÿÿÿÿÿÿZ! - -float_in_2 - - ÿÿÿÿÿÿÿÿÿZ! - -float_in_3 - - ÿÿÿÿÿÿÿÿÿb" - float_out_1 - - ÿÿÿÿÿÿÿÿÿb" - float_out_2 - - ÿÿÿÿÿÿÿÿÿB \ No newline at end of file