mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Add new api function At() (#4457)
* add modern standards to function arguments * add first version of At for better tensor element access
This commit is contained in:
parent
38c804a048
commit
b7254551f0
10 changed files with 106 additions and 31 deletions
|
|
@ -22,8 +22,8 @@
|
|||
namespace Ort::Experimental {
|
||||
|
||||
struct Session : Ort::Session {
|
||||
Session(Env& env, ORTCHAR_T* model_path, SessionOptions& options)
|
||||
: Ort::Session(env, model_path, options){};
|
||||
Session(Env& env, std::basic_string<ORTCHAR_T>& model_path, SessionOptions& options)
|
||||
: Ort::Session(env, model_path.data(), options){};
|
||||
Session(Env& env, void* model_data, size_t model_data_length, SessionOptions& options)
|
||||
: Ort::Session(env, model_data, model_data_length, options){};
|
||||
|
||||
|
|
|
|||
|
|
@ -959,6 +959,20 @@ struct OrtApi {
|
|||
*/
|
||||
void(ORT_API_CALL* ClearBoundInputs)(_Inout_ OrtIoBinding* binding_ptr) NO_EXCEPTION ORT_ALL_ARGS_NONNULL;
|
||||
void(ORT_API_CALL* ClearBoundOutputs)(_Inout_ OrtIoBinding* binding_ptr) NO_EXCEPTION ORT_ALL_ARGS_NONNULL;
|
||||
|
||||
/**
|
||||
* Provides element-level access into a tensor.
|
||||
* \param location_values a pointer to an array of index values that specify an element's location in the tensor data blob
|
||||
* \param location_values_count length of location_values
|
||||
* \param out a pointer to the element specified by location_values
|
||||
* e.g.
|
||||
* Given a tensor with overall shape [3,224,224], an element at
|
||||
* location [2,150,128] can be accessed directly.
|
||||
*
|
||||
* This function only works for numeric tensors.
|
||||
* This is a no-copy method whose pointer is only valid until the backing OrtValue is free'd.
|
||||
*/
|
||||
ORT_API2_STATUS(TensorAt, _Inout_ OrtValue* value, size_t* location_values, size_t location_values_count, _Outptr_ void** out);
|
||||
};
|
||||
|
||||
/*
|
||||
|
|
|
|||
|
|
@ -331,6 +331,9 @@ struct Value : Base<OrtValue> {
|
|||
template<typename T>
|
||||
const T* GetTensorData() const;
|
||||
|
||||
template <typename T>
|
||||
T At(const std::initializer_list<size_t>& location);
|
||||
|
||||
TypeInfo GetTypeInfo() const;
|
||||
TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const;
|
||||
|
||||
|
|
|
|||
|
|
@ -738,6 +738,14 @@ const T* Value::GetTensorData() const {
|
|||
return out;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T Value::At(const std::initializer_list<size_t>& location) {
|
||||
T* out;
|
||||
std::vector<size_t> location_ = location;
|
||||
ThrowOnError(GetApi().TensorAt(p_, location_.data(), location_.size(), (void**)&out));
|
||||
return *out;
|
||||
}
|
||||
|
||||
inline TypeInfo Value::GetTypeInfo() const {
|
||||
OrtTypeInfo* output;
|
||||
ThrowOnError(GetApi().GetTypeInfo(p_, &output));
|
||||
|
|
|
|||
|
|
@ -1630,6 +1630,31 @@ ORT_API_STATUS_IMPL(OrtApis::ReleaseAvailableProviders, _In_ char** ptr,
|
|||
return NULL;
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::TensorAt, _Inout_ OrtValue* value, size_t* location_values, size_t location_values_count,
|
||||
_Outptr_ void** out) {
|
||||
TENSOR_READWRITE_API_BEGIN
|
||||
//TODO: test if it's a string tensor
|
||||
if (location_values_count != tensor->Shape().NumDimensions())
|
||||
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "location dimensions do not match shape size");
|
||||
std::vector<size_t> location(location_values_count);
|
||||
for (size_t i = 0; i < location_values_count; i++) {
|
||||
if (location_values[i] >= (size_t)tensor->Shape()[i])
|
||||
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "invalid location range");
|
||||
location[i] = location_values[i];
|
||||
}
|
||||
// data has row-major format
|
||||
size_t offset = 0;
|
||||
for (size_t i = 1; i <= tensor->Shape().NumDimensions(); i++) {
|
||||
size_t sum = 1;
|
||||
for (size_t j = i+1; j <= tensor->Shape().NumDimensions(); j++) sum *= (size_t)tensor->Shape()[j-1];
|
||||
offset += location[i-1] * sum;
|
||||
}
|
||||
auto data = ((char *)tensor->MutableDataRaw()) + (tensor->DataType()->Size() * offset);
|
||||
*out = (void *)data;
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
// End support for non-tensor types
|
||||
|
||||
static constexpr OrtApiBase ort_api_base = {
|
||||
|
|
@ -1848,6 +1873,7 @@ static constexpr OrtApi ort_api_1_to_4 = {
|
|||
&OrtApis::ClearBoundInputs,
|
||||
&OrtApis::ClearBoundOutputs,
|
||||
|
||||
&OrtApis::TensorAt,
|
||||
};
|
||||
|
||||
// Assert to do a limited check to ensure Version 1 of OrtApi never changes (will detect an addition or deletion but not if they cancel out each other)
|
||||
|
|
|
|||
|
|
@ -198,7 +198,6 @@ ORT_API(void, ReleaseThreadingOptions, _Frees_ptr_opt_ OrtThreadingOptions*);
|
|||
ORT_API_STATUS_IMPL(ModelMetadataGetCustomMetadataMapKeys, _In_ const OrtModelMetadata* model_metadata,
|
||||
_Inout_ OrtAllocator* allocator, _Outptr_result_buffer_maybenull_(*num_keys) char*** keys, _Out_ int64_t* num_keys);
|
||||
|
||||
|
||||
ORT_API_STATUS_IMPL(AddFreeDimensionOverrideByName, _Inout_ OrtSessionOptions* options, _In_ const char* dim_name, _In_ int64_t dim_value);
|
||||
|
||||
ORT_API_STATUS_IMPL(CreateAllocator, const OrtSession* sess, const OrtMemoryInfo* mem_info,
|
||||
|
|
@ -229,4 +228,6 @@ ORT_API_STATUS_IMPL(ReleaseAvailableProviders, _In_ char **ptr,
|
|||
ORT_API_STATUS_IMPL(EnablePrePacking, _Inout_ OrtSessionOptions* options);
|
||||
ORT_API_STATUS_IMPL(DisablePrePacking, _Inout_ OrtSessionOptions* options);
|
||||
|
||||
ORT_API_STATUS_IMPL(TensorAt, _Inout_ OrtValue* value, size_t* location_values, size_t location_values_count, _Outptr_ void** out);
|
||||
|
||||
} // namespace OrtApis
|
||||
|
|
|
|||
|
|
@ -363,8 +363,7 @@ TEST(CApiTest, test_custom_op_library) {
|
|||
#if defined(ENABLE_LANGUAGE_INTEROP_OPS)
|
||||
std::once_flag my_module_flag;
|
||||
|
||||
void PrepareModule()
|
||||
{
|
||||
void PrepareModule() {
|
||||
std::ofstream module("mymodule.py");
|
||||
module << "class MyKernel:" << std::endl;
|
||||
module << "\t"
|
||||
|
|
@ -503,7 +502,6 @@ TEST(CApiTest, io_binding) {
|
|||
Ort::Value bound_y = Ort::Value::CreateTensor(info_cpu, y_values.data(), y_values.size(),
|
||||
y_shape.data(), y_shape.size());
|
||||
|
||||
|
||||
Ort::IoBinding binding(session);
|
||||
binding.BindInput("X", bound_x);
|
||||
binding.BindOutput("Y", bound_y);
|
||||
|
|
@ -586,7 +584,6 @@ TEST(CApiTest, fill_string_tensor) {
|
|||
Ort::Value tensor = Ort::Value::CreateTensor(default_allocator.get(), &expected_len, 1, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING);
|
||||
|
||||
for (int64_t i = 0; i < expected_len; i++) {
|
||||
|
||||
tensor.FillStringTensorElement(s[i], i);
|
||||
}
|
||||
|
||||
|
|
@ -636,6 +633,31 @@ TEST(CApiTest, create_tensor_with_data) {
|
|||
ASSERT_EQ(1u, tensor_info.GetDimensionsCount());
|
||||
}
|
||||
|
||||
TEST(CApiTest, access_tensor_data_elements) {
|
||||
/**
|
||||
* Create a 2x3 data blob that looks like:
|
||||
*
|
||||
* 0 1 2
|
||||
* 3 4 5
|
||||
*/
|
||||
std::vector<int64_t> shape = {2, 3};
|
||||
int element_count = 6; // 2*3
|
||||
std::vector<float> values(element_count);
|
||||
for (int i = 0; i < element_count; i++)
|
||||
values[i] = static_cast<float>(i);
|
||||
|
||||
Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
|
||||
|
||||
Ort::Value tensor = Ort::Value::CreateTensor<float>(info, values.data(), values.size(), shape.data(), shape.size());
|
||||
|
||||
float expected_value = 0;
|
||||
for (size_t row = 0; row < (size_t)shape[0]; row++) {
|
||||
for (size_t col = 0; col < (size_t)shape[1]; col++) {
|
||||
ASSERT_EQ(expected_value++, tensor.At<float>({row, col}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CApiTest, override_initializer) {
|
||||
Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
|
||||
auto allocator = onnxruntime::make_unique<MockedOrtAllocator>();
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
This directory contains a few (Windows only) C/C++ sample applications for demoing onnxruntime usage:
|
||||
This directory contains a few C/C++ sample applications for demoing onnxruntime usage:
|
||||
|
||||
1. fns_candy_style_transfer: A C application that uses the FNS-Candy style transfer model to re-style images.
|
||||
2. MNIST: A windows GUI application for doing handwriting recognition
|
||||
3. imagenet: An end-to-end sample for the [ImageNet Large Scale Visual Recognition Challenge 2012](http://www.image-net.org/challenges/LSVRC/2012/)
|
||||
|
||||
Imagenet sample requires ATL libraries installed as a part of VS Studio installation.
|
||||
1. (Windows only) fns_candy_style_transfer: A C application that uses the FNS-Candy style transfer model to re-style images.
|
||||
2. (Windows only) MNIST: A windows GUI application for doing handwriting recognition
|
||||
3. (Windows only) imagenet: An end-to-end sample for the [ImageNet Large Scale Visual Recognition Challenge 2012](http://www.image-net.org/challenges/LSVRC/2012/) - requires ATL libraries to be installed as a part of the VS Studio installation.
|
||||
4. model-explorer: A commandline C++ application that generates random data and performs model inference. A second C++ application demonstrates how to perform batch processing.
|
||||
|
||||
# How to build
|
||||
|
||||
|
|
|
|||
|
|
@ -60,10 +60,12 @@ int main(int argc, char** argv) {
|
|||
cout << "Usage: ./onnx-api-example <onnx_model.onnx>" << endl;
|
||||
return -1;
|
||||
}
|
||||
std::string model_file = argv[1];
|
||||
|
||||
// onnxruntime setup
|
||||
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "batch-model-explorer");
|
||||
Ort::SessionOptions session_options;
|
||||
Ort::Experimental::Session session = Ort::Experimental::Session(env, argv[1], session_options);
|
||||
Ort::Experimental::Session session = Ort::Experimental::Session(env, model_file, session_options);
|
||||
|
||||
// print name/shape of inputs
|
||||
auto input_names = session.GetInputNames();
|
||||
|
|
@ -98,9 +100,8 @@ int main(int argc, char** argv) {
|
|||
// Create an Ort tensor containing random numbers
|
||||
std::vector<float> batch_input_tensor_values(num_elements_per_batch);
|
||||
std::generate(batch_input_tensor_values.begin(), batch_input_tensor_values.end(), [&] { return rand() % 255; }); // generate random numbers in the range [0, 255]
|
||||
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
|
||||
std::vector<Ort::Value> batch_input_tensors;
|
||||
batch_input_tensors.push_back(Ort::Value::CreateTensor<float>(memory_info, batch_input_tensor_values.data(), batch_input_tensor_values.size(), input_shape.data(), input_shape.size()));
|
||||
batch_input_tensors.push_back(Ort::Experimental::Value::CreateTensor<float>(batch_input_tensor_values.data(), batch_input_tensor_values.size(), input_shape));
|
||||
|
||||
// double-check the dimensions of the input tensor
|
||||
assert(batch_input_tensors[0].IsTensor() &&
|
||||
|
|
|
|||
|
|
@ -49,10 +49,12 @@ int main(int argc, char** argv) {
|
|||
cout << "Usage: ./onnx-api-example <onnx_model.onnx>" << endl;
|
||||
return -1;
|
||||
}
|
||||
std::string model_file = argv[1];
|
||||
|
||||
// onnxruntime setup
|
||||
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "example-model-explorer");
|
||||
Ort::SessionOptions session_options;
|
||||
Ort::Experimental::Session session = Ort::Experimental::Session(env, argv[1], session_options); // access experimental components via the Experimental namespace
|
||||
Ort::Experimental::Session session = Ort::Experimental::Session(env, model_file, session_options); // access experimental components via the Experimental namespace
|
||||
|
||||
// print name/shape of inputs
|
||||
std::vector<std::string> input_names = session.GetInputNames();
|
||||
|
|
@ -78,9 +80,8 @@ int main(int argc, char** argv) {
|
|||
int total_number_elements = calculate_product(input_shape);
|
||||
std::vector<float> input_tensor_values(total_number_elements);
|
||||
std::generate(input_tensor_values.begin(), input_tensor_values.end(), [&] { return rand() % 255; }); // generate random numbers in the range [0, 255]
|
||||
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
|
||||
std::vector<Ort::Value> input_tensors;
|
||||
input_tensors.push_back(Ort::Value::CreateTensor<float>(memory_info, input_tensor_values.data(), input_tensor_values.size(), input_shape.data(), input_shape.size()));
|
||||
input_tensors.push_back(Ort::Experimental::Value::CreateTensor<float>(input_tensor_values.data(), input_tensor_values.size(), input_shape));
|
||||
|
||||
// double-check the dimensions of the input tensor
|
||||
assert(input_tensors[0].IsTensor() &&
|
||||
|
|
|
|||
Loading…
Reference in a new issue