mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-26 22:35:43 +00:00
Implement lite custom op API (#15590)
Implement a set of new APIs for lightweight custom ops registration, to save efforts on schema-composing. A few highlights: 1. Support build-time type inference; 2. Support function-as-op for "stateless" ops; 3. Support structure-as-op for "stateful" ops; 4. Support varied input/output forms such as span, scalar, and tensors, either optional or non-optional. --------- Co-authored-by: Randy Shuai <rashuai@microsoft.com>
This commit is contained in:
parent
0e9472d391
commit
cdf4fc49fc
8 changed files with 1074 additions and 109 deletions
679
include/onnxruntime/core/session/onnxruntime_lite_custom_op.h
Normal file
679
include/onnxruntime/core/session/onnxruntime_lite_custom_op.h
Normal file
|
|
@ -0,0 +1,679 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// Summary
|
||||
// The header has APIs to save custom op authors the trouble of defining schemas,
|
||||
// which will be inferred by functions' signature, as long as their argument list has types supported here.
|
||||
// Input could be:
|
||||
// 1. Tensor of onnx data types.
|
||||
// 2. Span of onnx data types.
|
||||
// 3. Scalar of onnx data types.
|
||||
// A input could be optional if indicated as std::optional<...>.
|
||||
// For an output, it must be a tensor of onnx data types.
|
||||
// Further, the header also has utility for a simple custom struct, where resources could be kept, to be registered as a custom op.
|
||||
// For concrete examples, please search keyword "LiteCustomOpTest" under "<cloned_src_dir>/onnxruntime/test/".
|
||||
// Note - all APIs in this header are ABI.
|
||||
|
||||
#pragma once
|
||||
#include "onnxruntime_cxx_api.h"
|
||||
#include <optional>
|
||||
#include <numeric>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace Ort {
|
||||
namespace Custom {
|
||||
|
||||
class TensorBase {
|
||||
public:
|
||||
TensorBase(OrtKernelContext* ctx) : ctx_(ctx) {}
|
||||
operator bool() const {
|
||||
return shape_.has_value();
|
||||
}
|
||||
|
||||
protected:
|
||||
struct KernelContext ctx_;
|
||||
std::optional<std::vector<int64_t>> shape_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Span {
|
||||
const T* data_ = {};
|
||||
size_t size_ = {};
|
||||
void Assign(const T* data, size_t size) {
|
||||
data_ = data;
|
||||
size_ = size;
|
||||
}
|
||||
size_t size() const { return size_; }
|
||||
T operator[](size_t indice) const {
|
||||
return data_[indice];
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Tensor : public TensorBase {
|
||||
public:
|
||||
using TT = typename std::remove_reference<T>::type;
|
||||
Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx), indice_(indice), is_input_(is_input) {
|
||||
if (is_input_) {
|
||||
if (indice >= ctx_.GetInputCount()) {
|
||||
ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
const_value_ = ctx_.GetInput(indice);
|
||||
auto type_shape_info = const_value_.GetTensorTypeAndShapeInfo();
|
||||
shape_ = type_shape_info.GetShape();
|
||||
}
|
||||
}
|
||||
const std::vector<int64_t>& Shape() const {
|
||||
if (!shape_.has_value()) {
|
||||
ORT_CXX_API_THROW("tensor shape is not yet initialized", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
return shape_.value();
|
||||
}
|
||||
int64_t NumberOfElement() const {
|
||||
if (shape_.has_value()) {
|
||||
return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies<int64_t>());
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
const TT* Data() const {
|
||||
return reinterpret_cast<const TT*>(const_value_.GetTensorRawData());
|
||||
}
|
||||
TT* Allocate(const std::vector<int64_t>& shape) {
|
||||
shape_ = shape;
|
||||
if (!data_) {
|
||||
shape_ = shape;
|
||||
data_ = ctx_.GetOutput(indice_, shape).template GetTensorMutableData<TT>();
|
||||
}
|
||||
return data_;
|
||||
}
|
||||
static TT GetT() { return (TT)0; }
|
||||
const Span<T>& AsSpan() {
|
||||
if (!shape_.has_value() || shape_->size() != 1) {
|
||||
ORT_CXX_API_THROW("invalid shape while trying to get a span out of Ort::Custom::Tensor",
|
||||
OrtErrorCode::ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
span_.Assign(Data(), static_cast<size_t>((*shape_)[0]));
|
||||
return span_;
|
||||
}
|
||||
const T& AsScalar() {
|
||||
if (!shape_.has_value() || shape_->size() != 1 || (*shape_)[0] != 1) {
|
||||
ORT_CXX_API_THROW("invalid shape while trying to get a scalar from Ort::Custom::Tensor",
|
||||
OrtErrorCode::ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
return *Data();
|
||||
}
|
||||
|
||||
private:
|
||||
size_t indice_;
|
||||
bool is_input_;
|
||||
ConstValue const_value_; // for input
|
||||
TT* data_{}; // for output
|
||||
Span<T> span_;
|
||||
};
|
||||
|
||||
template <>
|
||||
class Tensor<std::string> : public TensorBase {
|
||||
public:
|
||||
using strings = std::vector<std::string>;
|
||||
|
||||
Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx), indice_(indice), is_input_(is_input) {
|
||||
if (is_input_) {
|
||||
if (indice >= ctx_.GetInputCount()) {
|
||||
ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
auto const_value = ctx_.GetInput(indice);
|
||||
auto type_shape_info = const_value.GetTensorTypeAndShapeInfo();
|
||||
shape_ = type_shape_info.GetShape();
|
||||
auto num_chars = const_value.GetStringTensorDataLength();
|
||||
// note - there will be copy ...
|
||||
auto num_strings = static_cast<size_t>(NumberOfElement());
|
||||
if (num_strings) {
|
||||
std::vector<char> chars(num_chars + 1, '\0');
|
||||
std::vector<size_t> offsets(num_strings);
|
||||
const_value.GetStringTensorContent(static_cast<void*>(chars.data()), num_chars, offsets.data(), offsets.size());
|
||||
auto upper_bound = num_strings - 1;
|
||||
input_strings_.resize(num_strings);
|
||||
for (size_t i = upper_bound;; --i) {
|
||||
if (i < upper_bound) {
|
||||
chars[offsets[i + 1]] = '\0';
|
||||
}
|
||||
input_strings_[i] = chars.data() + offsets[i];
|
||||
if (0 == i) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
int64_t NumberOfElement() const {
|
||||
if (shape_.has_value()) {
|
||||
return std::accumulate(shape_->begin(), shape_->end(), 1ULL, std::multiplies<int64_t>());
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
const strings& Data() const {
|
||||
return input_strings_;
|
||||
}
|
||||
void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
|
||||
shape_ = dims;
|
||||
std::vector<const char*> raw;
|
||||
for (const auto& s : ss) {
|
||||
raw.push_back(s.data());
|
||||
}
|
||||
auto output = ctx_.GetOutput(indice_, dims.data(), dims.size());
|
||||
// note - there will be copy ...
|
||||
output.FillStringTensor(raw.data(), raw.size());
|
||||
}
|
||||
const Span<std::string>& AsSpan() {
|
||||
ORT_CXX_API_THROW("span for TensorT of string not implemented", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
const std::string& AsScalar() {
|
||||
if (input_strings_.size() != 1) {
|
||||
ORT_CXX_API_THROW("invalid shape while trying to get a scalar string from Ort::Custom::Tensor",
|
||||
OrtErrorCode::ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
return input_strings_[0];
|
||||
}
|
||||
|
||||
private:
|
||||
size_t indice_;
|
||||
bool is_input_;
|
||||
std::vector<std::string> input_strings_; // for input
|
||||
};
|
||||
|
||||
template <>
|
||||
class Tensor<std::string_view> : public TensorBase {
|
||||
public:
|
||||
using strings = std::vector<std::string>;
|
||||
using string_views = std::vector<std::string_view>;
|
||||
|
||||
Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx), indice_(indice), is_input_(is_input) {
|
||||
if (is_input_) {
|
||||
if (indice >= ctx_.GetInputCount()) {
|
||||
ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
auto const_value = ctx_.GetInput(indice);
|
||||
auto type_shape_info = const_value.GetTensorTypeAndShapeInfo();
|
||||
shape_ = type_shape_info.GetShape();
|
||||
auto num_chars = const_value.GetStringTensorDataLength();
|
||||
chars_.resize(num_chars + 1, '\0');
|
||||
auto num_strings = static_cast<size_t>(NumberOfElement());
|
||||
if (num_strings) {
|
||||
std::vector<size_t> offsets(num_strings);
|
||||
const_value.GetStringTensorContent(static_cast<void*>(chars_.data()), num_chars, offsets.data(), offsets.size());
|
||||
offsets.push_back(num_chars);
|
||||
for (size_t i = 0; i < num_strings; ++i) {
|
||||
input_string_views_.emplace_back(chars_.data() + offsets[i], offsets[i + 1] - offsets[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
int64_t NumberOfElement() const {
|
||||
if (shape_.has_value()) {
|
||||
return std::accumulate(shape_->begin(), shape_->end(), 1ULL, std::multiplies<int64_t>());
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
const string_views& Data() const {
|
||||
return input_string_views_;
|
||||
}
|
||||
void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
|
||||
shape_ = dims;
|
||||
std::vector<const char*> raw;
|
||||
for (const auto& s : ss) {
|
||||
raw.push_back(s.data());
|
||||
}
|
||||
auto output = ctx_.GetOutput(indice_, dims.data(), dims.size());
|
||||
// note - there will be copy ...
|
||||
output.FillStringTensor(raw.data(), raw.size());
|
||||
}
|
||||
const Span<std::string_view>& AsSpan() {
|
||||
ORT_CXX_API_THROW("span for TensorT of string view not implemented", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
std::string_view AsScalar() {
|
||||
if (input_string_views_.size() != 1) {
|
||||
ORT_CXX_API_THROW("invalid shape while trying to get a scalar string view from Ort::Custom::Tensor",
|
||||
OrtErrorCode::ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
return input_string_views_[0];
|
||||
}
|
||||
|
||||
private:
|
||||
size_t indice_;
|
||||
bool is_input_;
|
||||
std::vector<char> chars_; // for input
|
||||
std::vector<std::string_view> input_string_views_; // for input
|
||||
};
|
||||
|
||||
using TensorPtr = std::unique_ptr<Custom::TensorBase>;
|
||||
|
||||
//////////////////////////// OrtCustomOpBase ////////////////////////////////
|
||||
|
||||
struct OrtCustomOpBase : public OrtCustomOp {
|
||||
using ConstOptionalFloatTensor = std::optional<const Custom::Tensor<float>&>;
|
||||
using OptionalFloatTensor = std::optional<Custom::Tensor<float>>;
|
||||
|
||||
// CreateTuple
|
||||
template <size_t ith_input, size_t ith_output, typename... Ts>
|
||||
static typename std::enable_if<sizeof...(Ts) == 0, std::tuple<>>::type
|
||||
CreateTuple(OrtKernelContext*, std::vector<TensorPtr>&, size_t, size_t, const std::string&) {
|
||||
return std::make_tuple();
|
||||
}
|
||||
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
|
||||
static typename std::enable_if<std::is_same<T, OrtKernelContext*>::value, std::tuple<T, Ts...>>::type
|
||||
CreateTuple(OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
|
||||
std::tuple<T> current = std::tuple<OrtKernelContext*>{context};
|
||||
auto next = CreateTuple<ith_input, ith_output, Ts...>(context, tensors, num_input, num_output, ep);
|
||||
return std::tuple_cat(current, next);
|
||||
}
|
||||
|
||||
#define CREATE_TUPLE_INPUT(data_type) \
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
||||
static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
|
||||
CreateTuple(OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
|
||||
tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
|
||||
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())}; \
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, tensors, num_input, num_output, ep); \
|
||||
return std::tuple_cat(current, next); \
|
||||
} \
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
||||
static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
|
||||
CreateTuple(OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
|
||||
tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
|
||||
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())}; \
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, tensors, num_input, num_output, ep); \
|
||||
return std::tuple_cat(current, next); \
|
||||
} \
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
||||
static typename std::enable_if<std::is_same<T, std::optional<const Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
|
||||
CreateTuple(OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
|
||||
if (ith_input < num_input) { \
|
||||
tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
|
||||
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())}; \
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, tensors, num_input, num_output, ep); \
|
||||
return std::tuple_cat(current, next); \
|
||||
} else { \
|
||||
std::tuple<T> current = std::tuple<T>{}; \
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, tensors, num_input, num_output, ep); \
|
||||
return std::tuple_cat(current, next); \
|
||||
} \
|
||||
} \
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
||||
static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>*>::value, std::tuple<T, Ts...>>::type \
|
||||
CreateTuple(OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
|
||||
if ("CPUExecutionProvider" != ep) { \
|
||||
ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
|
||||
} \
|
||||
tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
|
||||
std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsSpan()}; \
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, tensors, num_input, num_output, ep); \
|
||||
return std::tuple_cat(current, next); \
|
||||
} \
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
||||
static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>&>::value, std::tuple<T, Ts...>>::type \
|
||||
CreateTuple(OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
|
||||
if ("CPUExecutionProvider" != ep) { \
|
||||
ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
|
||||
} \
|
||||
tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
|
||||
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsSpan()}; \
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, tensors, num_input, num_output, ep); \
|
||||
return std::tuple_cat(current, next); \
|
||||
} \
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
||||
static typename std::enable_if<std::is_same<T, std::optional<const Custom::Span<data_type>*>>::value, std::tuple<T, Ts...>>::type \
|
||||
CreateTuple(OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
|
||||
if (ith_input < num_input) { \
|
||||
if ("CPUExecutionProvider" != ep) { \
|
||||
ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
|
||||
} \
|
||||
tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
|
||||
std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsSpan()}; \
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, tensors, num_input, num_output, ep); \
|
||||
return std::tuple_cat(current, next); \
|
||||
} else { \
|
||||
std::tuple<T> current = std::tuple<T>{}; \
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, tensors, num_input, num_output, ep); \
|
||||
return std::tuple_cat(current, next); \
|
||||
} \
|
||||
} \
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
||||
static typename std::enable_if<std::is_same<T, data_type>::value, std::tuple<T, Ts...>>::type \
|
||||
CreateTuple(OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
|
||||
if ("CPUExecutionProvider" != ep) { \
|
||||
ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
|
||||
} \
|
||||
tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
|
||||
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsScalar()}; \
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, tensors, num_input, num_output, ep); \
|
||||
return std::tuple_cat(current, next); \
|
||||
} \
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
||||
static typename std::enable_if<std::is_same<T, std::optional<data_type>>::value, std::tuple<T, Ts...>>::type \
|
||||
CreateTuple(OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
|
||||
if (ith_input < num_input) { \
|
||||
if ("CPUExecutionProvider" != ep) { \
|
||||
ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
|
||||
} \
|
||||
tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
|
||||
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsScalar()}; \
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, tensors, num_input, num_output, ep); \
|
||||
return std::tuple_cat(current, next); \
|
||||
} else { \
|
||||
std::tuple<T> current = std::tuple<T>{}; \
|
||||
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, tensors, num_input, num_output, ep); \
|
||||
return std::tuple_cat(current, next); \
|
||||
} \
|
||||
}
|
||||
#define CREATE_TUPLE_OUTPUT(data_type) \
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
||||
static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
|
||||
CreateTuple(OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
|
||||
tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false)); \
|
||||
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())}; \
|
||||
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, tensors, num_input, num_output, ep); \
|
||||
return std::tuple_cat(current, next); \
|
||||
} \
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
||||
static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
|
||||
CreateTuple(OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
|
||||
tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false)); \
|
||||
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())}; \
|
||||
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, tensors, num_input, num_output, ep); \
|
||||
return std::tuple_cat(current, next); \
|
||||
} \
|
||||
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
||||
static typename std::enable_if<std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
|
||||
CreateTuple(OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
|
||||
if (ith_output < num_output) { \
|
||||
tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false)); \
|
||||
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())}; \
|
||||
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, tensors, num_input, num_output, ep); \
|
||||
return std::tuple_cat(current, next); \
|
||||
} else { \
|
||||
std::tuple<T> current = std::tuple<T>{}; \
|
||||
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, tensors, num_input, num_output, ep); \
|
||||
return std::tuple_cat(current, next); \
|
||||
} \
|
||||
}
|
||||
#define CREATE_TUPLE(data_type) \
|
||||
CREATE_TUPLE_INPUT(data_type) \
|
||||
CREATE_TUPLE_OUTPUT(data_type)
|
||||
|
||||
CREATE_TUPLE(bool)
|
||||
CREATE_TUPLE(float)
|
||||
CREATE_TUPLE(Ort::Float16_t)
|
||||
CREATE_TUPLE(Ort::BFloat16_t)
|
||||
CREATE_TUPLE(double)
|
||||
CREATE_TUPLE(int8_t)
|
||||
CREATE_TUPLE(int16_t)
|
||||
CREATE_TUPLE(int32_t)
|
||||
CREATE_TUPLE(int64_t)
|
||||
CREATE_TUPLE(uint8_t)
|
||||
CREATE_TUPLE(uint16_t)
|
||||
CREATE_TUPLE(uint32_t)
|
||||
CREATE_TUPLE(uint64_t)
|
||||
CREATE_TUPLE(std::string)
|
||||
CREATE_TUPLE_INPUT(std::string_view)
|
||||
|
||||
// ParseArgs ...
|
||||
template <typename... Ts>
|
||||
static typename std::enable_if<0 == sizeof...(Ts)>::type
|
||||
ParseArgs(std::vector<ONNXTensorElementDataType>&, std::vector<ONNXTensorElementDataType>&) {
|
||||
}
|
||||
|
||||
template <typename T, typename... Ts>
|
||||
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext*>::value>::type
|
||||
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
|
||||
ParseArgs<Ts...>(input_types, output_types);
|
||||
}
|
||||
|
||||
#define PARSE_INPUT_BASE(pack_type, onnx_type) \
|
||||
template <typename T, typename... Ts> \
|
||||
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, pack_type>::value>::type \
|
||||
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
|
||||
input_types.push_back(onnx_type); \
|
||||
ParseArgs<Ts...>(input_types, output_types); \
|
||||
} \
|
||||
template <typename T, typename... Ts> \
|
||||
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<pack_type>>::value>::type \
|
||||
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
|
||||
input_types.push_back(onnx_type); \
|
||||
ParseArgs<Ts...>(input_types, output_types); \
|
||||
}
|
||||
|
||||
#define PARSE_INPUT(data_type, onnx_type) \
|
||||
PARSE_INPUT_BASE(const Custom::Tensor<data_type>*, onnx_type) \
|
||||
PARSE_INPUT_BASE(const Custom::Tensor<data_type>&, onnx_type) \
|
||||
PARSE_INPUT_BASE(const Custom::Span<data_type>*, onnx_type) \
|
||||
PARSE_INPUT_BASE(const Custom::Span<data_type>&, onnx_type) \
|
||||
PARSE_INPUT_BASE(data_type, onnx_type)
|
||||
|
||||
#define PARSE_OUTPUT(data_type, onnx_type) \
|
||||
template <typename T, typename... Ts> \
|
||||
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>*>::value>::type \
|
||||
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
|
||||
output_types.push_back(onnx_type); \
|
||||
ParseArgs<Ts...>(input_types, output_types); \
|
||||
} \
|
||||
template <typename T, typename... Ts> \
|
||||
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>&>::value>::type \
|
||||
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
|
||||
output_types.push_back(onnx_type); \
|
||||
ParseArgs<Ts...>(input_types, output_types); \
|
||||
} \
|
||||
template <typename T, typename... Ts> \
|
||||
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value>::type \
|
||||
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
|
||||
output_types.push_back(onnx_type); \
|
||||
ParseArgs<Ts...>(input_types, output_types); \
|
||||
}
|
||||
|
||||
#define PARSE_ARGS(data_type, onnx_type) \
|
||||
PARSE_INPUT(data_type, onnx_type) \
|
||||
PARSE_OUTPUT(data_type, onnx_type)
|
||||
|
||||
PARSE_ARGS(bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)
|
||||
PARSE_ARGS(float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
|
||||
PARSE_ARGS(Ort::Float16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16)
|
||||
PARSE_ARGS(Ort::BFloat16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16)
|
||||
PARSE_ARGS(double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE)
|
||||
PARSE_ARGS(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8)
|
||||
PARSE_ARGS(int16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16)
|
||||
PARSE_ARGS(int32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)
|
||||
PARSE_ARGS(int64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)
|
||||
PARSE_ARGS(uint8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8)
|
||||
PARSE_ARGS(uint16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16)
|
||||
PARSE_ARGS(uint32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32)
|
||||
PARSE_ARGS(uint64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64)
|
||||
PARSE_ARGS(std::string, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING)
|
||||
PARSE_ARGS(std::string_view, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) // todo - remove string_view output
|
||||
|
||||
OrtCustomOpBase(const char* op_name,
|
||||
const char* execution_provider) : op_name_(op_name),
|
||||
execution_provider_(execution_provider) {
|
||||
OrtCustomOp::version = ORT_API_VERSION;
|
||||
|
||||
OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtCustomOpBase*>(op)->op_name_.c_str(); };
|
||||
OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtCustomOpBase*)op)->execution_provider_.c_str(); };
|
||||
OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp*, size_t) { return OrtMemTypeDefault; };
|
||||
|
||||
OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* op) {
|
||||
auto self = reinterpret_cast<const OrtCustomOpBase*>(op);
|
||||
return self->input_types_.size();
|
||||
};
|
||||
|
||||
OrtCustomOp::GetInputType = [](const OrtCustomOp* op, size_t indice) {
|
||||
auto self = reinterpret_cast<const OrtCustomOpBase*>(op);
|
||||
return self->input_types_[indice];
|
||||
};
|
||||
|
||||
OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* op) {
|
||||
auto self = reinterpret_cast<const OrtCustomOpBase*>(op);
|
||||
return self->output_types_.size();
|
||||
};
|
||||
|
||||
OrtCustomOp::GetOutputType = [](const OrtCustomOp* op, size_t indice) {
|
||||
auto self = reinterpret_cast<const OrtCustomOpBase*>(op);
|
||||
return self->output_types_[indice];
|
||||
};
|
||||
|
||||
OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp*, size_t) {
|
||||
return INPUT_OUTPUT_OPTIONAL;
|
||||
};
|
||||
|
||||
OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp*, size_t) {
|
||||
return INPUT_OUTPUT_OPTIONAL;
|
||||
};
|
||||
|
||||
OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) { return 0; };
|
||||
OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) { return 0; };
|
||||
OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) { return 0; };
|
||||
OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) { return 0; };
|
||||
}
|
||||
|
||||
const std::string op_name_;
|
||||
const std::string execution_provider_;
|
||||
|
||||
std::vector<ONNXTensorElementDataType> input_types_;
|
||||
std::vector<ONNXTensorElementDataType> output_types_;
|
||||
};
|
||||
|
||||
//////////////////////////// OrtCustomFunc ////////////////////////////////
|
||||
// The struct is to implement function-as-op.
|
||||
// E.g. a function might be defined as:
|
||||
// void Filter(const Ort::Custom::Tensor<float>& floats_in, Ort::Custom::Tensor<float>& floats_out) { ... }
|
||||
// It could be registered this way:
|
||||
// Ort::CustomOpDomain v2_domain{"v2"};
|
||||
// std::unique_ptr<OrtCustomOp> fil_op_ptr{Ort::Custom::CreateCustomOp("Filter", "CPUExecutionProvider", Filter)};
|
||||
// v2_domain.Add(fil_op_ptr.get());
|
||||
// session_options.Add(v2_domain);
|
||||
// For the complete example, please search keyword "LiteCustomOpTest" under "<cloned_src_dir>/onnxruntime/test/".
|
||||
template <typename... Args>
|
||||
struct OrtCustomFunc : public OrtCustomOpBase {
|
||||
using ComputeFn = void (*)(Args...);
|
||||
using MyType = OrtCustomFunc<Args...>;
|
||||
|
||||
struct Kernel {
|
||||
size_t num_input_{};
|
||||
size_t num_output_{};
|
||||
ComputeFn compute_fn_{};
|
||||
std::string ep_{};
|
||||
};
|
||||
|
||||
OrtCustomFunc(const char* op_name,
|
||||
const char* execution_provider,
|
||||
ComputeFn compute_fn) : OrtCustomOpBase(op_name, execution_provider),
|
||||
compute_fn_(compute_fn) {
|
||||
ParseArgs<Args...>(input_types_, output_types_);
|
||||
|
||||
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
|
||||
auto kernel = reinterpret_cast<Kernel*>(op_kernel);
|
||||
std::vector<TensorPtr> tensors;
|
||||
auto t = CreateTuple<0, 0, Args...>(context, tensors, kernel->num_input_, kernel->num_output_, kernel->ep_);
|
||||
std::apply([kernel](Args const&... t_args) { kernel->compute_fn_(t_args...); }, t);
|
||||
};
|
||||
|
||||
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
|
||||
auto kernel = std::make_unique<Kernel>();
|
||||
kernel->compute_fn_ = static_cast<const MyType*>(this_)->compute_fn_;
|
||||
Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
|
||||
Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
|
||||
auto self = static_cast<const OrtCustomFunc*>(this_);
|
||||
kernel->ep_ = self->execution_provider_;
|
||||
return reinterpret_cast<void*>(kernel.release());
|
||||
};
|
||||
|
||||
OrtCustomOp::KernelDestroy = [](void* op_kernel) {
|
||||
delete reinterpret_cast<Kernel*>(op_kernel);
|
||||
};
|
||||
}
|
||||
|
||||
ComputeFn compute_fn_;
|
||||
}; // struct OrtCustomFunc
|
||||
|
||||
/////////////////////////// OrtCustomStruct ///////////////////////////
|
||||
// The struct is to implement struct-as-op.
|
||||
// E.g. a struct might be defined as:
|
||||
// struct Merge {
|
||||
// Merge(const OrtApi* ort_api, const OrtKernelInfo* info) {...}
|
||||
// void Compute(const Ort::Custom::Tensor<std::string_view>& strings_in,
|
||||
// std::string_view string_in,
|
||||
// Ort::Custom::Tensor<std::string>* strings_out) {...}
|
||||
// bool reverse_ = false;
|
||||
// };
|
||||
// It could be registered this way:
|
||||
// Ort::CustomOpDomain v2_domain{"v2"};
|
||||
// std::unique_ptr<OrtCustomOp> mrg_op_ptr{Ort::Custom::CreateCustomOp<Merge>("Merge", "CPUExecutionProvider")};
|
||||
// v2_domain.Add(mrg_op_ptr.get());
|
||||
// session_options.Add(v2_domain);
|
||||
// For the complete example, please search keyword "LiteCustomOpTest" under "<cloned_src_dir>/onnxruntime/test/".
|
||||
template <typename CustomOp>
|
||||
struct OrtCustomStruct : public OrtCustomOpBase {
|
||||
template <typename... Args>
|
||||
using CustomComputeFn = void (CustomOp::*)(Args...);
|
||||
using MyType = OrtCustomStruct<CustomOp>;
|
||||
|
||||
struct Kernel {
|
||||
size_t num_input_{};
|
||||
size_t num_output_{};
|
||||
std::unique_ptr<CustomOp> custom_op_;
|
||||
std::string ep_{};
|
||||
};
|
||||
|
||||
OrtCustomStruct(const char* op_name,
|
||||
const char* execution_provider) : OrtCustomOpBase(op_name,
|
||||
execution_provider) {
|
||||
init(&CustomOp::Compute);
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void init(CustomComputeFn<Args...>) {
|
||||
ParseArgs<Args...>(input_types_, output_types_);
|
||||
|
||||
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
|
||||
auto kernel = reinterpret_cast<Kernel*>(op_kernel);
|
||||
std::vector<TensorPtr> tensors;
|
||||
auto t = CreateTuple<0, 0, Args...>(context, tensors, kernel->num_input_, kernel->num_output_, kernel->ep_);
|
||||
std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t);
|
||||
};
|
||||
|
||||
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
|
||||
auto kernel = std::make_unique<Kernel>();
|
||||
Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
|
||||
Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
|
||||
kernel->custom_op_ = std::make_unique<CustomOp>(ort_api, info);
|
||||
auto self = static_cast<const OrtCustomStruct*>(this_);
|
||||
kernel->ep_ = self->execution_provider_;
|
||||
return reinterpret_cast<void*>(kernel.release());
|
||||
};
|
||||
|
||||
OrtCustomOp::KernelDestroy = [](void* op_kernel) {
|
||||
delete reinterpret_cast<Kernel*>(op_kernel);
|
||||
};
|
||||
}
|
||||
}; // struct OrtCustomStruct
|
||||
|
||||
/////////////////////////// CreateCustomOp ////////////////////////////
|
||||
|
||||
template <typename... Args>
|
||||
OrtCustomOp* CreateCustomOp(const char* op_name,
|
||||
const char* execution_provider,
|
||||
void (*custom_compute_fn)(Args...)) {
|
||||
using OrtCustomOpTPtr = OrtCustomFunc<Args...>;
|
||||
return std::make_unique<OrtCustomOpTPtr>(op_name, execution_provider, custom_compute_fn).release();
|
||||
}
|
||||
|
||||
template <typename CustomOp>
|
||||
OrtCustomOp* CreateCustomOp(const char* op_name,
|
||||
const char* execution_provider) {
|
||||
using OrtCustomOpTPtr = OrtCustomStruct<CustomOp>;
|
||||
return std::make_unique<OrtCustomOpTPtr>(op_name, execution_provider).release();
|
||||
}
|
||||
|
||||
} // namespace Custom
|
||||
} // namespace Ort
|
||||
|
|
@ -571,20 +571,20 @@ Status IsCompatible(const ONNX_NAMESPACE::OpSchema& schema, const OrtCustomOp* o
|
|||
"custom op schemas mismatch, expecting ", i + 1,
|
||||
i == 0 ? "st" : (i == 1 ? "nd" : "th"),
|
||||
" input to be of variadic type");
|
||||
ORT_RETURN_IF_NOT(formal_parameter.GetIsHomogeneous() == (op->GetVariadicInputHomogeneity(op) != 0),
|
||||
"custom op schemas mismatch, expecting ", i + 1,
|
||||
i == 0 ? "st" : (i == 1 ? "nd" : "th"),
|
||||
" input to keep same homogeneity");
|
||||
ORT_RETURN_IF_NOT(formal_parameter.GetMinArity() == op->GetVariadicInputMinArity(op),
|
||||
"custom op schemas mismatch, expecting ", i + 1,
|
||||
i == 0 ? "st" : (i == 1 ? "nd" : "th"),
|
||||
" input to keep same arity");
|
||||
} else {
|
||||
ORT_RETURN_IF_NOT(formal_parameter.GetOption() == onnx::OpSchema::FormalParameterOption::Single,
|
||||
"custom op schemas mismatch, expecting ", i + 1,
|
||||
i == 0 ? "st" : (i == 1 ? "nd" : "th"),
|
||||
" input to be of single type");
|
||||
}
|
||||
ORT_RETURN_IF_NOT(formal_parameter.GetIsHomogeneous() == (op->GetVariadicOutputHomogeneity(op) != 0),
|
||||
"custom op schemas mismatch, expecting ", i + 1,
|
||||
i == 0 ? "st" : (i == 1 ? "nd" : "th"),
|
||||
" input to keep same homogeneity");
|
||||
ORT_RETURN_IF_NOT(formal_parameter.GetMinArity() == op->GetVariadicInputMinArity(op),
|
||||
"custom op schemas mismatch, expecting ", i + 1,
|
||||
i == 0 ? "st" : (i == 1 ? "nd" : "th"),
|
||||
" input to keep same arity");
|
||||
}
|
||||
// check outputs
|
||||
const auto& output_parameters = schema.outputs();
|
||||
|
|
@ -602,20 +602,20 @@ Status IsCompatible(const ONNX_NAMESPACE::OpSchema& schema, const OrtCustomOp* o
|
|||
"custom op schemas mismatch, expecting ", i + 1,
|
||||
i == 0 ? "st" : (i == 1 ? "nd" : "th"),
|
||||
" output to be of variadic type");
|
||||
ORT_RETURN_IF_NOT(formal_parameter.GetIsHomogeneous() == (op->GetVariadicOutputHomogeneity(op) != 0),
|
||||
"custom op schemas mismatch, expecting ", i + 1,
|
||||
i == 0 ? "st" : (i == 1 ? "nd" : "th"),
|
||||
" output to keep same homogeneity");
|
||||
ORT_RETURN_IF_NOT(formal_parameter.GetMinArity() == op->GetVariadicInputMinArity(op),
|
||||
"custom op schemas mismatch, expecting ", i + 1,
|
||||
i == 0 ? "st" : (i == 1 ? "nd" : "th"),
|
||||
" output to keep same arity");
|
||||
} else {
|
||||
ORT_RETURN_IF_NOT(formal_parameter.GetOption() == onnx::OpSchema::FormalParameterOption::Single,
|
||||
"custom op schemas mismatch, expecting ", i + 1,
|
||||
i == 0 ? "st" : (i == 1 ? "nd" : "th"),
|
||||
" output to be of single type");
|
||||
}
|
||||
ORT_RETURN_IF_NOT(formal_parameter.GetIsHomogeneous() == (op->GetVariadicOutputHomogeneity(op) != 0),
|
||||
"custom op schemas mismatch, expecting ", i + 1,
|
||||
i == 0 ? "st" : (i == 1 ? "nd" : "th"),
|
||||
" output to keep same homogeneity");
|
||||
ORT_RETURN_IF_NOT(formal_parameter.GetMinArity() == op->GetVariadicInputMinArity(op),
|
||||
"custom op schemas mismatch, expecting ", i + 1,
|
||||
i == 0 ? "st" : (i == 1 ? "nd" : "th"),
|
||||
" output to keep same arity");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@
|
|||
#include "core/graph/constants.h"
|
||||
#include "core/session/onnxruntime_c_api.h"
|
||||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
#include "core/session/onnxruntime_lite_custom_op.h"
|
||||
#include "core/session/onnxruntime_session_options_config_keys.h"
|
||||
#include "core/session/onnxruntime_run_options_config_keys.h"
|
||||
#include "core/util/thread_utils.h"
|
||||
|
|
@ -2827,6 +2828,197 @@ TEST(CApiTest, TestMultiStreamInferenceSimpleSSD) {
|
|||
}
|
||||
#endif
|
||||
|
||||
TEST(LiteCustomOpTest, CustomFunc) {
|
||||
Ort::SessionOptions session_options;
|
||||
session_options.SetIntraOpNumThreads(1);
|
||||
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
|
||||
session_options.SetLogSeverityLevel(0);
|
||||
#if defined(_WIN32)
|
||||
session_options.RegisterCustomOpsLibrary(ORT_TSTR("custom_op_library.dll"));
|
||||
#elif defined(__APPLE__)
|
||||
session_options.RegisterCustomOpsLibrary(ORT_TSTR("libcustom_op_library.dylib"));
|
||||
#else
|
||||
session_options.RegisterCustomOpsLibrary(ORT_TSTR("./libcustom_op_library.so"));
|
||||
#endif
|
||||
|
||||
Ort::Session session{*ort_env, TSTR("testdata/fuse_select_filter.onnx"), session_options};
|
||||
|
||||
const char* input_names[] = {"vector_1", "vector_2", "alpha", "indices"};
|
||||
const char* output_names[] = {"vector_filtered"};
|
||||
|
||||
float vector_1_value[] = {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f};
|
||||
int64_t vector_1_dim[] = {10};
|
||||
|
||||
float vector_2_value[] = {0.f, 1.f, 2.f, 3.f, 4.f, 5.f};
|
||||
int64_t vector_2_dim[] = {6};
|
||||
|
||||
int32_t alpha_value[] = {2};
|
||||
int64_t alpha_dim[] = {1};
|
||||
|
||||
int32_t indices_value[] = {0, 1, 2, 3, 4, 5};
|
||||
int64_t indices_dim[] = {6};
|
||||
|
||||
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
|
||||
|
||||
Ort::Value input_tensors[] = {
|
||||
Ort::Value::CreateTensor<float>(memory_info, vector_1_value, 10, vector_1_dim, 1),
|
||||
Ort::Value::CreateTensor<float>(memory_info, vector_2_value, 6, vector_2_dim, 1),
|
||||
Ort::Value::CreateTensor<int32_t>(memory_info, alpha_value, 1, alpha_dim, 1),
|
||||
Ort::Value::CreateTensor<int32_t>(memory_info, indices_value, 6, indices_dim, 1)};
|
||||
|
||||
Ort::RunOptions run_options;
|
||||
auto output_tensors = session.Run(run_options, input_names, input_tensors, 4, output_names, 1);
|
||||
const auto& vector_filterred = output_tensors.at(0);
|
||||
auto type_shape_info = vector_filterred.GetTensorTypeAndShapeInfo();
|
||||
const float* floats_output = static_cast<const float*>(vector_filterred.GetTensorRawData());
|
||||
ASSERT_TRUE(floats_output[0] == 8);
|
||||
ASSERT_TRUE(floats_output[1] == 16);
|
||||
}
|
||||
|
||||
struct Merge {
|
||||
Merge(const OrtApi* ort_api, const OrtKernelInfo* info) {
|
||||
int64_t reverse;
|
||||
ORT_ENFORCE(ort_api->KernelInfoGetAttribute_int64(info, "reverse", &reverse) == nullptr);
|
||||
reverse_ = reverse != 0;
|
||||
}
|
||||
void Compute(const Ort::Custom::Tensor<std::string_view>& strings_in,
|
||||
std::string_view string_in,
|
||||
Ort::Custom::Tensor<std::string>* strings_out) {
|
||||
std::vector<std::string> string_pool;
|
||||
for (const auto& s : strings_in.Data()) {
|
||||
string_pool.emplace_back(s.data(), s.size());
|
||||
}
|
||||
string_pool.emplace_back(string_in.data(), string_in.size());
|
||||
if (reverse_) {
|
||||
for (auto& str : string_pool) {
|
||||
std::reverse(str.begin(), str.end());
|
||||
}
|
||||
std::reverse(string_pool.begin(), string_pool.end());
|
||||
}
|
||||
strings_out->SetStringOutput(string_pool, {static_cast<int64_t>(string_pool.size())});
|
||||
}
|
||||
bool reverse_ = false;
|
||||
};
|
||||
|
||||
TEST(LiteCustomOpTest, CustomStruct) {
|
||||
const auto& ortApi = Ort::GetApi();
|
||||
|
||||
Ort::CustomOpDomain v2_domain{"v2"};
|
||||
std::unique_ptr<OrtCustomOp> mrg_op_ptr{Ort::Custom::CreateCustomOp<Merge>("Merge", "CPUExecutionProvider")};
|
||||
v2_domain.Add(mrg_op_ptr.get());
|
||||
|
||||
Ort::SessionOptions session_options;
|
||||
session_options.SetIntraOpNumThreads(1);
|
||||
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
|
||||
session_options.Add(v2_domain);
|
||||
session_options.SetLogSeverityLevel(0);
|
||||
|
||||
Ort::Session session{*ort_env, TSTR("testdata/merge.onnx"), session_options};
|
||||
|
||||
const char* input_names[] = {"str_in_1", "str_in_2"};
|
||||
const char* output_names[] = {"str_out"};
|
||||
|
||||
OrtAllocator* allocator = nullptr;
|
||||
ASSERT_TRUE(!ortApi.GetAllocatorWithDefaultOptions(&allocator));
|
||||
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
|
||||
|
||||
int64_t str_1_dims[] = {2};
|
||||
int64_t str_2_dims[] = {1};
|
||||
|
||||
Ort::Value input_tensors[] = {Ort::Value::CreateTensor(allocator, str_1_dims, 1, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING),
|
||||
Ort::Value::CreateTensor(allocator, str_2_dims, 1, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING)};
|
||||
|
||||
const char* str_1_raw[] = {"abc", "de"};
|
||||
const char* str_2_raw[] = {"fg"};
|
||||
|
||||
input_tensors[0].FillStringTensor(str_1_raw, 2);
|
||||
input_tensors[1].FillStringTensor(str_2_raw, 1);
|
||||
|
||||
Ort::RunOptions run_options;
|
||||
auto output_tensors = session.Run(run_options, input_names, input_tensors, 2, output_names, 1);
|
||||
const auto& str_out_tensor = output_tensors.at(0);
|
||||
auto num_chars = str_out_tensor.GetStringTensorDataLength();
|
||||
std::vector<char> chars(num_chars + 1, '\0');
|
||||
std::vector<size_t> offsets(3);
|
||||
str_out_tensor.GetStringTensorContent(static_cast<void*>(chars.data()), num_chars, offsets.data(), offsets.size());
|
||||
ASSERT_TRUE(strncmp(chars.data(), "gfedcba", 7) == 0);
|
||||
}
|
||||
|
||||
TEST(LiteCustomOpTest, MissingOptional) {
|
||||
Ort::SessionOptions session_options;
|
||||
session_options.SetIntraOpNumThreads(1);
|
||||
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
|
||||
session_options.SetLogSeverityLevel(0);
|
||||
#if defined(_WIN32)
|
||||
session_options.RegisterCustomOpsLibrary(ORT_TSTR("custom_op_library.dll"));
|
||||
#elif defined(__APPLE__)
|
||||
session_options.RegisterCustomOpsLibrary(ORT_TSTR("libcustom_op_library.dylib"));
|
||||
#else
|
||||
session_options.RegisterCustomOpsLibrary(ORT_TSTR("./libcustom_op_library.so"));
|
||||
#endif
|
||||
|
||||
Ort::Session session(*ort_env, TSTR("testdata/optional_2.onnx"), session_options);
|
||||
|
||||
const char* input_names[] = {"float_in_1", "float_in_2"};
|
||||
const char* output_names[] = {"float_out_1"};
|
||||
|
||||
float vector_1_value[] = {0.f, 1.f, 2.f};
|
||||
int64_t vector_1_dim[] = {3};
|
||||
|
||||
float vector_2_value[] = {4.f, 5.f, 6.f, 7.f};
|
||||
int64_t vector_2_dim[] = {4};
|
||||
|
||||
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
|
||||
|
||||
Ort::Value input_tensors[] = {
|
||||
Ort::Value::CreateTensor<float>(memory_info, vector_1_value, vector_1_dim[0], vector_1_dim, 1),
|
||||
Ort::Value::CreateTensor<float>(memory_info, vector_2_value, vector_2_dim[0], vector_2_dim, 1)};
|
||||
|
||||
Ort::RunOptions run_options;
|
||||
auto output_tensors = session.Run(run_options, input_names, input_tensors, 2, output_names, 1);
|
||||
ASSERT_TRUE(output_tensors.size() == 1);
|
||||
}
|
||||
|
||||
TEST(LiteCustomOpTest, HasOptional) {
|
||||
Ort::SessionOptions session_options;
|
||||
session_options.SetIntraOpNumThreads(1);
|
||||
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
|
||||
session_options.SetLogSeverityLevel(0);
|
||||
#if defined(_WIN32)
|
||||
session_options.RegisterCustomOpsLibrary(ORT_TSTR("custom_op_library.dll"));
|
||||
#elif defined(__APPLE__)
|
||||
session_options.RegisterCustomOpsLibrary(ORT_TSTR("libcustom_op_library.dylib"));
|
||||
#else
|
||||
session_options.RegisterCustomOpsLibrary(ORT_TSTR("./libcustom_op_library.so"));
|
||||
#endif
|
||||
|
||||
Ort::Session session(*ort_env, TSTR("testdata/optional_3.onnx"), session_options);
|
||||
|
||||
const char* input_names[] = {"float_in_1", "float_in_2", "float_in_3"};
|
||||
const char* output_names[] = {"float_out_1", "float_out_2"};
|
||||
|
||||
float vector_1_value[] = {0.f, 1.f, 2.f};
|
||||
int64_t vector_1_dim[] = {3};
|
||||
|
||||
float vector_2_value[] = {4.f, 5.f, 6.f, 7.f};
|
||||
int64_t vector_2_dim[] = {4};
|
||||
|
||||
float vector_3_value[] = {8.f, 9.f};
|
||||
int64_t vector_3_dim[] = {2};
|
||||
|
||||
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
|
||||
|
||||
Ort::Value input_tensors[] = {
|
||||
Ort::Value::CreateTensor<float>(memory_info, vector_1_value, vector_1_dim[0], vector_1_dim, 1),
|
||||
Ort::Value::CreateTensor<float>(memory_info, vector_2_value, vector_2_dim[0], vector_2_dim, 1),
|
||||
Ort::Value::CreateTensor<float>(memory_info, vector_3_value, vector_3_dim[0], vector_3_dim, 1),
|
||||
};
|
||||
|
||||
Ort::RunOptions run_options;
|
||||
auto output_tensors = session.Run(run_options, input_names, input_tensors, 3, output_names, 2);
|
||||
ASSERT_TRUE(output_tensors.size() == 2);
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
TEST(MultiKernelSingleSchemaTest, valid) {
|
||||
Ort::SessionOptions session_options;
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
#define ORT_API_MANUAL_INIT
|
||||
#include "onnxruntime_cxx_api.h"
|
||||
#undef ORT_API_MANUAL_INIT
|
||||
#include "onnxruntime_lite_custom_op.h"
|
||||
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
|
|
@ -47,28 +48,7 @@ struct KernelOne {
|
|||
}
|
||||
};
|
||||
|
||||
struct KernelTwo {
|
||||
void Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
Ort::KernelContext ctx(context);
|
||||
auto input_X = ctx.GetInput(0);
|
||||
const float* X = input_X.GetTensorData<float>();
|
||||
|
||||
// Setup output
|
||||
auto dimensions = input_X.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
auto output = ctx.GetOutput(0, dimensions);
|
||||
int32_t* out = output.GetTensorMutableData<int32_t>();
|
||||
|
||||
const size_t size = output.GetTensorTypeAndShapeInfo().GetElementCount();
|
||||
|
||||
// Do computation
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
out[i] = static_cast<int32_t>(round(X[i]));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// legacy custom op registration
|
||||
struct CustomOpOne : Ort::CustomOpBase<CustomOpOne, KernelOne> {
|
||||
void* CreateKernel(const OrtApi& /* api */, const OrtKernelInfo* /* info */) const {
|
||||
return std::make_unique<KernelOne>().release();
|
||||
|
|
@ -87,78 +67,97 @@ struct CustomOpOne : Ort::CustomOpBase<CustomOpOne, KernelOne> {
|
|||
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
||||
};
|
||||
|
||||
struct CustomOpTwo : Ort::CustomOpBase<CustomOpTwo, KernelTwo> {
|
||||
void* CreateKernel(const OrtApi& /* api */, const OrtKernelInfo* /* info */) const {
|
||||
return std::make_unique<CustomOpTwo>().release();
|
||||
};
|
||||
|
||||
const char* GetName() const { return "CustomOpTwo"; };
|
||||
|
||||
size_t GetInputTypeCount() const { return 1; };
|
||||
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
||||
|
||||
size_t GetOutputTypeCount() const { return 1; };
|
||||
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; };
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
T MulTopCompute(const T& input_0, const T& input_1) {
|
||||
return input_0 * input_1;
|
||||
// lite custom op as a function
|
||||
void KernelTwo(const Ort::Custom::Tensor<float>& X,
|
||||
Ort::Custom::Tensor<int32_t>& Y) {
|
||||
const auto& shape = X.Shape();
|
||||
auto X_raw = X.Data();
|
||||
auto Y_raw = Y.Allocate(shape);
|
||||
auto total = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
|
||||
for (int64_t i = 0; i < total; i++) {
|
||||
Y_raw[i] = static_cast<int32_t>(round(X_raw[i]));
|
||||
}
|
||||
}
|
||||
|
||||
struct MulTopKernelFloat {
|
||||
MulTopKernelFloat(const OrtKernelInfo*){};
|
||||
~MulTopKernelFloat() = default;
|
||||
void Compute(OrtKernelContext* context) {
|
||||
Ort::KernelContext ctx(context);
|
||||
auto tensor_in = ctx.GetInput(0);
|
||||
const float* float_in = tensor_in.GetTensorData<float>();
|
||||
int64_t output_shape = 1;
|
||||
auto tensor_out = ctx.GetOutput(0, &output_shape, 1);
|
||||
auto float_out = tensor_out.GetTensorMutableData<float>();
|
||||
*float_out = MulTopCompute(float_in[0], float_in[1]);
|
||||
template <typename T>
|
||||
void MulTop(const Ort::Custom::Span<T>& in, Ort::Custom::Tensor<T>& out) {
|
||||
out.Allocate({1})[0] = in[0] * in[1];
|
||||
}
|
||||
|
||||
void Fuse(
|
||||
OrtKernelContext*,
|
||||
const Ort::Custom::Span<float>& vector_1,
|
||||
const Ort::Custom::Span<float>& vector_2,
|
||||
int32_t alpha,
|
||||
Ort::Custom::Tensor<float>& vector_output) {
|
||||
auto len_output = std::min(vector_1.size(), vector_2.size());
|
||||
float* floats_out = static_cast<float*>(vector_output.Allocate({(int64_t)len_output}));
|
||||
for (size_t i = 0; i < len_output; ++i) {
|
||||
floats_out[i] = (vector_1[i] + vector_2[i]) * alpha;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
struct MulTopOpFloat : Ort::CustomOpBase<MulTopOpFloat, MulTopKernelFloat> {
|
||||
void* CreateKernel(const OrtApi&, const OrtKernelInfo* info) const { return new MulTopKernelFloat(info); }
|
||||
const char* GetName() const { return "MulTop"; }
|
||||
const char* GetExecutionProviderType() const { return "CPUExecutionProvider"; }
|
||||
size_t GetInputTypeCount() const { return 1; }
|
||||
ONNXTensorElementDataType GetInputType(size_t) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }
|
||||
size_t GetOutputTypeCount() const { return 1; }
|
||||
ONNXTensorElementDataType GetOutputType(size_t) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////
|
||||
|
||||
struct MulTopKernelInt32 {
|
||||
MulTopKernelInt32(const OrtKernelInfo*){};
|
||||
~MulTopKernelInt32() = default;
|
||||
void Compute(OrtKernelContext* context) {
|
||||
Ort::KernelContext ctx(context);
|
||||
auto tensor_in = ctx.GetInput(0);
|
||||
const int32_t* int_in = tensor_in.GetTensorData<int32_t>();
|
||||
int64_t output_shape = 1;
|
||||
auto tensor_out = ctx.GetOutput(0, &output_shape, 1);
|
||||
auto int_out = tensor_out.GetTensorMutableData<int32_t>();
|
||||
*int_out = MulTopCompute(int_in[0], int_in[1]);
|
||||
void Select(const Ort::Custom::Span<int32_t>& indices_in,
|
||||
Ort::Custom::Tensor<int32_t>& indices_out) {
|
||||
std::vector<int32_t> selected_indices;
|
||||
for (size_t i = 0; i < indices_in.size(); ++i) {
|
||||
if (indices_in[i] % 2 == 0) {
|
||||
selected_indices.push_back(indices_in[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct MulTopOpInt32 : Ort::CustomOpBase<MulTopOpInt32, MulTopKernelInt32> {
|
||||
void* CreateKernel(const OrtApi&, const OrtKernelInfo* info) const { return new MulTopKernelInt32(info); }
|
||||
const char* GetName() const { return "MulTop"; }
|
||||
const char* GetExecutionProviderType() const { return "CPUExecutionProvider"; }
|
||||
size_t GetInputTypeCount() const { return 1; }
|
||||
ONNXTensorElementDataType GetInputType(size_t) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; }
|
||||
size_t GetOutputTypeCount() const { return 1; }
|
||||
ONNXTensorElementDataType GetOutputType(size_t) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; }
|
||||
};
|
||||
int32_t* int_out = static_cast<int32_t*>(indices_out.Allocate({static_cast<int64_t>(selected_indices.size())}));
|
||||
for (size_t j = 0; j < selected_indices.size(); ++j) {
|
||||
int_out[j] = selected_indices[j];
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////
|
||||
void Filter(const Ort::Custom::Tensor<float>& floats_in,
|
||||
Ort::Custom::Tensor<float>& floats_out) {
|
||||
const float* in = floats_in.Data();
|
||||
auto in_len = floats_in.NumberOfElement();
|
||||
|
||||
std::vector<float> filter_floats;
|
||||
for (int64_t i = 0; i < in_len; ++i) {
|
||||
if (in[i] > 1.f) {
|
||||
filter_floats.push_back(in[i]);
|
||||
}
|
||||
}
|
||||
|
||||
float* out = static_cast<float*>(floats_out.Allocate({static_cast<int64_t>(filter_floats.size())}));
|
||||
for (size_t j = 0; j < filter_floats.size(); ++j) {
|
||||
out[j] = filter_floats[j];
|
||||
}
|
||||
}
|
||||
|
||||
void Box(const Ort::Custom::Tensor<float>* float_in_1,
|
||||
const Ort::Custom::Tensor<float>* float_in_2,
|
||||
std::optional<const Ort::Custom::Tensor<float>*> float_in_3,
|
||||
Ort::Custom::Tensor<float>* float_out_1,
|
||||
std::optional<Ort::Custom::Tensor<float>*> float_out_2) {
|
||||
auto raw_in_1 = float_in_1->Data();
|
||||
auto raw_in_2 = float_in_2->Data();
|
||||
|
||||
auto l_in_1 = float_in_1->Shape()[0];
|
||||
auto l_in_2 = float_in_2->Shape()[0];
|
||||
auto l_out_1 = l_in_1 + l_in_2;
|
||||
|
||||
auto raw_out_1 = float_out_1->Allocate({l_out_1});
|
||||
|
||||
for (int64_t i = 0; i < l_out_1; ++i) {
|
||||
raw_out_1[i] = i < l_in_1 ? raw_in_1[i] : raw_in_2[i - l_in_1];
|
||||
}
|
||||
|
||||
if (float_in_3.has_value() && float_out_2.has_value()) {
|
||||
auto raw_in_3 = float_in_3.value()->Data();
|
||||
auto l_in_3 = float_in_3.value()->Shape()[0];
|
||||
auto l_out_2 = l_in_2 + l_in_3;
|
||||
auto raw_out_2 = float_out_2.value()->Allocate({l_out_2});
|
||||
for (int64_t i = 0; i < l_out_2; ++i) {
|
||||
raw_out_2[i] = i < l_in_2 ? raw_in_2[i] : raw_in_3[i - l_in_2];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void AddOrtCustomOpDomainToContainer(Ort::CustomOpDomain&& domain) {
|
||||
static std::vector<Ort::CustomOpDomain> ort_custom_op_domain_container;
|
||||
|
|
@ -171,21 +170,30 @@ OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtA
|
|||
Ort::Global<void>::api_ = api->GetApi(ORT_API_VERSION);
|
||||
|
||||
static const CustomOpOne c_CustomOpOne;
|
||||
static const CustomOpTwo c_CustomOpTwo;
|
||||
static const std::unique_ptr<OrtCustomOp> c_CustomOpTwo{Ort::Custom::CreateCustomOp("CustomOpTwo", "CPUExecutionProvider", KernelTwo)};
|
||||
|
||||
static const MulTopOpFloat c_MulTopOpFloat;
|
||||
static const MulTopOpInt32 c_MulTopOpInt32;
|
||||
static const std::unique_ptr<OrtCustomOp> c_MulTopOpFloat{Ort::Custom::CreateCustomOp("MulTop", "CPUExecutionProvider", MulTop<float>)};
|
||||
static const std::unique_ptr<OrtCustomOp> c_MulTopOpInt32{Ort::Custom::CreateCustomOp("MulTop", "CPUExecutionProvider", MulTop<int32_t>)};
|
||||
|
||||
static const std::unique_ptr<OrtCustomOp> fus_op_ptr{Ort::Custom::CreateCustomOp("Fuse", "CPUExecutionProvider", Fuse)};
|
||||
static const std::unique_ptr<OrtCustomOp> sel_op_ptr{Ort::Custom::CreateCustomOp("Select", "CPUExecutionProvider", Select)};
|
||||
static const std::unique_ptr<OrtCustomOp> fil_op_ptr{Ort::Custom::CreateCustomOp("Filter", "CPUExecutionProvider", Filter)};
|
||||
static const std::unique_ptr<OrtCustomOp> box_op_ptr{Ort::Custom::CreateCustomOp("Box", "CPUExecutionProvider", Box)};
|
||||
|
||||
OrtStatus* result = nullptr;
|
||||
|
||||
ORT_TRY {
|
||||
Ort::CustomOpDomain domain{c_OpDomain};
|
||||
domain.Add(&c_CustomOpOne);
|
||||
domain.Add(&c_CustomOpTwo);
|
||||
domain.Add(c_CustomOpTwo.get());
|
||||
|
||||
Ort::CustomOpDomain domain_v2{"v2"};
|
||||
domain_v2.Add(&c_MulTopOpFloat);
|
||||
domain_v2.Add(&c_MulTopOpInt32);
|
||||
domain_v2.Add(c_MulTopOpFloat.get());
|
||||
domain_v2.Add(c_MulTopOpInt32.get());
|
||||
domain_v2.Add(fus_op_ptr.get());
|
||||
domain_v2.Add(sel_op_ptr.get());
|
||||
domain_v2.Add(fil_op_ptr.get());
|
||||
domain_v2.Add(box_op_ptr.get());
|
||||
|
||||
Ort::UnownedSessionOptions session_options(options);
|
||||
session_options.Add(domain);
|
||||
|
|
|
|||
28
onnxruntime/test/testdata/fuse_select_filter.onnx
vendored
Normal file
28
onnxruntime/test/testdata/fuse_select_filter.onnx
vendored
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
:Ä
|
||||
P
|
||||
vector_1
|
||||
vector_2
|
||||
alphavector_fused fuse_node"Fuse*
|
||||
fuse_algo :v2
|
||||
4
|
||||
indicesindices_selectedselect_node"Select:v2
|
||||
N
|
||||
vector_fused
|
||||
indices_selectedvector_gatheredgather_node"GatherElements
|
||||
;
|
||||
vector_gatheredvector_filteredfilter_node"Filter:v2graphZ
|
||||
vector_1
|
||||
|
||||
ÿÿÿÿÿÿÿÿÿZ
|
||||
vector_2
|
||||
|
||||
ÿÿÿÿÿÿÿÿÿZ
|
||||
alpha
|
||||
|
||||
ÿÿÿÿÿÿÿÿÿZ
|
||||
indices
|
||||
|
||||
ÿÿÿÿÿÿÿÿÿb&
|
||||
vector_filtered
|
||||
|
||||
ÿÿÿÿÿÿÿÿÿB
|
||||
15
onnxruntime/test/testdata/merge.onnx
vendored
Normal file
15
onnxruntime/test/testdata/merge.onnx
vendored
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
:¯
|
||||
D
|
||||
str_in_1
|
||||
str_in_2str_out
|
||||
merge_node"Merge*
|
||||
reverse :v2graphZ
|
||||
str_in_1
|
||||
|
||||
ÿÿÿÿÿÿÿÿÿZ
|
||||
str_in_2
|
||||
|
||||
ÿÿÿÿÿÿÿÿÿb
|
||||
str_out
|
||||
|
||||
ÿÿÿÿÿÿÿÿÿB
|
||||
17
onnxruntime/test/testdata/optional_2.onnx
vendored
Normal file
17
onnxruntime/test/testdata/optional_2.onnx
vendored
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
:«
|
||||
8
|
||||
|
||||
float_in_1
|
||||
|
||||
float_in_2float_out_1box_node"Box:v2graphZ!
|
||||
|
||||
float_in_1
|
||||
|
||||
ÿÿÿÿÿÿÿÿÿZ!
|
||||
|
||||
float_in_2
|
||||
|
||||
ÿÿÿÿÿÿÿÿÿb"
|
||||
float_out_1
|
||||
|
||||
ÿÿÿÿÿÿÿÿÿB
|
||||
26
onnxruntime/test/testdata/optional_3.onnx
vendored
Normal file
26
onnxruntime/test/testdata/optional_3.onnx
vendored
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
:‹
|
||||
Q
|
||||
|
||||
float_in_1
|
||||
|
||||
float_in_2
|
||||
|
||||
float_in_3float_out_1float_out_2box_node"Box:v2graphZ!
|
||||
|
||||
float_in_1
|
||||
|
||||
ÿÿÿÿÿÿÿÿÿZ!
|
||||
|
||||
float_in_2
|
||||
|
||||
ÿÿÿÿÿÿÿÿÿZ!
|
||||
|
||||
float_in_3
|
||||
|
||||
ÿÿÿÿÿÿÿÿÿb"
|
||||
float_out_1
|
||||
|
||||
ÿÿÿÿÿÿÿÿÿb"
|
||||
float_out_2
|
||||
|
||||
ÿÿÿÿÿÿÿÿÿB
|
||||
Loading…
Reference in a new issue