Refactor TensorAt, prepare for release (#5180)

* Refactor TensorAt
  locations* must be const and int64_t since our dims are int64_t
  Remove unnecessary copy of locations.
  Remove unnecesary casting and C-casting. Simplify implementation.
  Add a check for string type.
  Make CXX api return T& to fully expose C API in C++, const std::vector& by value as it
  covers more ground and eliminate redundant copy.
  Eliminate inner loop, compute strides first.
This commit is contained in:
Dmitri Smirnov 2020-09-16 10:20:45 -07:00 committed by GitHub
parent a20f8037f6
commit e6f85f338e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 40 additions and 22 deletions

View file

@ -1009,7 +1009,7 @@ struct OrtApi {
* This function only works for numeric tensors.
* This is a no-copy method whose pointer is only valid until the backing OrtValue is free'd.
*/
ORT_API2_STATUS(TensorAt, _Inout_ OrtValue* value, size_t* location_values, size_t location_values_count, _Outptr_ void** out);
ORT_API2_STATUS(TensorAt, _Inout_ OrtValue* value, const int64_t* location_values, size_t location_values_count, _Outptr_ void** out);
/**
* Creates an allocator instance and registers it with the env to enable

View file

@ -371,7 +371,7 @@ struct Value : Base<OrtValue> {
const T* GetTensorData() const;
template <typename T>
T At(const std::initializer_list<size_t>& location);
T& At(const std::vector<int64_t>& location);
TypeInfo GetTypeInfo() const;
TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const;

View file

@ -777,10 +777,10 @@ const T* Value::GetTensorData() const {
}
template <typename T>
inline T Value::At(const std::initializer_list<size_t>& location) {
inline T& Value::At(const std::vector<int64_t>& location) {
static_assert(!std::is_same<T, std::string>::value, "this api does not support std::string");
T* out;
std::vector<size_t> location_ = location;
ThrowOnError(GetApi().TensorAt(p_, location_.data(), location_.size(), (void**)&out));
ThrowOnError(GetApi().TensorAt(p_, location.data(), location.size(), (void**)&out));
return *out;
}

View file

@ -1727,27 +1727,45 @@ ORT_API_STATUS_IMPL(OrtApis::ReleaseAvailableProviders, _In_ char** ptr,
return NULL;
}
ORT_API_STATUS_IMPL(OrtApis::TensorAt, _Inout_ OrtValue* value, size_t* location_values, size_t location_values_count,
ORT_API_STATUS_IMPL(OrtApis::TensorAt, _Inout_ OrtValue* value, const int64_t* location_values, size_t location_values_count,
_Outptr_ void** out) {
TENSOR_READWRITE_API_BEGIN
//TODO: test if it's a string tensor
if (location_values_count != tensor->Shape().NumDimensions())
if(tensor->IsDataTypeString()) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "this API does not support strings");
}
const auto& tensor_shape = tensor->Shape();
const auto num_dimensions = tensor_shape.NumDimensions();
if (location_values_count != num_dimensions) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "location dimensions do not match shape size");
std::vector<size_t> location(location_values_count);
}
for (size_t i = 0; i < location_values_count; i++) {
if (location_values[i] >= (size_t)tensor->Shape()[i])
if (location_values[i] >= tensor_shape[i] || location_values[i] < 0) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "invalid location range");
location[i] = location_values[i];
}
}
// data has row-major format
size_t offset = 0;
for (size_t i = 1; i <= tensor->Shape().NumDimensions(); i++) {
size_t sum = 1;
for (size_t j = i + 1; j <= tensor->Shape().NumDimensions(); j++) sum *= (size_t)tensor->Shape()[j - 1];
offset += location[i - 1] * sum;
// compute strides
// TensorPitches p;
std::vector<int64_t> strides(num_dimensions);
{
int64_t stride = 1;
for (size_t dim = num_dimensions; dim > 0; --dim) {
strides[dim - 1] = stride;
stride *= tensor_shape[dim - 1];
}
}
auto data = ((char*)tensor->MutableDataRaw()) + (tensor->DataType()->Size() * offset);
*out = (void*)data;
// For Scalers the offset would always be zero
int64_t offset = 0;
for (size_t i = 0; i < num_dimensions; i++) {
offset += location_values[i] * strides[i];
}
auto data = reinterpret_cast<char*>(tensor->MutableDataRaw()) + tensor->DataType()->Size() * offset;
*out = data;
return nullptr;
API_IMPL_END
}

View file

@ -228,7 +228,7 @@ ORT_API_STATUS_IMPL(ReleaseAvailableProviders, _In_ char** ptr,
ORT_API_STATUS_IMPL(AddSessionConfigEntry, _Inout_ OrtSessionOptions* options,
_In_z_ const char* config_key, _In_z_ const char* config_value);
ORT_API_STATUS_IMPL(TensorAt, _Inout_ OrtValue* value, size_t* location_values, size_t location_values_count, _Outptr_ void** out);
ORT_API_STATUS_IMPL(TensorAt, _Inout_ OrtValue* value, const int64_t* location_values, size_t location_values_count, _Outptr_ void** out);
ORT_API_STATUS_IMPL(CreateAndRegisterAllocator, _Inout_ OrtEnv* env, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg);

View file

@ -680,8 +680,8 @@ TEST(CApiTest, access_tensor_data_elements) {
Ort::Value tensor = Ort::Value::CreateTensor<float>(info, values.data(), values.size(), shape.data(), shape.size());
float expected_value = 0;
for (size_t row = 0; row < (size_t)shape[0]; row++) {
for (size_t col = 0; col < (size_t)shape[1]; col++) {
for (int64_t row = 0; row < shape[0]; row++) {
for (int64_t col = 0; col < shape[1]; col++) {
ASSERT_EQ(expected_value++, tensor.At<float>({row, col}));
}
}