onnxruntime/objectivec/ort_value.mm

247 lines
8.6 KiB
Text
Raw Normal View History

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#import "ort_value_internal.h"
#include <optional>
#import "cxx_api.h"
#import "error_utils.h"
#import "ort_enums_internal.h"
NS_ASSUME_NONNULL_BEGIN
namespace {
ORTTensorTypeAndShapeInfo* CXXAPIToPublicTensorTypeAndShapeInfo(
Deprecate CustomApi and refactor public API for better safety and consistency (#13215) ### Description Deprecate CustomOpApi and refactor dependencies for exception safety and eliminate memory leaks. Refactor API classes for clear ownership and semantics. Introduce `InitProviderOrtApi()` ### Motivation and Context Make public API better and safer. Special note about `Ort::Unowned`. The class suffers from the following problems: 1. It is not able to hold const pointers to the underlying C objects. This forces users to `const_cast` and circumvent constness of the returned object. The user is now able to call mutating interfaces on the object which violates invariants and may be a thread-safety issue. It also enables to take ownership of the pointer and destroy it unintentionally (see examples below). 2. The objects that are unowned cannot be copied and that makes coding inconvenient and at times unsafe. 3. It directly inherits from the type it `unowns`. All of the above creates great conditions for inadvertent unowned object mutations and destructions. Consider the following examples of object slicing, one of them is from a real customer issue and the other one I accidentally coded myself (and I am supposed to know how this works). None of the below can be solved by aftermarket patches and can be hard to diagnose. #### Example 1 slicing of argument ```cpp void SlicingOnArgument(Ort::Value& value) { // This will take possession of the input and if the argument // is Ort::Unowned<Ort::Value> it would again double free the ptr // regardless if it was const or not since we cast it away. Ort::Value output_values[] = {std::move(value)}; } void main() { const OrtValue* ptr = nullptr; // some value does not matter Ort::Unowned<Ort::Value> unowned{const_cast<OrtValue*>(ptr)}; // onowned is destroyed when the call returns. SlicingOnArgument(unowned); } ``` #### Example 2 slicing of return value ```cpp // The return will be sliced to Ort::Value that would own and relase (double free the ptr) Ort::Value SlicingOnReturn() { const OrtValue* ptr = nullptr; // some value does not matter Ort::Unowned<Ort::Value> unowned{const_cast<OrtValue*>(ptr)}; return unowned; } ```
2022-10-06 21:57:37 +00:00
const Ort::ConstTensorTypeAndShapeInfo& CXXAPITensorTypeAndShapeInfo) {
auto* result = [[ORTTensorTypeAndShapeInfo alloc] init];
const auto elementType = CXXAPITensorTypeAndShapeInfo.GetElementType();
const std::vector<int64_t> shape = CXXAPITensorTypeAndShapeInfo.GetShape();
result.elementType = CAPIToPublicTensorElementType(elementType);
auto* shapeArray = [[NSMutableArray alloc] initWithCapacity:shape.size()];
for (size_t i = 0; i < shape.size(); ++i) {
shapeArray[i] = @(shape[i]);
}
result.shape = shapeArray;
return result;
}
ORTValueTypeInfo* CXXAPIToPublicValueTypeInfo(
const Ort::TypeInfo& CXXAPITypeInfo) {
auto* result = [[ORTValueTypeInfo alloc] init];
const auto valueType = CXXAPITypeInfo.GetONNXType();
result.type = CAPIToPublicValueType(valueType);
if (valueType == ONNX_TYPE_TENSOR) {
const auto tensorTypeAndShapeInfo = CXXAPITypeInfo.GetTensorTypeAndShapeInfo();
result.tensorTypeAndShapeInfo = CXXAPIToPublicTensorTypeAndShapeInfo(tensorTypeAndShapeInfo);
}
return result;
}
// out = a * b
// returns true iff the result does not overflow
bool SafeMultiply(size_t a, size_t b, size_t& out) {
return !__builtin_mul_overflow(a, b, &out);
}
} // namespace
@interface ORTValue ()
// pointer to any external tensor data to keep alive for the lifetime of the ORTValue
@property(nonatomic, nullable) NSMutableData* externalTensorData;
@end
@implementation ORTValue {
std::optional<Ort::Value> _value;
std::optional<Ort::TypeInfo> _typeInfo;
}
#pragma mark - Public
- (nullable instancetype)initWithTensorData:(NSMutableData*)tensorData
elementType:(ORTTensorElementDataType)elementType
shape:(NSArray<NSNumber*>*)shape
error:(NSError**)error {
try {
if (elementType == ORTTensorElementDataTypeString) {
ORT_CXX_API_THROW(
"ORTTensorElementDataTypeString element type provided. "
"Please call initWithTensorStringData:shape:error: instead to create an ORTValue with string data.",
ORT_INVALID_ARGUMENT);
}
const auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
const auto ONNXElementType = PublicToCAPITensorElementType(elementType);
const auto shapeVector = [shape]() {
std::vector<int64_t> result{};
result.reserve(shape.count);
for (NSNumber* dim in shape) {
result.push_back(dim.longLongValue);
}
return result;
}();
Ort::Value ortValue = Ort::Value::CreateTensor(
memoryInfo, tensorData.mutableBytes, tensorData.length,
shapeVector.data(), shapeVector.size(), ONNXElementType);
return [self initWithCXXAPIOrtValue:std::move(ortValue)
externalTensorData:tensorData
error:error];
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
- (nullable instancetype)initWithTensorStringData:(NSArray<NSString*>*)tensorStringData
shape:(NSArray<NSNumber*>*)shape
error:(NSError**)error {
try {
Ort::AllocatorWithDefaultOptions allocator;
size_t tensorSize = 1U;
const auto shapeVector = [&tensorSize, shape]() {
std::vector<int64_t> result{};
result.reserve(shape.count);
for (NSNumber* dim in shape) {
const auto dimValue = dim.longLongValue;
if (dimValue < 0 || !SafeMultiply(static_cast<size_t>(dimValue), tensorSize, tensorSize)) {
ORT_CXX_API_THROW("Failed to compute the tensor size.", ORT_RUNTIME_EXCEPTION);
}
result.push_back(dimValue);
}
return result;
}();
if (tensorSize != [tensorStringData count]) {
ORT_CXX_API_THROW(
"Computed tensor size does not equal the length of the provided tensor string data.",
ORT_INVALID_ARGUMENT);
}
Ort::Value ortValue = Ort::Value::CreateTensor(
allocator, shapeVector.data(), shapeVector.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING);
size_t index = 0;
for (NSString* stringData in tensorStringData) {
ortValue.FillStringTensorElement([stringData UTF8String], index++);
}
return [self initWithCXXAPIOrtValue:std::move(ortValue)
externalTensorData:nil
error:error];
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
- (nullable ORTValueTypeInfo*)typeInfoWithError:(NSError**)error {
try {
return CXXAPIToPublicValueTypeInfo(*_typeInfo);
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
- (nullable ORTTensorTypeAndShapeInfo*)tensorTypeAndShapeInfoWithError:(NSError**)error {
try {
const auto tensorTypeAndShapeInfo = _typeInfo->GetTensorTypeAndShapeInfo();
if (!tensorTypeAndShapeInfo) {
ORT_CXX_API_THROW("ORTValue is not a tensor.", ORT_RUNTIME_EXCEPTION);
}
return CXXAPIToPublicTensorTypeAndShapeInfo(tensorTypeAndShapeInfo);
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
- (nullable NSMutableData*)tensorDataWithError:(NSError**)error {
try {
const auto tensorTypeAndShapeInfo = _typeInfo->GetTensorTypeAndShapeInfo();
if (!tensorTypeAndShapeInfo) {
ORT_CXX_API_THROW("ORTValue is not a tensor.", ORT_RUNTIME_EXCEPTION);
}
if (tensorTypeAndShapeInfo.GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
ORT_CXX_API_THROW(
"This ORTValue holds string data. Please call tensorStringDataWithError: "
"instead to retrieve the string data from this ORTValue.",
ORT_RUNTIME_EXCEPTION);
}
const size_t elementCount = tensorTypeAndShapeInfo.GetElementCount();
const size_t elementSize = SizeOfCAPITensorElementType(tensorTypeAndShapeInfo.GetElementType());
size_t rawDataLength;
if (!SafeMultiply(elementCount, elementSize, rawDataLength)) {
ORT_CXX_API_THROW("failed to compute tensor data length", ORT_RUNTIME_EXCEPTION);
}
void* rawData;
Ort::ThrowOnError(Ort::GetApi().GetTensorMutableData(*_value, &rawData));
return [NSMutableData dataWithBytesNoCopy:rawData
length:rawDataLength
freeWhenDone:NO];
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
- (nullable NSArray<NSString*>*)tensorStringDataWithError:(NSError**)error {
try {
const auto tensorTypeAndShapeInfo = _typeInfo->GetTensorTypeAndShapeInfo();
if (!tensorTypeAndShapeInfo) {
ORT_CXX_API_THROW("ORTValue is not a tensor.", ORT_RUNTIME_EXCEPTION);
}
const size_t elementCount = tensorTypeAndShapeInfo.GetElementCount();
const size_t tensorStringDataLength = _value->GetStringTensorDataLength();
std::vector<char> tensorStringData(tensorStringDataLength, '\0');
std::vector<size_t> offsets(elementCount);
_value->GetStringTensorContent(tensorStringData.data(), tensorStringDataLength,
offsets.data(), offsets.size());
NSMutableArray<NSString*>* result = [NSMutableArray arrayWithCapacity:elementCount];
for (size_t idx = 0; idx < elementCount; ++idx) {
const size_t strLength = (idx == elementCount - 1) ? tensorStringDataLength - offsets[idx]
: offsets[idx + 1] - offsets[idx];
[result addObject:[[NSString alloc] initWithBytes:tensorStringData.data() + offsets[idx]
length:strLength
encoding:NSUTF8StringEncoding]];
}
return result;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
#pragma mark - Internal
- (nullable instancetype)initWithCXXAPIOrtValue:(Ort::Value&&)existingCXXAPIOrtValue
externalTensorData:(nullable NSMutableData*)externalTensorData
error:(NSError**)error {
if ((self = [super init]) == nil) {
return nil;
}
try {
_typeInfo = existingCXXAPIOrtValue.GetTypeInfo();
_externalTensorData = externalTensorData;
// transfer C++ Ort::Value ownership to this instance
_value = std::move(existingCXXAPIOrtValue);
return self;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error);
}
- (Ort::Value&)CXXAPIOrtValue {
return *_value;
}
@end
@implementation ORTValueTypeInfo
@end
@implementation ORTTensorTypeAndShapeInfo
@end
NS_ASSUME_NONNULL_END