diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index cbf9fe232d..6dab20a055 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -1009,7 +1009,7 @@ struct OrtApi { * This function only works for numeric tensors. * This is a no-copy method whose pointer is only valid until the backing OrtValue is free'd. */ - ORT_API2_STATUS(TensorAt, _Inout_ OrtValue* value, size_t* location_values, size_t location_values_count, _Outptr_ void** out); + ORT_API2_STATUS(TensorAt, _Inout_ OrtValue* value, const int64_t* location_values, size_t location_values_count, _Outptr_ void** out); /** * Creates an allocator instance and registers it with the env to enable diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index ff2ca58696..9e8f365adc 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -371,7 +371,7 @@ struct Value : Base { const T* GetTensorData() const; template - T At(const std::initializer_list& location); + T& At(const std::vector& location); TypeInfo GetTypeInfo() const; TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index ffe927d457..104ae566d5 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -777,10 +777,10 @@ const T* Value::GetTensorData() const { } template -inline T Value::At(const std::initializer_list& location) { +inline T& Value::At(const std::vector& location) { + static_assert(!std::is_same::value, "this api does not support std::string"); T* out; - std::vector location_ = location; - ThrowOnError(GetApi().TensorAt(p_, location_.data(), location_.size(), (void**)&out)); + ThrowOnError(GetApi().TensorAt(p_, location.data(), location.size(), (void**)&out)); return *out; } diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 424873ac5e..dbec53ec89 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -1727,27 +1727,45 @@ ORT_API_STATUS_IMPL(OrtApis::ReleaseAvailableProviders, _In_ char** ptr, return NULL; } -ORT_API_STATUS_IMPL(OrtApis::TensorAt, _Inout_ OrtValue* value, size_t* location_values, size_t location_values_count, +ORT_API_STATUS_IMPL(OrtApis::TensorAt, _Inout_ OrtValue* value, const int64_t* location_values, size_t location_values_count, _Outptr_ void** out) { TENSOR_READWRITE_API_BEGIN - //TODO: test if it's a string tensor - if (location_values_count != tensor->Shape().NumDimensions()) + + if(tensor->IsDataTypeString()) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "this API does not support strings"); + } + + const auto& tensor_shape = tensor->Shape(); + const auto num_dimensions = tensor_shape.NumDimensions(); + if (location_values_count != num_dimensions) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "location dimensions do not match shape size"); - std::vector location(location_values_count); + } + for (size_t i = 0; i < location_values_count; i++) { - if (location_values[i] >= (size_t)tensor->Shape()[i]) + if (location_values[i] >= tensor_shape[i] || location_values[i] < 0) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "invalid location range"); - location[i] = location_values[i]; + } } - // data has row-major format - size_t offset = 0; - for (size_t i = 1; i <= tensor->Shape().NumDimensions(); i++) { - size_t sum = 1; - for (size_t j = i + 1; j <= tensor->Shape().NumDimensions(); j++) sum *= (size_t)tensor->Shape()[j - 1]; - offset += location[i - 1] * sum; + + // compute strides + // TensorPitches p; + std::vector strides(num_dimensions); + { + int64_t stride = 1; + for (size_t dim = num_dimensions; dim > 0; --dim) { + strides[dim - 1] = stride; + stride *= tensor_shape[dim - 1]; + } } - auto data = ((char*)tensor->MutableDataRaw()) + (tensor->DataType()->Size() * offset); - *out = (void*)data; + + // For Scalers the offset would always be zero + int64_t offset = 0; + for (size_t i = 0; i < num_dimensions; i++) { + offset += location_values[i] * strides[i]; + } + + auto data = reinterpret_cast(tensor->MutableDataRaw()) + tensor->DataType()->Size() * offset; + *out = data; return nullptr; API_IMPL_END } diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 21b4aca34d..2d86ae39d4 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -228,7 +228,7 @@ ORT_API_STATUS_IMPL(ReleaseAvailableProviders, _In_ char** ptr, ORT_API_STATUS_IMPL(AddSessionConfigEntry, _Inout_ OrtSessionOptions* options, _In_z_ const char* config_key, _In_z_ const char* config_value); -ORT_API_STATUS_IMPL(TensorAt, _Inout_ OrtValue* value, size_t* location_values, size_t location_values_count, _Outptr_ void** out); +ORT_API_STATUS_IMPL(TensorAt, _Inout_ OrtValue* value, const int64_t* location_values, size_t location_values_count, _Outptr_ void** out); ORT_API_STATUS_IMPL(CreateAndRegisterAllocator, _Inout_ OrtEnv* env, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg); diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 2cae2269d3..04169d129c 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -680,8 +680,8 @@ TEST(CApiTest, access_tensor_data_elements) { Ort::Value tensor = Ort::Value::CreateTensor(info, values.data(), values.size(), shape.data(), shape.size()); float expected_value = 0; - for (size_t row = 0; row < (size_t)shape[0]; row++) { - for (size_t col = 0; col < (size_t)shape[1]; col++) { + for (int64_t row = 0; row < shape[0]; row++) { + for (int64_t col = 0; col < shape[1]; col++) { ASSERT_EQ(expected_value++, tensor.At({row, col})); } }