Support parsing an array of values stored as an attribute in a custom op (#6878)

This commit is contained in:
Hariharan Seshadri 2021-03-08 23:49:58 -08:00 committed by GitHub
parent e64eff1f13
commit c8e2e3191b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 272 additions and 8 deletions

View file

@ -728,10 +728,40 @@ struct OrtApi {
ORT_API2_STATUS(GetOpaqueValue, _In_ const char* domain_name, _In_ const char* type_name, _In_ const OrtValue* in,
_Out_ void* data_container, size_t data_container_size);
/**
* Fetch a float stored as an attribute in the graph node
* \info - OrtKernelInfo instance
* \name - name of the attribute to be parsed
* \out - pointer to memory where the attribute is to be stored
*/
ORT_API2_STATUS(KernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name,
_Out_ float* out);
/**
* Fetch a 64-bit int stored as an attribute in the graph node
* \info - OrtKernelInfo instance
* \name - name of the attribute to be parsed
* \out - pointer to memory where the attribute is to be stored
*/
ORT_API2_STATUS(KernelInfoGetAttribute_int64, _In_ const OrtKernelInfo* info, _In_ const char* name,
_Out_ int64_t* out);
/**
* Fetch a string stored as an attribute in the graph node
* \info - OrtKernelInfo instance
* \name - name of the attribute to be parsed
* \out - pointer to memory where the attribute's contents are to be stored
* \size - actual size of string attribute
* (If `out` is nullptr, the value of `size` is set to the true size of the string
attribute, and a success status is returned.
If the `size` parameter is greater than or equal to the actual string attribute's size,
the value of `size` is set to the true size of the string attribute, the provided memory
is filled with the attribute's contents, and a success status is returned.
If the `size` parameter is lesser than the actual string attribute's size and `out`
is not nullptr, the value of `size` is set to the true size of the string attribute
and a failure status is returned.)
*/
ORT_API2_STATUS(KernelInfoGetAttribute_string, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ char* out,
_Inout_ size_t* size);
@ -1179,6 +1209,48 @@ struct OrtApi {
* Get the current device id of the GPU execution provider (cuda/tensorrt/rocm).
*/
ORT_API2_STATUS(GetCurrentGpuDeviceId, _In_ int* device_id);
/**
* Fetch an array of int64_t values stored as an attribute in the graph node
* \info - OrtKernelInfo instance
* \name - name of the attribute to be parsed
* \out - pointer to memory where the attribute's contents are to be stored
* \size - actual size of attribute array
* (If `out` is nullptr, the value of `size` is set to the true size of the attribute
array's size, and a success status is returned.
If the `size` parameter is greater than or equal to the actual attribute array's size,
the value of `size` is set to the true size of the attribute array's size,
the provided memory is filled with the attribute's contents,
and a success status is returned.
If the `size` parameter is lesser than the actual attribute array's size and `out`
is not nullptr, the value of `size` is set to the true size of the attribute array's size
and a failure status is returned.)
*/
ORT_API2_STATUS(KernelInfoGetAttributeArray_float, _In_ const OrtKernelInfo* info, _In_ const char* name,
_Out_ float* out, _Inout_ size_t* size);
/**
* Fetch an array of int64_t values stored as an attribute in the graph node
* \info - OrtKernelInfo instance
* \name - name of the attribute to be parsed
* \out - pointer to memory where the attribute's contents are to be stored
* \size - actual size of attribute array
* (If `out` is nullptr, the value of `size` is set to the true size of the attribute
array's size, and a success status is returned.
If the `size` parameter is greater than or equal to the actual attribute array's size,
the value of `size` is set to the true size of the attribute array's size,
the provided memory is filled with the attribute's contents,
and a success status is returned.
If the `size` parameter is lesser than the actual attribute array's size and `out`
is not nullptr, the value of `size` is set to the true size of the attribute array's size
and a failure status is returned.)
*/
ORT_API2_STATUS(KernelInfoGetAttributeArray_int64, _In_ const OrtKernelInfo* info, _In_ const char* name,
_Out_ int64_t* out, _Inout_ size_t* size);
};
/*

View file

@ -578,7 +578,7 @@ struct ArenaCfg : Base<OrtArenaCfg> {
struct CustomOpApi {
CustomOpApi(const OrtApi& api) : api_(api) {}
template <typename T> // T is only implemented for float, int64_t, and string
template <typename T> // T is only implemented for std::vector<float>, std::vector<int64_t>, float, int64_t, and string
T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name);
OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value);

View file

@ -877,11 +877,11 @@ template <>
inline std::string CustomOpApi::KernelInfoGetAttribute<std::string>(_In_ const OrtKernelInfo* info, _In_ const char* name) {
size_t size = 0;
std::string out;
// Feed nullptr for the data buffer to query the true size of the string attribute
OrtStatus* status = api_.KernelInfoGetAttribute_string(info, name, nullptr, &size);
// The status should be ORT_INVALID_ARGUMENT because the size is insufficient to hold the string
if (api_.GetErrorCode(status) == ORT_INVALID_ARGUMENT) {
api_.ReleaseStatus(status);
if (status == nullptr) {
out.resize(size);
ThrowOnError(api_.KernelInfoGetAttribute_string(info, name, &out[0], &size));
out.resize(size - 1); // remove the terminating character '\0'
@ -891,6 +891,39 @@ inline std::string CustomOpApi::KernelInfoGetAttribute<std::string>(_In_ const O
return out;
}
template <>
inline std::vector<float> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) {
size_t size = 0;
std::vector<float> out;
// Feed nullptr for the data buffer to query the true size of the attribute
OrtStatus* status = api_.KernelInfoGetAttributeArray_float(info, name, nullptr, &size);
if (status == nullptr) {
out.resize(size);
ThrowOnError(api_.KernelInfoGetAttributeArray_float(info, name, out.data(), &size));
} else {
ThrowOnError(status);
}
return out;
}
template <>
inline std::vector<int64_t> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) {
size_t size = 0;
std::vector<int64_t> out;
// Feed nullptr for the data buffer to query the true size of the attribute
OrtStatus* status = api_.KernelInfoGetAttributeArray_int64(info, name, nullptr, &size);
if (status == nullptr) {
out.resize(size);
ThrowOnError(api_.KernelInfoGetAttributeArray_int64(info, name, out.data(), &size));
} else {
ThrowOnError(status);
}
return out;
}
inline OrtTensorTypeAndShapeInfo* CustomOpApi::GetTensorTypeAndShape(_In_ const OrtValue* value) {
OrtTensorTypeAndShapeInfo* out;
ThrowOnError(api_.GetTensorTypeAndShape(value, &out));

View file

@ -13,6 +13,7 @@
#include "core/graph/onnx_protobuf.h"
#include "core/session/inference_session.h"
#include "core/session/ort_apis.h"
#include <type_traits>
ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out) {
auto status = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAttr<float>(name, out);
@ -53,12 +54,15 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_string, _In_ const OrtKernel
std::string value;
auto status = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAttr<std::string>(name, &value);
if (status.IsOK()) {
if (*size >= value.size() + 1) {
if (out == nullptr) { // User is querying the true size of the attribute
*size = value.size() + 1;
return nullptr;
} else if (*size >= value.size() + 1) {
std::memcpy(out, value.data(), value.size());
out[value.size()] = '\0';
*size = value.size() + 1;
return nullptr;
} else {
} else { // User has provided a buffer that is not large enough
*size = value.size() + 1;
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Result buffer is not large enough");
}
@ -66,6 +70,42 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_string, _In_ const OrtKernel
return onnxruntime::ToOrtStatus(status);
}
template <typename T, typename std::enable_if<std::is_fundamental<T>::value, int>::type = 0>
static Status CopyDataFromVectorToMemory(const std::vector<T>& values, T* out, size_t* size) {
if (out == nullptr) { // User is querying the true size of the attribute
*size = values.size();
return Status::OK();
} else if (*size >= values.size()) {
std::memcpy(out, values.data(), values.size() * sizeof(T));
*size = values.size();
} else { // User has provided a buffer that is not large enough
*size = values.size();
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Result buffer is not large enough");
}
return Status::OK();
}
ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttributeArray_float, _In_ const OrtKernelInfo* info, _In_ const char* name,
_Out_ float* out, _Inout_ size_t* size) {
std::vector<float> values;
auto status = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAttrs<float>(name, values);
if (status.IsOK()) {
status = CopyDataFromVectorToMemory<float>(values, out, size);
}
return onnxruntime::ToOrtStatus(status);
}
ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttributeArray_int64, _In_ const OrtKernelInfo* info, _In_ const char* name,
_Out_ int64_t* out, _Inout_ size_t* size) {
std::vector<int64_t> values;
auto status = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAttrs<int64_t>(name, values);
if (status.IsOK()) {
status = CopyDataFromVectorToMemory<int64_t>(values, out, size);
}
return onnxruntime::ToOrtStatus(status);
}
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
#include "core/framework/customregistry.h"
namespace onnxruntime {

View file

@ -2111,6 +2111,8 @@ static constexpr OrtApi ort_api_1_to_8 = {
// End of Version 7 - DO NOT MODIFY ABOVE (see above text for more information)
// Version 8 - In development, feel free to add/remove/rearrange here
&OrtApis::KernelInfoGetAttributeArray_float,
&OrtApis::KernelInfoGetAttributeArray_int64,
};
// Assert to do a limited check to ensure Version 1 of OrtApi never changes (will detect an addition or deletion but not if they cancel out each other)

View file

@ -259,4 +259,6 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_TensorRT,
_In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptions* tensorrt_options);
ORT_API_STATUS_IMPL(SetCurrentGpuDeviceId, _In_ int device_id);
ORT_API_STATUS_IMPL(GetCurrentGpuDeviceId, _In_ int* device_id);
ORT_API_STATUS_IMPL(KernelInfoGetAttributeArray_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out, _Inout_ size_t* size);
ORT_API_STATUS_IMPL(KernelInfoGetAttributeArray_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out, _Inout_ size_t* size);
} // namespace OrtApis

View file

@ -167,6 +167,7 @@ static constexpr PATH_TYPE VARIED_INPUT_CUSTOM_OP_MODEL_URI = TSTR("testdata/Var
static constexpr PATH_TYPE VARIED_INPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/foo_3.onnx");
static constexpr PATH_TYPE OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI = TSTR("testdata/foo_bar_1.onnx");
static constexpr PATH_TYPE OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/foo_bar_2.onnx");
static constexpr PATH_TYPE CUSTOM_OP_MODEL_WITH_ATTRIBUTES_URI = TSTR("testdata/foo_bar_3.onnx");
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
static constexpr PATH_TYPE PYOP_FLOAT_MODEL_URI = TSTR("testdata/pyop_1.onnx");
@ -529,6 +530,49 @@ TEST(CApiTest, optional_input_output_custom_op_handler) {
}
}
}
TEST(CApiTest, custom_op_with_attributes_handler) {
MyCustomOpWithAttributes custom_op{onnxruntime::kCpuExecutionProvider};
Ort::CustomOpDomain custom_op_domain("");
custom_op_domain.Add(&custom_op);
Ort::SessionOptions session_options;
session_options.Add(custom_op_domain);
Ort::Session session(*ort_env, CUSTOM_OP_MODEL_WITH_ATTRIBUTES_URI, session_options);
Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
std::vector<Ort::Value> ort_inputs;
std::vector<const char*> input_names;
// input 0 (float type)
input_names.emplace_back("X");
std::vector<float> input_0_data = {1.f};
std::vector<int64_t> input_0_dims = {1};
ort_inputs.emplace_back(
Ort::Value::CreateTensor<float>(info, const_cast<float*>(input_0_data.data()),
input_0_data.size(), input_0_dims.data(), input_0_dims.size()));
// Run
const char* output_name = "Y";
auto ort_outputs = session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(),
&output_name, 1);
ASSERT_EQ(ort_outputs.size(), 1u);
// Validate results
std::vector<int64_t> y_dims = {1};
std::vector<float> values_y = {15.f};
auto type_info = ort_outputs[0].GetTensorTypeAndShapeInfo();
ASSERT_EQ(type_info.GetShape(), y_dims);
size_t total_len = type_info.GetElementCount();
ASSERT_EQ(values_y.size(), total_len);
float* f = ort_outputs[0].GetTensorMutableData<float>();
for (size_t i = 0; i != total_len; ++i) {
ASSERT_EQ(values_y[i], f[i]);
}
}
// Tests registration of a custom op of the same name for both CPU and CUDA EPs
#ifdef USE_CUDA

View file

@ -87,3 +87,32 @@ void MyCustomKernelWithOptionalInput::Compute(OrtKernelContext* context) {
out[i] = X1[i] + (X2 != nullptr ? X2[i] : 0) + X3[i];
}
}
void MyCustomKernelWithAttributes::Compute(OrtKernelContext* context) {
// Setup inputs
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
const float* X = ort_.GetTensorData<float>(input_X);
// Setup output
OrtTensorDimensions dimensions(ort_, input_X);
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
float* out = ort_.GetTensorMutableData<float>(output);
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
int64_t size = ort_.GetTensorShapeElementCount(output_info);
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
// This kernel only supports CPU EP
if (string_arr_ == "add") { // Test that the string attribute parsing went correctly
for (int64_t i = 0; i < size; i++) {
out[i] = X[i] +
float_attr_ + static_cast<float>(int_attr_) +
floats_attr_[0] + floats_attr_[1] +
static_cast<float>(ints_attr_[0]) + static_cast<float>(ints_attr_[1]);
}
} else { // if the string attribute parsing had not gone correctly - it will trigger this path and fail the test due to result mis-match
for (int64_t i = 0; i < size; i++) {
out[i] = 0.f;
}
}
}

View file

@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "core/session/onnxruntime_cxx_api.h"
#include <vector>
struct Input {
const char* name = nullptr;
@ -81,7 +82,6 @@ struct MyCustomOpMultipleDynamicInputs : Ort::CustomOpBase<MyCustomOpMultipleDyn
struct MyCustomKernelWithOptionalInput {
MyCustomKernelWithOptionalInput(Ort::CustomOpApi ort, const OrtKernelInfo* /*info*/) : ort_(ort) {
}
void Compute(OrtKernelContext* context);
private:
@ -113,4 +113,46 @@ struct MyCustomOpWithOptionalInput : Ort::CustomOpBase<MyCustomOpWithOptionalInp
private:
const char* provider_;
};
};
struct MyCustomKernelWithAttributes {
MyCustomKernelWithAttributes(Ort::CustomOpApi ort, const OrtKernelInfo* info) : ort_(ort) {
int_attr_ = ort_.KernelInfoGetAttribute<int64_t>(info, "int_attr");
float_attr_ = ort_.KernelInfoGetAttribute<float>(info, "float_attr");
ints_attr_ = ort_.KernelInfoGetAttribute<std::vector<int64_t>>(info, "ints_attr");
floats_attr_ = ort_.KernelInfoGetAttribute<std::vector<float>>(info, "floats_attr");
string_arr_ = ort_.KernelInfoGetAttribute<std::string>(info, "string_attr");
}
void Compute(OrtKernelContext* context);
private:
Ort::CustomOpApi ort_;
int64_t int_attr_;
float float_attr_;
std::vector<int64_t> ints_attr_;
std::vector<float> floats_attr_;
std::string string_arr_;
};
struct MyCustomOpWithAttributes : Ort::CustomOpBase<MyCustomOpWithAttributes, MyCustomKernelWithAttributes> {
explicit MyCustomOpWithAttributes(const char* provider) : provider_(provider) {}
void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) const { return new MyCustomKernelWithAttributes(api, info); };
const char* GetName() const { return "FooBar_Attr"; };
const char* GetExecutionProviderType() const { return provider_; };
size_t GetInputTypeCount() const { return 1; };
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
size_t GetOutputTypeCount() const { return 1; };
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
private:
const char* provider_;
};

BIN
onnxruntime/test/testdata/foo_bar_3.onnx vendored Normal file

Binary file not shown.