Add new api function At() (#4457)

* add modern standards to function arguments
* add first version of At for better tensor element access
This commit is contained in:
Josh Bradley 2020-08-11 21:34:03 -04:00 committed by GitHub
parent 38c804a048
commit b7254551f0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 106 additions and 31 deletions

View file

@ -22,8 +22,8 @@
namespace Ort::Experimental {
struct Session : Ort::Session {
Session(Env& env, ORTCHAR_T* model_path, SessionOptions& options)
: Ort::Session(env, model_path, options){};
Session(Env& env, std::basic_string<ORTCHAR_T>& model_path, SessionOptions& options)
: Ort::Session(env, model_path.data(), options){};
Session(Env& env, void* model_data, size_t model_data_length, SessionOptions& options)
: Ort::Session(env, model_data, model_data_length, options){};

View file

@ -959,6 +959,20 @@ struct OrtApi {
*/
void(ORT_API_CALL* ClearBoundInputs)(_Inout_ OrtIoBinding* binding_ptr) NO_EXCEPTION ORT_ALL_ARGS_NONNULL;
void(ORT_API_CALL* ClearBoundOutputs)(_Inout_ OrtIoBinding* binding_ptr) NO_EXCEPTION ORT_ALL_ARGS_NONNULL;
/**
* Provides element-level access into a tensor.
* \param location_values a pointer to an array of index values that specify an element's location in the tensor data blob
* \param location_values_count length of location_values
* \param out a pointer to the element specified by location_values
* e.g.
* Given a tensor with overall shape [3,224,224], an element at
* location [2,150,128] can be accessed directly.
*
* This function only works for numeric tensors.
* This is a no-copy method whose pointer is only valid until the backing OrtValue is free'd.
*/
ORT_API2_STATUS(TensorAt, _Inout_ OrtValue* value, size_t* location_values, size_t location_values_count, _Outptr_ void** out);
};
/*

View file

@ -331,6 +331,9 @@ struct Value : Base<OrtValue> {
template<typename T>
const T* GetTensorData() const;
template <typename T>
T At(const std::initializer_list<size_t>& location);
TypeInfo GetTypeInfo() const;
TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const;

View file

@ -738,6 +738,14 @@ const T* Value::GetTensorData() const {
return out;
}
template <typename T>
inline T Value::At(const std::initializer_list<size_t>& location) {
T* out;
std::vector<size_t> location_ = location;
ThrowOnError(GetApi().TensorAt(p_, location_.data(), location_.size(), (void**)&out));
return *out;
}
inline TypeInfo Value::GetTypeInfo() const {
OrtTypeInfo* output;
ThrowOnError(GetApi().GetTypeInfo(p_, &output));

View file

@ -1630,6 +1630,31 @@ ORT_API_STATUS_IMPL(OrtApis::ReleaseAvailableProviders, _In_ char** ptr,
return NULL;
}
ORT_API_STATUS_IMPL(OrtApis::TensorAt, _Inout_ OrtValue* value, size_t* location_values, size_t location_values_count,
_Outptr_ void** out) {
TENSOR_READWRITE_API_BEGIN
//TODO: test if it's a string tensor
if (location_values_count != tensor->Shape().NumDimensions())
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "location dimensions do not match shape size");
std::vector<size_t> location(location_values_count);
for (size_t i = 0; i < location_values_count; i++) {
if (location_values[i] >= (size_t)tensor->Shape()[i])
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "invalid location range");
location[i] = location_values[i];
}
// data has row-major format
size_t offset = 0;
for (size_t i = 1; i <= tensor->Shape().NumDimensions(); i++) {
size_t sum = 1;
for (size_t j = i+1; j <= tensor->Shape().NumDimensions(); j++) sum *= (size_t)tensor->Shape()[j-1];
offset += location[i-1] * sum;
}
auto data = ((char *)tensor->MutableDataRaw()) + (tensor->DataType()->Size() * offset);
*out = (void *)data;
return nullptr;
API_IMPL_END
}
// End support for non-tensor types
static constexpr OrtApiBase ort_api_base = {
@ -1848,6 +1873,7 @@ static constexpr OrtApi ort_api_1_to_4 = {
&OrtApis::ClearBoundInputs,
&OrtApis::ClearBoundOutputs,
&OrtApis::TensorAt,
};
// Assert to do a limited check to ensure Version 1 of OrtApi never changes (will detect an addition or deletion but not if they cancel out each other)

View file

@ -198,7 +198,6 @@ ORT_API(void, ReleaseThreadingOptions, _Frees_ptr_opt_ OrtThreadingOptions*);
ORT_API_STATUS_IMPL(ModelMetadataGetCustomMetadataMapKeys, _In_ const OrtModelMetadata* model_metadata,
_Inout_ OrtAllocator* allocator, _Outptr_result_buffer_maybenull_(*num_keys) char*** keys, _Out_ int64_t* num_keys);
ORT_API_STATUS_IMPL(AddFreeDimensionOverrideByName, _Inout_ OrtSessionOptions* options, _In_ const char* dim_name, _In_ int64_t dim_value);
ORT_API_STATUS_IMPL(CreateAllocator, const OrtSession* sess, const OrtMemoryInfo* mem_info,
@ -229,4 +228,6 @@ ORT_API_STATUS_IMPL(ReleaseAvailableProviders, _In_ char **ptr,
ORT_API_STATUS_IMPL(EnablePrePacking, _Inout_ OrtSessionOptions* options);
ORT_API_STATUS_IMPL(DisablePrePacking, _Inout_ OrtSessionOptions* options);
ORT_API_STATUS_IMPL(TensorAt, _Inout_ OrtValue* value, size_t* location_values, size_t location_values_count, _Outptr_ void** out);
} // namespace OrtApis

View file

@ -363,8 +363,7 @@ TEST(CApiTest, test_custom_op_library) {
#if defined(ENABLE_LANGUAGE_INTEROP_OPS)
std::once_flag my_module_flag;
void PrepareModule()
{
void PrepareModule() {
std::ofstream module("mymodule.py");
module << "class MyKernel:" << std::endl;
module << "\t"
@ -503,7 +502,6 @@ TEST(CApiTest, io_binding) {
Ort::Value bound_y = Ort::Value::CreateTensor(info_cpu, y_values.data(), y_values.size(),
y_shape.data(), y_shape.size());
Ort::IoBinding binding(session);
binding.BindInput("X", bound_x);
binding.BindOutput("Y", bound_y);
@ -586,7 +584,6 @@ TEST(CApiTest, fill_string_tensor) {
Ort::Value tensor = Ort::Value::CreateTensor(default_allocator.get(), &expected_len, 1, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING);
for (int64_t i = 0; i < expected_len; i++) {
tensor.FillStringTensorElement(s[i], i);
}
@ -636,6 +633,31 @@ TEST(CApiTest, create_tensor_with_data) {
ASSERT_EQ(1u, tensor_info.GetDimensionsCount());
}
TEST(CApiTest, access_tensor_data_elements) {
/**
* Create a 2x3 data blob that looks like:
*
* 0 1 2
* 3 4 5
*/
std::vector<int64_t> shape = {2, 3};
int element_count = 6; // 2*3
std::vector<float> values(element_count);
for (int i = 0; i < element_count; i++)
values[i] = static_cast<float>(i);
Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
Ort::Value tensor = Ort::Value::CreateTensor<float>(info, values.data(), values.size(), shape.data(), shape.size());
float expected_value = 0;
for (size_t row = 0; row < (size_t)shape[0]; row++) {
for (size_t col = 0; col < (size_t)shape[1]; col++) {
ASSERT_EQ(expected_value++, tensor.At<float>({row, col}));
}
}
}
TEST(CApiTest, override_initializer) {
Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
auto allocator = onnxruntime::make_unique<MockedOrtAllocator>();

View file

@ -1,10 +1,9 @@
This directory contains a few (Windows only) C/C++ sample applications for demoing onnxruntime usage:
This directory contains a few C/C++ sample applications for demoing onnxruntime usage:
1. fns_candy_style_transfer: A C application that uses the FNS-Candy style transfer model to re-style images.
2. MNIST: A windows GUI application for doing handwriting recognition
3. imagenet: An end-to-end sample for the [ImageNet Large Scale Visual Recognition Challenge 2012](http://www.image-net.org/challenges/LSVRC/2012/)
Imagenet sample requires ATL libraries installed as a part of VS Studio installation.
1. (Windows only) fns_candy_style_transfer: A C application that uses the FNS-Candy style transfer model to re-style images.
2. (Windows only) MNIST: A windows GUI application for doing handwriting recognition
3. (Windows only) imagenet: An end-to-end sample for the [ImageNet Large Scale Visual Recognition Challenge 2012](http://www.image-net.org/challenges/LSVRC/2012/) - requires ATL libraries to be installed as a part of the VS Studio installation.
4. model-explorer: A commandline C++ application that generates random data and performs model inference. A second C++ application demonstrates how to perform batch processing.
# How to build

View file

@ -60,10 +60,12 @@ int main(int argc, char** argv) {
cout << "Usage: ./onnx-api-example <onnx_model.onnx>" << endl;
return -1;
}
std::string model_file = argv[1];
// onnxruntime setup
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "batch-model-explorer");
Ort::SessionOptions session_options;
Ort::Experimental::Session session = Ort::Experimental::Session(env, argv[1], session_options);
Ort::Experimental::Session session = Ort::Experimental::Session(env, model_file, session_options);
// print name/shape of inputs
auto input_names = session.GetInputNames();
@ -98,9 +100,8 @@ int main(int argc, char** argv) {
// Create an Ort tensor containing random numbers
std::vector<float> batch_input_tensor_values(num_elements_per_batch);
std::generate(batch_input_tensor_values.begin(), batch_input_tensor_values.end(), [&] { return rand() % 255; }); // generate random numbers in the range [0, 255]
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
std::vector<Ort::Value> batch_input_tensors;
batch_input_tensors.push_back(Ort::Value::CreateTensor<float>(memory_info, batch_input_tensor_values.data(), batch_input_tensor_values.size(), input_shape.data(), input_shape.size()));
batch_input_tensors.push_back(Ort::Experimental::Value::CreateTensor<float>(batch_input_tensor_values.data(), batch_input_tensor_values.size(), input_shape));
// double-check the dimensions of the input tensor
assert(batch_input_tensors[0].IsTensor() &&

View file

@ -49,10 +49,12 @@ int main(int argc, char** argv) {
cout << "Usage: ./onnx-api-example <onnx_model.onnx>" << endl;
return -1;
}
std::string model_file = argv[1];
// onnxruntime setup
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "example-model-explorer");
Ort::SessionOptions session_options;
Ort::Experimental::Session session = Ort::Experimental::Session(env, argv[1], session_options); // access experimental components via the Experimental namespace
Ort::Experimental::Session session = Ort::Experimental::Session(env, model_file, session_options); // access experimental components via the Experimental namespace
// print name/shape of inputs
std::vector<std::string> input_names = session.GetInputNames();
@ -78,9 +80,8 @@ int main(int argc, char** argv) {
int total_number_elements = calculate_product(input_shape);
std::vector<float> input_tensor_values(total_number_elements);
std::generate(input_tensor_values.begin(), input_tensor_values.end(), [&] { return rand() % 255; }); // generate random numbers in the range [0, 255]
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
std::vector<Ort::Value> input_tensors;
input_tensors.push_back(Ort::Value::CreateTensor<float>(memory_info, input_tensor_values.data(), input_tensor_values.size(), input_shape.data(), input_shape.size()));
input_tensors.push_back(Ort::Experimental::Value::CreateTensor<float>(input_tensor_values.data(), input_tensor_values.size(), input_shape));
// double-check the dimensions of the input tensor
assert(input_tensors[0].IsTensor() &&