diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 5ca4b9b1e0..719a881a0e 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -728,10 +728,40 @@ struct OrtApi { ORT_API2_STATUS(GetOpaqueValue, _In_ const char* domain_name, _In_ const char* type_name, _In_ const OrtValue* in, _Out_ void* data_container, size_t data_container_size); + /** + * Fetch a float stored as an attribute in the graph node + * \info - OrtKernelInfo instance + * \name - name of the attribute to be parsed + * \out - pointer to memory where the attribute is to be stored + */ ORT_API2_STATUS(KernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out); + + /** + * Fetch a 64-bit int stored as an attribute in the graph node + * \info - OrtKernelInfo instance + * \name - name of the attribute to be parsed + * \out - pointer to memory where the attribute is to be stored + */ ORT_API2_STATUS(KernelInfoGetAttribute_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out); + /** + * Fetch a string stored as an attribute in the graph node + * \info - OrtKernelInfo instance + * \name - name of the attribute to be parsed + * \out - pointer to memory where the attribute's contents are to be stored + * \size - actual size of string attribute + * (If `out` is nullptr, the value of `size` is set to the true size of the string + attribute, and a success status is returned. + + If the `size` parameter is greater than or equal to the actual string attribute's size, + the value of `size` is set to the true size of the string attribute, the provided memory + is filled with the attribute's contents, and a success status is returned. + + If the `size` parameter is lesser than the actual string attribute's size and `out` + is not nullptr, the value of `size` is set to the true size of the string attribute + and a failure status is returned.) + */ ORT_API2_STATUS(KernelInfoGetAttribute_string, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ char* out, _Inout_ size_t* size); @@ -1179,6 +1209,48 @@ struct OrtApi { * Get the current device id of the GPU execution provider (cuda/tensorrt/rocm). */ ORT_API2_STATUS(GetCurrentGpuDeviceId, _In_ int* device_id); + + /** + * Fetch an array of int64_t values stored as an attribute in the graph node + * \info - OrtKernelInfo instance + * \name - name of the attribute to be parsed + * \out - pointer to memory where the attribute's contents are to be stored + * \size - actual size of attribute array + * (If `out` is nullptr, the value of `size` is set to the true size of the attribute + array's size, and a success status is returned. + + If the `size` parameter is greater than or equal to the actual attribute array's size, + the value of `size` is set to the true size of the attribute array's size, + the provided memory is filled with the attribute's contents, + and a success status is returned. + + If the `size` parameter is lesser than the actual attribute array's size and `out` + is not nullptr, the value of `size` is set to the true size of the attribute array's size + and a failure status is returned.) + */ + ORT_API2_STATUS(KernelInfoGetAttributeArray_float, _In_ const OrtKernelInfo* info, _In_ const char* name, + _Out_ float* out, _Inout_ size_t* size); + + /** + * Fetch an array of int64_t values stored as an attribute in the graph node + * \info - OrtKernelInfo instance + * \name - name of the attribute to be parsed + * \out - pointer to memory where the attribute's contents are to be stored + * \size - actual size of attribute array + * (If `out` is nullptr, the value of `size` is set to the true size of the attribute + array's size, and a success status is returned. + + If the `size` parameter is greater than or equal to the actual attribute array's size, + the value of `size` is set to the true size of the attribute array's size, + the provided memory is filled with the attribute's contents, + and a success status is returned. + + If the `size` parameter is lesser than the actual attribute array's size and `out` + is not nullptr, the value of `size` is set to the true size of the attribute array's size + and a failure status is returned.) + */ + ORT_API2_STATUS(KernelInfoGetAttributeArray_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, + _Out_ int64_t* out, _Inout_ size_t* size); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index f9c87c0a0c..55ae18e270 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -578,7 +578,7 @@ struct ArenaCfg : Base { struct CustomOpApi { CustomOpApi(const OrtApi& api) : api_(api) {} - template // T is only implemented for float, int64_t, and string + template // T is only implemented for std::vector, std::vector, float, int64_t, and string T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name); OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value); diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index a818c3c691..d27d9055f6 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -877,11 +877,11 @@ template <> inline std::string CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) { size_t size = 0; std::string out; + + // Feed nullptr for the data buffer to query the true size of the string attribute OrtStatus* status = api_.KernelInfoGetAttribute_string(info, name, nullptr, &size); - // The status should be ORT_INVALID_ARGUMENT because the size is insufficient to hold the string - if (api_.GetErrorCode(status) == ORT_INVALID_ARGUMENT) { - api_.ReleaseStatus(status); + if (status == nullptr) { out.resize(size); ThrowOnError(api_.KernelInfoGetAttribute_string(info, name, &out[0], &size)); out.resize(size - 1); // remove the terminating character '\0' @@ -891,6 +891,39 @@ inline std::string CustomOpApi::KernelInfoGetAttribute(_In_ const O return out; } +template <> +inline std::vector CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) { + size_t size = 0; + std::vector out; + + // Feed nullptr for the data buffer to query the true size of the attribute + OrtStatus* status = api_.KernelInfoGetAttributeArray_float(info, name, nullptr, &size); + + if (status == nullptr) { + out.resize(size); + ThrowOnError(api_.KernelInfoGetAttributeArray_float(info, name, out.data(), &size)); + } else { + ThrowOnError(status); + } + return out; +} + +template <> +inline std::vector CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) { + size_t size = 0; + std::vector out; + + // Feed nullptr for the data buffer to query the true size of the attribute + OrtStatus* status = api_.KernelInfoGetAttributeArray_int64(info, name, nullptr, &size); + + if (status == nullptr) { + out.resize(size); + ThrowOnError(api_.KernelInfoGetAttributeArray_int64(info, name, out.data(), &size)); + } else { + ThrowOnError(status); + } + return out; +} inline OrtTensorTypeAndShapeInfo* CustomOpApi::GetTensorTypeAndShape(_In_ const OrtValue* value) { OrtTensorTypeAndShapeInfo* out; ThrowOnError(api_.GetTensorTypeAndShape(value, &out)); diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 25d97a6b00..3b60ab4ddc 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -13,6 +13,7 @@ #include "core/graph/onnx_protobuf.h" #include "core/session/inference_session.h" #include "core/session/ort_apis.h" +#include ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out) { auto status = reinterpret_cast(info)->GetAttr(name, out); @@ -53,12 +54,15 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_string, _In_ const OrtKernel std::string value; auto status = reinterpret_cast(info)->GetAttr(name, &value); if (status.IsOK()) { - if (*size >= value.size() + 1) { + if (out == nullptr) { // User is querying the true size of the attribute + *size = value.size() + 1; + return nullptr; + } else if (*size >= value.size() + 1) { std::memcpy(out, value.data(), value.size()); out[value.size()] = '\0'; *size = value.size() + 1; return nullptr; - } else { + } else { // User has provided a buffer that is not large enough *size = value.size() + 1; return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Result buffer is not large enough"); } @@ -66,6 +70,42 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_string, _In_ const OrtKernel return onnxruntime::ToOrtStatus(status); } +template ::value, int>::type = 0> +static Status CopyDataFromVectorToMemory(const std::vector& values, T* out, size_t* size) { + if (out == nullptr) { // User is querying the true size of the attribute + *size = values.size(); + return Status::OK(); + } else if (*size >= values.size()) { + std::memcpy(out, values.data(), values.size() * sizeof(T)); + *size = values.size(); + } else { // User has provided a buffer that is not large enough + *size = values.size(); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Result buffer is not large enough"); + } + + return Status::OK(); +} + +ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttributeArray_float, _In_ const OrtKernelInfo* info, _In_ const char* name, + _Out_ float* out, _Inout_ size_t* size) { + std::vector values; + auto status = reinterpret_cast(info)->GetAttrs(name, values); + if (status.IsOK()) { + status = CopyDataFromVectorToMemory(values, out, size); + } + return onnxruntime::ToOrtStatus(status); +} + +ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttributeArray_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, + _Out_ int64_t* out, _Inout_ size_t* size) { + std::vector values; + auto status = reinterpret_cast(info)->GetAttrs(name, values); + if (status.IsOK()) { + status = CopyDataFromVectorToMemory(values, out, size); + } + return onnxruntime::ToOrtStatus(status); +} + #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) #include "core/framework/customregistry.h" namespace onnxruntime { diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 24d94bc625..a98af18f18 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2111,6 +2111,8 @@ static constexpr OrtApi ort_api_1_to_8 = { // End of Version 7 - DO NOT MODIFY ABOVE (see above text for more information) // Version 8 - In development, feel free to add/remove/rearrange here + &OrtApis::KernelInfoGetAttributeArray_float, + &OrtApis::KernelInfoGetAttributeArray_int64, }; // Assert to do a limited check to ensure Version 1 of OrtApi never changes (will detect an addition or deletion but not if they cancel out each other) diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 10ab7328f3..fe6e358771 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -259,4 +259,6 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_TensorRT, _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptions* tensorrt_options); ORT_API_STATUS_IMPL(SetCurrentGpuDeviceId, _In_ int device_id); ORT_API_STATUS_IMPL(GetCurrentGpuDeviceId, _In_ int* device_id); +ORT_API_STATUS_IMPL(KernelInfoGetAttributeArray_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out, _Inout_ size_t* size); +ORT_API_STATUS_IMPL(KernelInfoGetAttributeArray_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out, _Inout_ size_t* size); } // namespace OrtApis diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 4f4a03c818..cd70eaf64c 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -167,6 +167,7 @@ static constexpr PATH_TYPE VARIED_INPUT_CUSTOM_OP_MODEL_URI = TSTR("testdata/Var static constexpr PATH_TYPE VARIED_INPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/foo_3.onnx"); static constexpr PATH_TYPE OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI = TSTR("testdata/foo_bar_1.onnx"); static constexpr PATH_TYPE OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/foo_bar_2.onnx"); +static constexpr PATH_TYPE CUSTOM_OP_MODEL_WITH_ATTRIBUTES_URI = TSTR("testdata/foo_bar_3.onnx"); #ifdef ENABLE_LANGUAGE_INTEROP_OPS static constexpr PATH_TYPE PYOP_FLOAT_MODEL_URI = TSTR("testdata/pyop_1.onnx"); @@ -529,6 +530,49 @@ TEST(CApiTest, optional_input_output_custom_op_handler) { } } } +TEST(CApiTest, custom_op_with_attributes_handler) { + MyCustomOpWithAttributes custom_op{onnxruntime::kCpuExecutionProvider}; + + Ort::CustomOpDomain custom_op_domain(""); + custom_op_domain.Add(&custom_op); + + Ort::SessionOptions session_options; + session_options.Add(custom_op_domain); + + Ort::Session session(*ort_env, CUSTOM_OP_MODEL_WITH_ATTRIBUTES_URI, session_options); + + Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + + std::vector ort_inputs; + std::vector input_names; + + // input 0 (float type) + input_names.emplace_back("X"); + std::vector input_0_data = {1.f}; + std::vector input_0_dims = {1}; + ort_inputs.emplace_back( + Ort::Value::CreateTensor(info, const_cast(input_0_data.data()), + input_0_data.size(), input_0_dims.data(), input_0_dims.size())); + + // Run + const char* output_name = "Y"; + auto ort_outputs = session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), + &output_name, 1); + ASSERT_EQ(ort_outputs.size(), 1u); + + // Validate results + std::vector y_dims = {1}; + std::vector values_y = {15.f}; + auto type_info = ort_outputs[0].GetTensorTypeAndShapeInfo(); + ASSERT_EQ(type_info.GetShape(), y_dims); + size_t total_len = type_info.GetElementCount(); + ASSERT_EQ(values_y.size(), total_len); + + float* f = ort_outputs[0].GetTensorMutableData(); + for (size_t i = 0; i != total_len; ++i) { + ASSERT_EQ(values_y[i], f[i]); + } +} // Tests registration of a custom op of the same name for both CPU and CUDA EPs #ifdef USE_CUDA diff --git a/onnxruntime/test/shared_lib/utils.cc b/onnxruntime/test/shared_lib/utils.cc index 6267cc8d0d..046fbf6e8d 100644 --- a/onnxruntime/test/shared_lib/utils.cc +++ b/onnxruntime/test/shared_lib/utils.cc @@ -87,3 +87,32 @@ void MyCustomKernelWithOptionalInput::Compute(OrtKernelContext* context) { out[i] = X1[i] + (X2 != nullptr ? X2[i] : 0) + X3[i]; } } + +void MyCustomKernelWithAttributes::Compute(OrtKernelContext* context) { + // Setup inputs + const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0); + const float* X = ort_.GetTensorData(input_X); + + // Setup output + OrtTensorDimensions dimensions(ort_, input_X); + OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size()); + float* out = ort_.GetTensorMutableData(output); + + OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output); + int64_t size = ort_.GetTensorShapeElementCount(output_info); + ort_.ReleaseTensorTypeAndShapeInfo(output_info); + + // This kernel only supports CPU EP + if (string_arr_ == "add") { // Test that the string attribute parsing went correctly + for (int64_t i = 0; i < size; i++) { + out[i] = X[i] + + float_attr_ + static_cast(int_attr_) + + floats_attr_[0] + floats_attr_[1] + + static_cast(ints_attr_[0]) + static_cast(ints_attr_[1]); + } + } else { // if the string attribute parsing had not gone correctly - it will trigger this path and fail the test due to result mis-match + for (int64_t i = 0; i < size; i++) { + out[i] = 0.f; + } + } +} diff --git a/onnxruntime/test/shared_lib/utils.h b/onnxruntime/test/shared_lib/utils.h index 19b8467f3f..b719475c27 100644 --- a/onnxruntime/test/shared_lib/utils.h +++ b/onnxruntime/test/shared_lib/utils.h @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/session/onnxruntime_cxx_api.h" +#include struct Input { const char* name = nullptr; @@ -81,7 +82,6 @@ struct MyCustomOpMultipleDynamicInputs : Ort::CustomOpBase(info, "int_attr"); + float_attr_ = ort_.KernelInfoGetAttribute(info, "float_attr"); + + ints_attr_ = ort_.KernelInfoGetAttribute>(info, "ints_attr"); + floats_attr_ = ort_.KernelInfoGetAttribute>(info, "floats_attr"); + + string_arr_ = ort_.KernelInfoGetAttribute(info, "string_attr"); + } + + void Compute(OrtKernelContext* context); + + private: + Ort::CustomOpApi ort_; + + int64_t int_attr_; + float float_attr_; + + std::vector ints_attr_; + std::vector floats_attr_; + + std::string string_arr_; +}; + +struct MyCustomOpWithAttributes : Ort::CustomOpBase { + explicit MyCustomOpWithAttributes(const char* provider) : provider_(provider) {} + + void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) const { return new MyCustomKernelWithAttributes(api, info); }; + const char* GetName() const { return "FooBar_Attr"; }; + const char* GetExecutionProviderType() const { return provider_; }; + + size_t GetInputTypeCount() const { return 1; }; + ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; + + size_t GetOutputTypeCount() const { return 1; }; + ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; + + private: + const char* provider_; +}; diff --git a/onnxruntime/test/testdata/foo_bar_3.onnx b/onnxruntime/test/testdata/foo_bar_3.onnx new file mode 100644 index 0000000000..f645936f43 Binary files /dev/null and b/onnxruntime/test/testdata/foo_bar_3.onnx differ