mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Support parsing an array of values stored as an attribute in a custom op (#6878)
This commit is contained in:
parent
e64eff1f13
commit
c8e2e3191b
10 changed files with 272 additions and 8 deletions
|
|
@ -728,10 +728,40 @@ struct OrtApi {
|
|||
ORT_API2_STATUS(GetOpaqueValue, _In_ const char* domain_name, _In_ const char* type_name, _In_ const OrtValue* in,
|
||||
_Out_ void* data_container, size_t data_container_size);
|
||||
|
||||
/**
|
||||
* Fetch a float stored as an attribute in the graph node
|
||||
* \info - OrtKernelInfo instance
|
||||
* \name - name of the attribute to be parsed
|
||||
* \out - pointer to memory where the attribute is to be stored
|
||||
*/
|
||||
ORT_API2_STATUS(KernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name,
|
||||
_Out_ float* out);
|
||||
|
||||
/**
|
||||
* Fetch a 64-bit int stored as an attribute in the graph node
|
||||
* \info - OrtKernelInfo instance
|
||||
* \name - name of the attribute to be parsed
|
||||
* \out - pointer to memory where the attribute is to be stored
|
||||
*/
|
||||
ORT_API2_STATUS(KernelInfoGetAttribute_int64, _In_ const OrtKernelInfo* info, _In_ const char* name,
|
||||
_Out_ int64_t* out);
|
||||
/**
|
||||
* Fetch a string stored as an attribute in the graph node
|
||||
* \info - OrtKernelInfo instance
|
||||
* \name - name of the attribute to be parsed
|
||||
* \out - pointer to memory where the attribute's contents are to be stored
|
||||
* \size - actual size of string attribute
|
||||
* (If `out` is nullptr, the value of `size` is set to the true size of the string
|
||||
attribute, and a success status is returned.
|
||||
|
||||
If the `size` parameter is greater than or equal to the actual string attribute's size,
|
||||
the value of `size` is set to the true size of the string attribute, the provided memory
|
||||
is filled with the attribute's contents, and a success status is returned.
|
||||
|
||||
If the `size` parameter is lesser than the actual string attribute's size and `out`
|
||||
is not nullptr, the value of `size` is set to the true size of the string attribute
|
||||
and a failure status is returned.)
|
||||
*/
|
||||
ORT_API2_STATUS(KernelInfoGetAttribute_string, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ char* out,
|
||||
_Inout_ size_t* size);
|
||||
|
||||
|
|
@ -1179,6 +1209,48 @@ struct OrtApi {
|
|||
* Get the current device id of the GPU execution provider (cuda/tensorrt/rocm).
|
||||
*/
|
||||
ORT_API2_STATUS(GetCurrentGpuDeviceId, _In_ int* device_id);
|
||||
|
||||
/**
|
||||
* Fetch an array of int64_t values stored as an attribute in the graph node
|
||||
* \info - OrtKernelInfo instance
|
||||
* \name - name of the attribute to be parsed
|
||||
* \out - pointer to memory where the attribute's contents are to be stored
|
||||
* \size - actual size of attribute array
|
||||
* (If `out` is nullptr, the value of `size` is set to the true size of the attribute
|
||||
array's size, and a success status is returned.
|
||||
|
||||
If the `size` parameter is greater than or equal to the actual attribute array's size,
|
||||
the value of `size` is set to the true size of the attribute array's size,
|
||||
the provided memory is filled with the attribute's contents,
|
||||
and a success status is returned.
|
||||
|
||||
If the `size` parameter is lesser than the actual attribute array's size and `out`
|
||||
is not nullptr, the value of `size` is set to the true size of the attribute array's size
|
||||
and a failure status is returned.)
|
||||
*/
|
||||
ORT_API2_STATUS(KernelInfoGetAttributeArray_float, _In_ const OrtKernelInfo* info, _In_ const char* name,
|
||||
_Out_ float* out, _Inout_ size_t* size);
|
||||
|
||||
/**
|
||||
* Fetch an array of int64_t values stored as an attribute in the graph node
|
||||
* \info - OrtKernelInfo instance
|
||||
* \name - name of the attribute to be parsed
|
||||
* \out - pointer to memory where the attribute's contents are to be stored
|
||||
* \size - actual size of attribute array
|
||||
* (If `out` is nullptr, the value of `size` is set to the true size of the attribute
|
||||
array's size, and a success status is returned.
|
||||
|
||||
If the `size` parameter is greater than or equal to the actual attribute array's size,
|
||||
the value of `size` is set to the true size of the attribute array's size,
|
||||
the provided memory is filled with the attribute's contents,
|
||||
and a success status is returned.
|
||||
|
||||
If the `size` parameter is lesser than the actual attribute array's size and `out`
|
||||
is not nullptr, the value of `size` is set to the true size of the attribute array's size
|
||||
and a failure status is returned.)
|
||||
*/
|
||||
ORT_API2_STATUS(KernelInfoGetAttributeArray_int64, _In_ const OrtKernelInfo* info, _In_ const char* name,
|
||||
_Out_ int64_t* out, _Inout_ size_t* size);
|
||||
};
|
||||
|
||||
/*
|
||||
|
|
|
|||
|
|
@ -578,7 +578,7 @@ struct ArenaCfg : Base<OrtArenaCfg> {
|
|||
struct CustomOpApi {
|
||||
CustomOpApi(const OrtApi& api) : api_(api) {}
|
||||
|
||||
template <typename T> // T is only implemented for float, int64_t, and string
|
||||
template <typename T> // T is only implemented for std::vector<float>, std::vector<int64_t>, float, int64_t, and string
|
||||
T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name);
|
||||
|
||||
OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value);
|
||||
|
|
|
|||
|
|
@ -877,11 +877,11 @@ template <>
|
|||
inline std::string CustomOpApi::KernelInfoGetAttribute<std::string>(_In_ const OrtKernelInfo* info, _In_ const char* name) {
|
||||
size_t size = 0;
|
||||
std::string out;
|
||||
|
||||
// Feed nullptr for the data buffer to query the true size of the string attribute
|
||||
OrtStatus* status = api_.KernelInfoGetAttribute_string(info, name, nullptr, &size);
|
||||
|
||||
// The status should be ORT_INVALID_ARGUMENT because the size is insufficient to hold the string
|
||||
if (api_.GetErrorCode(status) == ORT_INVALID_ARGUMENT) {
|
||||
api_.ReleaseStatus(status);
|
||||
if (status == nullptr) {
|
||||
out.resize(size);
|
||||
ThrowOnError(api_.KernelInfoGetAttribute_string(info, name, &out[0], &size));
|
||||
out.resize(size - 1); // remove the terminating character '\0'
|
||||
|
|
@ -891,6 +891,39 @@ inline std::string CustomOpApi::KernelInfoGetAttribute<std::string>(_In_ const O
|
|||
return out;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::vector<float> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) {
|
||||
size_t size = 0;
|
||||
std::vector<float> out;
|
||||
|
||||
// Feed nullptr for the data buffer to query the true size of the attribute
|
||||
OrtStatus* status = api_.KernelInfoGetAttributeArray_float(info, name, nullptr, &size);
|
||||
|
||||
if (status == nullptr) {
|
||||
out.resize(size);
|
||||
ThrowOnError(api_.KernelInfoGetAttributeArray_float(info, name, out.data(), &size));
|
||||
} else {
|
||||
ThrowOnError(status);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::vector<int64_t> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) {
|
||||
size_t size = 0;
|
||||
std::vector<int64_t> out;
|
||||
|
||||
// Feed nullptr for the data buffer to query the true size of the attribute
|
||||
OrtStatus* status = api_.KernelInfoGetAttributeArray_int64(info, name, nullptr, &size);
|
||||
|
||||
if (status == nullptr) {
|
||||
out.resize(size);
|
||||
ThrowOnError(api_.KernelInfoGetAttributeArray_int64(info, name, out.data(), &size));
|
||||
} else {
|
||||
ThrowOnError(status);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
inline OrtTensorTypeAndShapeInfo* CustomOpApi::GetTensorTypeAndShape(_In_ const OrtValue* value) {
|
||||
OrtTensorTypeAndShapeInfo* out;
|
||||
ThrowOnError(api_.GetTensorTypeAndShape(value, &out));
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@
|
|||
#include "core/graph/onnx_protobuf.h"
|
||||
#include "core/session/inference_session.h"
|
||||
#include "core/session/ort_apis.h"
|
||||
#include <type_traits>
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out) {
|
||||
auto status = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAttr<float>(name, out);
|
||||
|
|
@ -53,12 +54,15 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_string, _In_ const OrtKernel
|
|||
std::string value;
|
||||
auto status = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAttr<std::string>(name, &value);
|
||||
if (status.IsOK()) {
|
||||
if (*size >= value.size() + 1) {
|
||||
if (out == nullptr) { // User is querying the true size of the attribute
|
||||
*size = value.size() + 1;
|
||||
return nullptr;
|
||||
} else if (*size >= value.size() + 1) {
|
||||
std::memcpy(out, value.data(), value.size());
|
||||
out[value.size()] = '\0';
|
||||
*size = value.size() + 1;
|
||||
return nullptr;
|
||||
} else {
|
||||
} else { // User has provided a buffer that is not large enough
|
||||
*size = value.size() + 1;
|
||||
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Result buffer is not large enough");
|
||||
}
|
||||
|
|
@ -66,6 +70,42 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_string, _In_ const OrtKernel
|
|||
return onnxruntime::ToOrtStatus(status);
|
||||
}
|
||||
|
||||
template <typename T, typename std::enable_if<std::is_fundamental<T>::value, int>::type = 0>
|
||||
static Status CopyDataFromVectorToMemory(const std::vector<T>& values, T* out, size_t* size) {
|
||||
if (out == nullptr) { // User is querying the true size of the attribute
|
||||
*size = values.size();
|
||||
return Status::OK();
|
||||
} else if (*size >= values.size()) {
|
||||
std::memcpy(out, values.data(), values.size() * sizeof(T));
|
||||
*size = values.size();
|
||||
} else { // User has provided a buffer that is not large enough
|
||||
*size = values.size();
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Result buffer is not large enough");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttributeArray_float, _In_ const OrtKernelInfo* info, _In_ const char* name,
|
||||
_Out_ float* out, _Inout_ size_t* size) {
|
||||
std::vector<float> values;
|
||||
auto status = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAttrs<float>(name, values);
|
||||
if (status.IsOK()) {
|
||||
status = CopyDataFromVectorToMemory<float>(values, out, size);
|
||||
}
|
||||
return onnxruntime::ToOrtStatus(status);
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttributeArray_int64, _In_ const OrtKernelInfo* info, _In_ const char* name,
|
||||
_Out_ int64_t* out, _Inout_ size_t* size) {
|
||||
std::vector<int64_t> values;
|
||||
auto status = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAttrs<int64_t>(name, values);
|
||||
if (status.IsOK()) {
|
||||
status = CopyDataFromVectorToMemory<int64_t>(values, out, size);
|
||||
}
|
||||
return onnxruntime::ToOrtStatus(status);
|
||||
}
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
|
||||
#include "core/framework/customregistry.h"
|
||||
namespace onnxruntime {
|
||||
|
|
|
|||
|
|
@ -2111,6 +2111,8 @@ static constexpr OrtApi ort_api_1_to_8 = {
|
|||
// End of Version 7 - DO NOT MODIFY ABOVE (see above text for more information)
|
||||
|
||||
// Version 8 - In development, feel free to add/remove/rearrange here
|
||||
&OrtApis::KernelInfoGetAttributeArray_float,
|
||||
&OrtApis::KernelInfoGetAttributeArray_int64,
|
||||
};
|
||||
|
||||
// Assert to do a limited check to ensure Version 1 of OrtApi never changes (will detect an addition or deletion but not if they cancel out each other)
|
||||
|
|
|
|||
|
|
@ -259,4 +259,6 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_TensorRT,
|
|||
_In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptions* tensorrt_options);
|
||||
ORT_API_STATUS_IMPL(SetCurrentGpuDeviceId, _In_ int device_id);
|
||||
ORT_API_STATUS_IMPL(GetCurrentGpuDeviceId, _In_ int* device_id);
|
||||
ORT_API_STATUS_IMPL(KernelInfoGetAttributeArray_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out, _Inout_ size_t* size);
|
||||
ORT_API_STATUS_IMPL(KernelInfoGetAttributeArray_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out, _Inout_ size_t* size);
|
||||
} // namespace OrtApis
|
||||
|
|
|
|||
|
|
@ -167,6 +167,7 @@ static constexpr PATH_TYPE VARIED_INPUT_CUSTOM_OP_MODEL_URI = TSTR("testdata/Var
|
|||
static constexpr PATH_TYPE VARIED_INPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/foo_3.onnx");
|
||||
static constexpr PATH_TYPE OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI = TSTR("testdata/foo_bar_1.onnx");
|
||||
static constexpr PATH_TYPE OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/foo_bar_2.onnx");
|
||||
static constexpr PATH_TYPE CUSTOM_OP_MODEL_WITH_ATTRIBUTES_URI = TSTR("testdata/foo_bar_3.onnx");
|
||||
|
||||
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
|
||||
static constexpr PATH_TYPE PYOP_FLOAT_MODEL_URI = TSTR("testdata/pyop_1.onnx");
|
||||
|
|
@ -529,6 +530,49 @@ TEST(CApiTest, optional_input_output_custom_op_handler) {
|
|||
}
|
||||
}
|
||||
}
|
||||
TEST(CApiTest, custom_op_with_attributes_handler) {
|
||||
MyCustomOpWithAttributes custom_op{onnxruntime::kCpuExecutionProvider};
|
||||
|
||||
Ort::CustomOpDomain custom_op_domain("");
|
||||
custom_op_domain.Add(&custom_op);
|
||||
|
||||
Ort::SessionOptions session_options;
|
||||
session_options.Add(custom_op_domain);
|
||||
|
||||
Ort::Session session(*ort_env, CUSTOM_OP_MODEL_WITH_ATTRIBUTES_URI, session_options);
|
||||
|
||||
Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
|
||||
|
||||
std::vector<Ort::Value> ort_inputs;
|
||||
std::vector<const char*> input_names;
|
||||
|
||||
// input 0 (float type)
|
||||
input_names.emplace_back("X");
|
||||
std::vector<float> input_0_data = {1.f};
|
||||
std::vector<int64_t> input_0_dims = {1};
|
||||
ort_inputs.emplace_back(
|
||||
Ort::Value::CreateTensor<float>(info, const_cast<float*>(input_0_data.data()),
|
||||
input_0_data.size(), input_0_dims.data(), input_0_dims.size()));
|
||||
|
||||
// Run
|
||||
const char* output_name = "Y";
|
||||
auto ort_outputs = session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(),
|
||||
&output_name, 1);
|
||||
ASSERT_EQ(ort_outputs.size(), 1u);
|
||||
|
||||
// Validate results
|
||||
std::vector<int64_t> y_dims = {1};
|
||||
std::vector<float> values_y = {15.f};
|
||||
auto type_info = ort_outputs[0].GetTensorTypeAndShapeInfo();
|
||||
ASSERT_EQ(type_info.GetShape(), y_dims);
|
||||
size_t total_len = type_info.GetElementCount();
|
||||
ASSERT_EQ(values_y.size(), total_len);
|
||||
|
||||
float* f = ort_outputs[0].GetTensorMutableData<float>();
|
||||
for (size_t i = 0; i != total_len; ++i) {
|
||||
ASSERT_EQ(values_y[i], f[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Tests registration of a custom op of the same name for both CPU and CUDA EPs
|
||||
#ifdef USE_CUDA
|
||||
|
|
|
|||
|
|
@ -87,3 +87,32 @@ void MyCustomKernelWithOptionalInput::Compute(OrtKernelContext* context) {
|
|||
out[i] = X1[i] + (X2 != nullptr ? X2[i] : 0) + X3[i];
|
||||
}
|
||||
}
|
||||
|
||||
void MyCustomKernelWithAttributes::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
const float* X = ort_.GetTensorData<float>(input_X);
|
||||
|
||||
// Setup output
|
||||
OrtTensorDimensions dimensions(ort_, input_X);
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||
float* out = ort_.GetTensorMutableData<float>(output);
|
||||
|
||||
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
|
||||
int64_t size = ort_.GetTensorShapeElementCount(output_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
|
||||
|
||||
// This kernel only supports CPU EP
|
||||
if (string_arr_ == "add") { // Test that the string attribute parsing went correctly
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
out[i] = X[i] +
|
||||
float_attr_ + static_cast<float>(int_attr_) +
|
||||
floats_attr_[0] + floats_attr_[1] +
|
||||
static_cast<float>(ints_attr_[0]) + static_cast<float>(ints_attr_[1]);
|
||||
}
|
||||
} else { // if the string attribute parsing had not gone correctly - it will trigger this path and fail the test due to result mis-match
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
out[i] = 0.f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
#include <vector>
|
||||
|
||||
struct Input {
|
||||
const char* name = nullptr;
|
||||
|
|
@ -81,7 +82,6 @@ struct MyCustomOpMultipleDynamicInputs : Ort::CustomOpBase<MyCustomOpMultipleDyn
|
|||
struct MyCustomKernelWithOptionalInput {
|
||||
MyCustomKernelWithOptionalInput(Ort::CustomOpApi ort, const OrtKernelInfo* /*info*/) : ort_(ort) {
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
|
|
@ -113,4 +113,46 @@ struct MyCustomOpWithOptionalInput : Ort::CustomOpBase<MyCustomOpWithOptionalInp
|
|||
|
||||
private:
|
||||
const char* provider_;
|
||||
};
|
||||
};
|
||||
|
||||
struct MyCustomKernelWithAttributes {
|
||||
MyCustomKernelWithAttributes(Ort::CustomOpApi ort, const OrtKernelInfo* info) : ort_(ort) {
|
||||
int_attr_ = ort_.KernelInfoGetAttribute<int64_t>(info, "int_attr");
|
||||
float_attr_ = ort_.KernelInfoGetAttribute<float>(info, "float_attr");
|
||||
|
||||
ints_attr_ = ort_.KernelInfoGetAttribute<std::vector<int64_t>>(info, "ints_attr");
|
||||
floats_attr_ = ort_.KernelInfoGetAttribute<std::vector<float>>(info, "floats_attr");
|
||||
|
||||
string_arr_ = ort_.KernelInfoGetAttribute<std::string>(info, "string_attr");
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
Ort::CustomOpApi ort_;
|
||||
|
||||
int64_t int_attr_;
|
||||
float float_attr_;
|
||||
|
||||
std::vector<int64_t> ints_attr_;
|
||||
std::vector<float> floats_attr_;
|
||||
|
||||
std::string string_arr_;
|
||||
};
|
||||
|
||||
struct MyCustomOpWithAttributes : Ort::CustomOpBase<MyCustomOpWithAttributes, MyCustomKernelWithAttributes> {
|
||||
explicit MyCustomOpWithAttributes(const char* provider) : provider_(provider) {}
|
||||
|
||||
void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) const { return new MyCustomKernelWithAttributes(api, info); };
|
||||
const char* GetName() const { return "FooBar_Attr"; };
|
||||
const char* GetExecutionProviderType() const { return provider_; };
|
||||
|
||||
size_t GetInputTypeCount() const { return 1; };
|
||||
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
||||
|
||||
size_t GetOutputTypeCount() const { return 1; };
|
||||
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
||||
|
||||
private:
|
||||
const char* provider_;
|
||||
};
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/foo_bar_3.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/foo_bar_3.onnx
vendored
Normal file
Binary file not shown.
Loading…
Reference in a new issue