mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Refactor TensorAt, prepare for release (#5180)
* Refactor TensorAt locations* must be const and int64_t since our dims are int64_t Remove unnecessary copy of locations. Remove unnecesary casting and C-casting. Simplify implementation. Add a check for string type. Make CXX api return T& to fully expose C API in C++, const std::vector& by value as it covers more ground and eliminate redundant copy. Eliminate inner loop, compute strides first.
This commit is contained in:
parent
a20f8037f6
commit
e6f85f338e
6 changed files with 40 additions and 22 deletions
|
|
@ -1009,7 +1009,7 @@ struct OrtApi {
|
|||
* This function only works for numeric tensors.
|
||||
* This is a no-copy method whose pointer is only valid until the backing OrtValue is free'd.
|
||||
*/
|
||||
ORT_API2_STATUS(TensorAt, _Inout_ OrtValue* value, size_t* location_values, size_t location_values_count, _Outptr_ void** out);
|
||||
ORT_API2_STATUS(TensorAt, _Inout_ OrtValue* value, const int64_t* location_values, size_t location_values_count, _Outptr_ void** out);
|
||||
|
||||
/**
|
||||
* Creates an allocator instance and registers it with the env to enable
|
||||
|
|
|
|||
|
|
@ -371,7 +371,7 @@ struct Value : Base<OrtValue> {
|
|||
const T* GetTensorData() const;
|
||||
|
||||
template <typename T>
|
||||
T At(const std::initializer_list<size_t>& location);
|
||||
T& At(const std::vector<int64_t>& location);
|
||||
|
||||
TypeInfo GetTypeInfo() const;
|
||||
TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const;
|
||||
|
|
|
|||
|
|
@ -777,10 +777,10 @@ const T* Value::GetTensorData() const {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
inline T Value::At(const std::initializer_list<size_t>& location) {
|
||||
inline T& Value::At(const std::vector<int64_t>& location) {
|
||||
static_assert(!std::is_same<T, std::string>::value, "this api does not support std::string");
|
||||
T* out;
|
||||
std::vector<size_t> location_ = location;
|
||||
ThrowOnError(GetApi().TensorAt(p_, location_.data(), location_.size(), (void**)&out));
|
||||
ThrowOnError(GetApi().TensorAt(p_, location.data(), location.size(), (void**)&out));
|
||||
return *out;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1727,27 +1727,45 @@ ORT_API_STATUS_IMPL(OrtApis::ReleaseAvailableProviders, _In_ char** ptr,
|
|||
return NULL;
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::TensorAt, _Inout_ OrtValue* value, size_t* location_values, size_t location_values_count,
|
||||
ORT_API_STATUS_IMPL(OrtApis::TensorAt, _Inout_ OrtValue* value, const int64_t* location_values, size_t location_values_count,
|
||||
_Outptr_ void** out) {
|
||||
TENSOR_READWRITE_API_BEGIN
|
||||
//TODO: test if it's a string tensor
|
||||
if (location_values_count != tensor->Shape().NumDimensions())
|
||||
|
||||
if(tensor->IsDataTypeString()) {
|
||||
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "this API does not support strings");
|
||||
}
|
||||
|
||||
const auto& tensor_shape = tensor->Shape();
|
||||
const auto num_dimensions = tensor_shape.NumDimensions();
|
||||
if (location_values_count != num_dimensions) {
|
||||
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "location dimensions do not match shape size");
|
||||
std::vector<size_t> location(location_values_count);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < location_values_count; i++) {
|
||||
if (location_values[i] >= (size_t)tensor->Shape()[i])
|
||||
if (location_values[i] >= tensor_shape[i] || location_values[i] < 0) {
|
||||
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "invalid location range");
|
||||
location[i] = location_values[i];
|
||||
}
|
||||
}
|
||||
// data has row-major format
|
||||
size_t offset = 0;
|
||||
for (size_t i = 1; i <= tensor->Shape().NumDimensions(); i++) {
|
||||
size_t sum = 1;
|
||||
for (size_t j = i + 1; j <= tensor->Shape().NumDimensions(); j++) sum *= (size_t)tensor->Shape()[j - 1];
|
||||
offset += location[i - 1] * sum;
|
||||
|
||||
// compute strides
|
||||
// TensorPitches p;
|
||||
std::vector<int64_t> strides(num_dimensions);
|
||||
{
|
||||
int64_t stride = 1;
|
||||
for (size_t dim = num_dimensions; dim > 0; --dim) {
|
||||
strides[dim - 1] = stride;
|
||||
stride *= tensor_shape[dim - 1];
|
||||
}
|
||||
}
|
||||
auto data = ((char*)tensor->MutableDataRaw()) + (tensor->DataType()->Size() * offset);
|
||||
*out = (void*)data;
|
||||
|
||||
// For Scalers the offset would always be zero
|
||||
int64_t offset = 0;
|
||||
for (size_t i = 0; i < num_dimensions; i++) {
|
||||
offset += location_values[i] * strides[i];
|
||||
}
|
||||
|
||||
auto data = reinterpret_cast<char*>(tensor->MutableDataRaw()) + tensor->DataType()->Size() * offset;
|
||||
*out = data;
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
|
|
|||
|
|
@ -228,7 +228,7 @@ ORT_API_STATUS_IMPL(ReleaseAvailableProviders, _In_ char** ptr,
|
|||
ORT_API_STATUS_IMPL(AddSessionConfigEntry, _Inout_ OrtSessionOptions* options,
|
||||
_In_z_ const char* config_key, _In_z_ const char* config_value);
|
||||
|
||||
ORT_API_STATUS_IMPL(TensorAt, _Inout_ OrtValue* value, size_t* location_values, size_t location_values_count, _Outptr_ void** out);
|
||||
ORT_API_STATUS_IMPL(TensorAt, _Inout_ OrtValue* value, const int64_t* location_values, size_t location_values_count, _Outptr_ void** out);
|
||||
|
||||
ORT_API_STATUS_IMPL(CreateAndRegisterAllocator, _Inout_ OrtEnv* env, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg);
|
||||
|
||||
|
|
|
|||
|
|
@ -680,8 +680,8 @@ TEST(CApiTest, access_tensor_data_elements) {
|
|||
Ort::Value tensor = Ort::Value::CreateTensor<float>(info, values.data(), values.size(), shape.data(), shape.size());
|
||||
|
||||
float expected_value = 0;
|
||||
for (size_t row = 0; row < (size_t)shape[0]; row++) {
|
||||
for (size_t col = 0; col < (size_t)shape[1]; col++) {
|
||||
for (int64_t row = 0; row < shape[0]; row++) {
|
||||
for (int64_t col = 0; col < shape[1]; col++) {
|
||||
ASSERT_EQ(expected_value++, tensor.At<float>({row, col}));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue