mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
The PR is to allow custom op of different input types to have same op name in a graph. The idea to go over all ops of same name and merge their input/output types into a type-inference function. With the enhancement, custom op node inside a graph can have same op-type given that the input/output types are different. --------- Co-authored-by: Randy Shuai <rashuai@microsoft.com>
467 lines
No EOL
17 KiB
C++
467 lines
No EOL
17 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#include "core/session/onnxruntime_cxx_api.h"
|
|
#include <vector>
|
|
|
|
#ifdef USE_CUDA
|
|
#include <cuda_runtime.h>
|
|
#endif
|
|
|
|
struct Input {
|
|
const char* name = nullptr;
|
|
std::vector<int64_t> dims;
|
|
std::vector<float> values;
|
|
};
|
|
|
|
struct MyCustomKernel {
|
|
MyCustomKernel(const OrtApi& ort_api, const OrtKernelInfo* /*info*/)
|
|
: ort_(ort_api) {
|
|
}
|
|
|
|
void Compute(OrtKernelContext* context);
|
|
|
|
private:
|
|
const OrtApi& ort_;
|
|
};
|
|
|
|
struct MyCustomKernelSecondInputOnCpu {
|
|
MyCustomKernelSecondInputOnCpu(const OrtKernelInfo* /*info*/, void* compute_stream)
|
|
: compute_stream_(compute_stream) {
|
|
}
|
|
|
|
void Compute(OrtKernelContext* context);
|
|
|
|
private:
|
|
void* compute_stream_;
|
|
};
|
|
|
|
struct MyCustomOp : Ort::CustomOpBase<MyCustomOp, MyCustomKernel> {
|
|
explicit MyCustomOp(const char* provider) : provider_(provider) {}
|
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const { return new MyCustomKernel(api, info); };
|
|
const char* GetName() const { return "Foo"; };
|
|
const char* GetExecutionProviderType() const { return provider_; };
|
|
|
|
size_t GetInputTypeCount() const { return 2; };
|
|
ONNXTensorElementDataType GetInputType(size_t /*index*/) const {
|
|
// Both the inputs need to be necessarily of float type
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
|
};
|
|
|
|
size_t GetOutputTypeCount() const { return 1; };
|
|
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
|
|
|
private:
|
|
const char* provider_{"CPUExecutionProvider"};
|
|
};
|
|
|
|
struct MyCustomOpSecondInputOnCpu : Ort::CustomOpBase<MyCustomOpSecondInputOnCpu, MyCustomKernelSecondInputOnCpu> {
|
|
explicit MyCustomOpSecondInputOnCpu(const char* provider, void* compute_stream)
|
|
: provider_(provider), compute_stream_(compute_stream) {}
|
|
|
|
void* CreateKernel(const OrtApi& /* api */, const OrtKernelInfo* info) const {
|
|
return new MyCustomKernelSecondInputOnCpu(info, compute_stream_);
|
|
};
|
|
|
|
const char* GetName() const { return "Foo"; };
|
|
const char* GetExecutionProviderType() const { return provider_; };
|
|
|
|
size_t GetInputTypeCount() const { return 2; };
|
|
ONNXTensorElementDataType GetInputType(size_t /*index*/) const {
|
|
// Both the inputs need to be necessarily of float type
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
|
};
|
|
|
|
OrtMemType GetInputMemoryType(size_t i) const {
|
|
if (i == 1) {
|
|
return OrtMemTypeCPUInput;
|
|
}
|
|
return OrtMemTypeDefault;
|
|
};
|
|
|
|
size_t GetOutputTypeCount() const { return 1; };
|
|
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
|
|
|
private:
|
|
const char* provider_{"CUDAExecutionProvider"};
|
|
void* compute_stream_;
|
|
};
|
|
|
|
struct MyCustomKernelMultipleDynamicInputs {
|
|
MyCustomKernelMultipleDynamicInputs(const OrtApi& ort_api, const OrtKernelInfo* /*info*/)
|
|
: ort_(ort_api) {
|
|
}
|
|
|
|
void Compute(OrtKernelContext* context);
|
|
|
|
private:
|
|
const OrtApi& ort_;
|
|
};
|
|
|
|
struct MyCustomOpMultipleDynamicInputs : Ort::CustomOpBase<MyCustomOpMultipleDynamicInputs,
|
|
MyCustomKernelMultipleDynamicInputs> {
|
|
explicit MyCustomOpMultipleDynamicInputs(const char* provider) : provider_(provider) {}
|
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
|
return new MyCustomKernelMultipleDynamicInputs(api, info);
|
|
};
|
|
const char* GetName() const { return "Foo"; };
|
|
const char* GetExecutionProviderType() const { return provider_; };
|
|
|
|
size_t GetInputTypeCount() const { return 2; };
|
|
ONNXTensorElementDataType GetInputType(size_t /*index*/) const {
|
|
// Both the inputs are dynamic typed (i.e.) they can be any type and need not be
|
|
// homogeneous
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
|
|
};
|
|
|
|
size_t GetOutputTypeCount() const { return 1; };
|
|
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
|
|
|
private:
|
|
const char* provider_;
|
|
};
|
|
|
|
struct MyCustomKernelWithOptionalInput {
|
|
MyCustomKernelWithOptionalInput(const OrtKernelInfo* /*info*/) {
|
|
}
|
|
void Compute(OrtKernelContext* context);
|
|
};
|
|
|
|
struct MyCustomOpWithOptionalInput : Ort::CustomOpBase<MyCustomOpWithOptionalInput, MyCustomKernelWithOptionalInput> {
|
|
explicit MyCustomOpWithOptionalInput(const char* provider) : provider_(provider) {}
|
|
|
|
void* CreateKernel(const OrtApi& /* api */, const OrtKernelInfo* info) const {
|
|
return new MyCustomKernelWithOptionalInput(info);
|
|
};
|
|
|
|
const char* GetName() const { return "FooBar"; };
|
|
const char* GetExecutionProviderType() const { return provider_; };
|
|
|
|
size_t GetInputTypeCount() const { return 3; };
|
|
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
|
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t index) const {
|
|
// The second input (index == 1) is optional
|
|
if (index == 1)
|
|
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL;
|
|
|
|
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
|
|
}
|
|
|
|
size_t GetOutputTypeCount() const { return 1; };
|
|
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
|
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const {
|
|
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
|
|
}
|
|
|
|
private:
|
|
const char* provider_;
|
|
};
|
|
|
|
// Custom kernel that outputs the lengths of all input strings.
|
|
struct MyCustomStringLengthsKernel {
|
|
explicit MyCustomStringLengthsKernel(const OrtKernelInfo* /* info */) {}
|
|
|
|
void Compute(OrtKernelContext* context);
|
|
};
|
|
|
|
// Utility function to be used when testing with MyCustomStringLengthsKernel.
|
|
// Creates an input tensor from the provided input string and adds it to `ort_inputs`.
|
|
// Also initializes the corresponding expected output and I/O names.
|
|
void AddInputForCustomStringLengthsKernel(std::string input_str, OrtAllocator* allocator,
|
|
std::vector<Ort::Value>& ort_inputs, std::vector<std::string>& input_names,
|
|
std::vector<std::string>& output_names,
|
|
std::vector<std::vector<int64_t>>& expected_dims,
|
|
std::vector<std::vector<int64_t>>& expected_outputs);
|
|
|
|
// Custom kernel that echos input arguments (shape [1]) in reversed order.
|
|
// Used to test variadic custom ops with heterogenous input types.
|
|
struct MyCustomEchoReversedArgsKernel {
|
|
explicit MyCustomEchoReversedArgsKernel(const OrtKernelInfo* /* info */) {}
|
|
void Compute(OrtKernelContext* context);
|
|
};
|
|
|
|
// Utility custom op class that can be configured with a kernel class (T) and input/output
|
|
// configurations.
|
|
template <typename T>
|
|
struct TemplatedCustomOp : Ort::CustomOpBase<TemplatedCustomOp<T>, T> {
|
|
TemplatedCustomOp(const char* op_name, std::vector<ONNXTensorElementDataType> input_types,
|
|
std::vector<OrtCustomOpInputOutputCharacteristic> input_characs, int input_min_arity,
|
|
bool input_homogeneity, std::vector<ONNXTensorElementDataType> output_types,
|
|
std::vector<OrtCustomOpInputOutputCharacteristic> output_characs, int output_min_arity,
|
|
bool output_homogeneity)
|
|
: op_name_(op_name),
|
|
input_types_(std::move(input_types)),
|
|
input_characs_(std::move(input_characs)),
|
|
input_min_arity_(input_min_arity),
|
|
input_homogeneity_(input_homogeneity),
|
|
output_types_(std::move(output_types)),
|
|
output_characs_(std::move(output_characs)),
|
|
output_min_arity_(output_min_arity),
|
|
output_homogeneity_(output_homogeneity) {}
|
|
|
|
void* CreateKernel(const OrtApi& /* api */, const OrtKernelInfo* info) const {
|
|
return new T(info);
|
|
}
|
|
|
|
const char* GetName() const noexcept { return op_name_; }
|
|
|
|
size_t GetInputTypeCount() const noexcept { return input_types_.size(); }
|
|
|
|
ONNXTensorElementDataType GetInputType(size_t index) const noexcept {
|
|
return input_types_[index];
|
|
}
|
|
|
|
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t index) const noexcept {
|
|
return input_characs_[index];
|
|
}
|
|
|
|
int GetVariadicInputMinArity() const noexcept {
|
|
return input_min_arity_;
|
|
}
|
|
|
|
bool GetVariadicInputHomogeneity() const noexcept {
|
|
return input_homogeneity_;
|
|
}
|
|
|
|
size_t GetOutputTypeCount() const noexcept { return output_types_.size(); }
|
|
|
|
ONNXTensorElementDataType GetOutputType(size_t index) const noexcept {
|
|
return output_types_[index];
|
|
}
|
|
|
|
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t index) const noexcept {
|
|
return output_characs_[index];
|
|
}
|
|
|
|
int GetVariadicOutputMinArity() const noexcept {
|
|
return output_min_arity_;
|
|
}
|
|
|
|
bool GetVariadicOutputHomogeneity() const noexcept {
|
|
return output_homogeneity_;
|
|
}
|
|
|
|
private:
|
|
const char* op_name_;
|
|
|
|
std::vector<ONNXTensorElementDataType> input_types_;
|
|
std::vector<OrtCustomOpInputOutputCharacteristic> input_characs_;
|
|
int input_min_arity_;
|
|
bool input_homogeneity_;
|
|
|
|
std::vector<ONNXTensorElementDataType> output_types_;
|
|
std::vector<OrtCustomOpInputOutputCharacteristic> output_characs_;
|
|
int output_min_arity_;
|
|
bool output_homogeneity_;
|
|
};
|
|
|
|
struct MyCustomKernelWithAttributes {
|
|
MyCustomKernelWithAttributes(const OrtKernelInfo* kernel_info) {
|
|
Ort::ConstKernelInfo info{kernel_info};
|
|
int_attr_ = info.GetAttribute<int64_t>("int_attr");
|
|
float_attr_ = info.GetAttribute<float>("float_attr");
|
|
|
|
ints_attr_ = info.GetAttributes<int64_t>("ints_attr");
|
|
floats_attr_ = info.GetAttributes<float>("floats_attr");
|
|
|
|
string_arr_ = info.GetAttribute<std::string>("string_attr");
|
|
}
|
|
|
|
void Compute(OrtKernelContext* context);
|
|
|
|
private:
|
|
int64_t int_attr_;
|
|
float float_attr_;
|
|
|
|
std::vector<int64_t> ints_attr_;
|
|
std::vector<float> floats_attr_;
|
|
|
|
std::string string_arr_;
|
|
};
|
|
|
|
struct MyCustomOpWithAttributes : Ort::CustomOpBase<MyCustomOpWithAttributes, MyCustomKernelWithAttributes> {
|
|
explicit MyCustomOpWithAttributes(const char* provider) : provider_(provider) {}
|
|
|
|
void* CreateKernel(const OrtApi&, const OrtKernelInfo* info) const {
|
|
return new MyCustomKernelWithAttributes(info);
|
|
};
|
|
|
|
const char* GetName() const { return "FooBar_Attr"; };
|
|
const char* GetExecutionProviderType() const { return provider_; };
|
|
|
|
size_t GetInputTypeCount() const { return 1; };
|
|
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
|
|
|
size_t GetOutputTypeCount() const { return 1; };
|
|
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
|
|
|
private:
|
|
const char* provider_;
|
|
};
|
|
|
|
// Slice array of floats or doubles between [from, to) and save to output
|
|
struct SliceCustomOpKernel {
|
|
SliceCustomOpKernel(const OrtKernelInfo* /*info*/) {
|
|
}
|
|
|
|
void Compute(OrtKernelContext* context);
|
|
};
|
|
|
|
struct SliceCustomOp : Ort::CustomOpBase<SliceCustomOp, SliceCustomOpKernel> {
|
|
explicit SliceCustomOp(const char* provider) : provider_(provider) {}
|
|
void* CreateKernel(const OrtApi&, const OrtKernelInfo* info) const {
|
|
return new SliceCustomOpKernel(info);
|
|
};
|
|
|
|
const char* GetName() const { return "Slice"; };
|
|
const char* GetExecutionProviderType() const { return provider_; };
|
|
|
|
size_t GetInputTypeCount() const { return 3; };
|
|
ONNXTensorElementDataType GetInputType(size_t index) const {
|
|
if (index == 0)
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; // input array of float or double
|
|
else if (index == 1)
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; // slice from
|
|
|
|
// index 2 (keep compiler happy on Linux)
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; // slice to
|
|
};
|
|
|
|
size_t GetOutputTypeCount() const { return 1; };
|
|
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const {
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
|
|
}
|
|
|
|
private:
|
|
const char* provider_;
|
|
};
|
|
|
|
struct StandaloneCustomKernel {
|
|
StandaloneCustomKernel(const OrtKernelInfo* info);
|
|
~StandaloneCustomKernel();
|
|
void Compute(OrtKernelContext* context);
|
|
|
|
private:
|
|
#if !defined(REDUCED_OPS_BUILD)
|
|
void InitTopK();
|
|
void InvokeTopK(OrtKernelContext* context);
|
|
|
|
void InitGru();
|
|
void InvokeGru(OrtKernelContext* context);
|
|
|
|
void InitInvokeConv(OrtKernelContext* context); // create Conv and invoke in Compute(...)
|
|
|
|
Ort::Op op_topk_{nullptr};
|
|
Ort::Op op_gru_{nullptr};
|
|
#endif
|
|
|
|
Ort::KernelInfo info_copy_{nullptr};
|
|
Ort::Op op_add_{nullptr};
|
|
};
|
|
|
|
struct StandaloneCustomOp : Ort::CustomOpBase<StandaloneCustomOp, StandaloneCustomKernel> {
|
|
explicit StandaloneCustomOp(const char* provider) : provider_(provider) {}
|
|
|
|
void* CreateKernel(const OrtApi&, const OrtKernelInfo* info) const { return new StandaloneCustomKernel(info); };
|
|
const char* GetName() const { return "Foo"; };
|
|
const char* GetExecutionProviderType() const { return provider_; };
|
|
|
|
size_t GetInputTypeCount() const { return 2; };
|
|
ONNXTensorElementDataType GetInputType(size_t /*index*/) const {
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
|
};
|
|
|
|
size_t GetOutputTypeCount() const { return 1; };
|
|
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
|
|
|
private:
|
|
const char* provider_;
|
|
};
|
|
|
|
/////////////// structures to test multi-kernls-single-schema ///////////////
|
|
|
|
struct MulTopKernelFloat {
|
|
MulTopKernelFloat(const OrtKernelInfo*){};
|
|
~MulTopKernelFloat() = default;
|
|
void Compute(OrtKernelContext*){};
|
|
};
|
|
|
|
struct MulTopOpFloat : Ort::CustomOpBase<MulTopOpFloat, MulTopKernelFloat> {
|
|
void* CreateKernel(const OrtApi&, const OrtKernelInfo* info) const { return new MulTopKernelFloat(info); }
|
|
const char* GetName() const { return "MulTop"; }
|
|
const char* GetExecutionProviderType() const { return "CPUExecutionProvider"; }
|
|
size_t GetInputTypeCount() const { return 1; }
|
|
ONNXTensorElementDataType GetInputType(size_t) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }
|
|
size_t GetOutputTypeCount() const { return 1; }
|
|
ONNXTensorElementDataType GetOutputType(size_t) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }
|
|
};
|
|
|
|
struct MulTopKernelInt32 {
|
|
MulTopKernelInt32(const OrtKernelInfo*){};
|
|
~MulTopKernelInt32() = default;
|
|
void Compute(OrtKernelContext*){};
|
|
};
|
|
|
|
struct MulTopOpInt32 : Ort::CustomOpBase<MulTopOpInt32, MulTopKernelInt32> {
|
|
void* CreateKernel(const OrtApi&, const OrtKernelInfo* info) const { return new MulTopKernelInt32(info); }
|
|
const char* GetName() const { return "MulTop"; }
|
|
const char* GetExecutionProviderType() const { return "CPUExecutionProvider"; }
|
|
size_t GetInputTypeCount() const { return 1; }
|
|
ONNXTensorElementDataType GetInputType(size_t) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; }
|
|
size_t GetOutputTypeCount() const { return 1; }
|
|
ONNXTensorElementDataType GetOutputType(size_t) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; }
|
|
};
|
|
|
|
struct MulTopKernelDouble {
|
|
MulTopKernelDouble(const OrtKernelInfo*){};
|
|
~MulTopKernelDouble() = default;
|
|
void Compute(OrtKernelContext*){};
|
|
};
|
|
|
|
// MulTopOpDouble and MulTopOpFloat has input count mismatch
|
|
struct MulTopOpDouble : Ort::CustomOpBase<MulTopOpDouble, MulTopKernelDouble> {
|
|
void* CreateKernel(const OrtApi&, const OrtKernelInfo* info) const { return new MulTopKernelDouble(info); }
|
|
const char* GetName() const { return "MulTop"; }
|
|
const char* GetExecutionProviderType() const { return "CPUExecutionProvider"; }
|
|
size_t GetInputTypeCount() const { return 2; }
|
|
ONNXTensorElementDataType GetInputType(size_t) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; }
|
|
size_t GetOutputTypeCount() const { return 1; }
|
|
ONNXTensorElementDataType GetOutputType(size_t) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; }
|
|
};
|
|
|
|
struct MulTopKernelInt16 {
|
|
MulTopKernelInt16(const OrtKernelInfo*){};
|
|
~MulTopKernelInt16() = default;
|
|
void Compute(OrtKernelContext*){};
|
|
};
|
|
|
|
// MulTopOpInt16 and MulTopOpFloat has output count mismatch
|
|
struct MulTopOpInt16 : Ort::CustomOpBase<MulTopOpInt16, MulTopKernelInt16> {
|
|
void* CreateKernel(const OrtApi&, const OrtKernelInfo* info) const { return new MulTopKernelInt16(info); }
|
|
const char* GetName() const { return "MulTop"; }
|
|
const char* GetExecutionProviderType() const { return "CPUExecutionProvider"; }
|
|
size_t GetInputTypeCount() const { return 1; }
|
|
ONNXTensorElementDataType GetInputType(size_t) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; }
|
|
size_t GetOutputTypeCount() const { return 2; }
|
|
ONNXTensorElementDataType GetOutputType(size_t) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; }
|
|
};
|
|
|
|
// MulTopKernelFloat16 and MulTopOpFloat has input characteristic mismatch
|
|
struct MulTopKernelFloat16 {
|
|
MulTopKernelFloat16(const OrtKernelInfo*){};
|
|
~MulTopKernelFloat16() = default;
|
|
void Compute(OrtKernelContext*){};
|
|
};
|
|
|
|
struct MulTopOpFloat16 : Ort::CustomOpBase<MulTopOpFloat16, MulTopKernelFloat16> {
|
|
void* CreateKernel(const OrtApi&, const OrtKernelInfo* info) const { return new MulTopKernelFloat16(info); }
|
|
const char* GetName() const { return "MulTop"; }
|
|
const char* GetExecutionProviderType() const { return "CPUExecutionProvider"; }
|
|
size_t GetInputTypeCount() const { return 1; }
|
|
ONNXTensorElementDataType GetInputType(size_t) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; }
|
|
size_t GetOutputTypeCount() const { return 1; }
|
|
ONNXTensorElementDataType GetOutputType(size_t) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; }
|
|
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t) const {
|
|
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL;
|
|
}
|
|
}; |