onnxruntime/objectivec/ort_value.mm

161 lines
4.9 KiB
Text
Raw Normal View History

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#import "ort_value_internal.h"
#include <optional>
#import "cxx_api.h"
#import "error_utils.h"
#import "ort_enums_internal.h"
NS_ASSUME_NONNULL_BEGIN
namespace {
ORTTensorTypeAndShapeInfo* CXXAPIToPublicTensorTypeAndShapeInfo(
Deprecate CustomApi and refactor public API for better safety and consistency (#13215) ### Description Deprecate CustomOpApi and refactor dependencies for exception safety and eliminate memory leaks. Refactor API classes for clear ownership and semantics. Introduce `InitProviderOrtApi()` ### Motivation and Context Make public API better and safer. Special note about `Ort::Unowned`. The class suffers from the following problems: 1. It is not able to hold const pointers to the underlying C objects. This forces users to `const_cast` and circumvent constness of the returned object. The user is now able to call mutating interfaces on the object which violates invariants and may be a thread-safety issue. It also enables to take ownership of the pointer and destroy it unintentionally (see examples below). 2. The objects that are unowned cannot be copied and that makes coding inconvenient and at times unsafe. 3. It directly inherits from the type it `unowns`. All of the above creates great conditions for inadvertent unowned object mutations and destructions. Consider the following examples of object slicing, one of them is from a real customer issue and the other one I accidentally coded myself (and I am supposed to know how this works). None of the below can be solved by aftermarket patches and can be hard to diagnose. #### Example 1 slicing of argument ```cpp void SlicingOnArgument(Ort::Value& value) { // This will take possession of the input and if the argument // is Ort::Unowned<Ort::Value> it would again double free the ptr // regardless if it was const or not since we cast it away. Ort::Value output_values[] = {std::move(value)}; } void main() { const OrtValue* ptr = nullptr; // some value does not matter Ort::Unowned<Ort::Value> unowned{const_cast<OrtValue*>(ptr)}; // onowned is destroyed when the call returns. SlicingOnArgument(unowned); } ``` #### Example 2 slicing of return value ```cpp // The return will be sliced to Ort::Value that would own and relase (double free the ptr) Ort::Value SlicingOnReturn() { const OrtValue* ptr = nullptr; // some value does not matter Ort::Unowned<Ort::Value> unowned{const_cast<OrtValue*>(ptr)}; return unowned; } ```
2022-10-06 21:57:37 +00:00
const Ort::ConstTensorTypeAndShapeInfo& CXXAPITensorTypeAndShapeInfo) {
auto* result = [[ORTTensorTypeAndShapeInfo alloc] init];
const auto elementType = CXXAPITensorTypeAndShapeInfo.GetElementType();
const std::vector<int64_t> shape = CXXAPITensorTypeAndShapeInfo.GetShape();
result.elementType = CAPIToPublicTensorElementType(elementType);
auto* shapeArray = [[NSMutableArray alloc] initWithCapacity:shape.size()];
for (size_t i = 0; i < shape.size(); ++i) {
shapeArray[i] = @(shape[i]);
}
result.shape = shapeArray;
return result;
}
ORTValueTypeInfo* CXXAPIToPublicValueTypeInfo(
const Ort::TypeInfo& CXXAPITypeInfo) {
auto* result = [[ORTValueTypeInfo alloc] init];
const auto valueType = CXXAPITypeInfo.GetONNXType();
result.type = CAPIToPublicValueType(valueType);
if (valueType == ONNX_TYPE_TENSOR) {
const auto tensorTypeAndShapeInfo = CXXAPITypeInfo.GetTensorTypeAndShapeInfo();
result.tensorTypeAndShapeInfo = CXXAPIToPublicTensorTypeAndShapeInfo(tensorTypeAndShapeInfo);
}
return result;
}
// out = a * b
// returns true iff the result does not overflow
bool SafeMultiply(size_t a, size_t b, size_t& out) {
return !__builtin_mul_overflow(a, b, &out);
}
} // namespace
@interface ORTValue ()
// pointer to any external tensor data to keep alive for the lifetime of the ORTValue
@property(nonatomic, nullable) NSMutableData* externalTensorData;
@end
@implementation ORTValue {
std::optional<Ort::Value> _value;
std::optional<Ort::TypeInfo> _typeInfo;
}
#pragma mark - Public
- (nullable instancetype)initWithTensorData:(NSMutableData*)tensorData
elementType:(ORTTensorElementDataType)elementType
shape:(NSArray<NSNumber*>*)shape
error:(NSError**)error {
try {
const auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
const auto ONNXElementType = PublicToCAPITensorElementType(elementType);
const auto shapeVector = [shape]() {
std::vector<int64_t> result{};
result.reserve(shape.count);
for (NSNumber* dim in shape) {
result.push_back(dim.longLongValue);
}
return result;
}();
Ort::Value ortValue = Ort::Value::CreateTensor(
memoryInfo, tensorData.mutableBytes, tensorData.length,
shapeVector.data(), shapeVector.size(), ONNXElementType);
return [self initWithCAPIOrtValue:ortValue.release()
externalTensorData:tensorData
error:error];
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
- (nullable ORTValueTypeInfo*)typeInfoWithError:(NSError**)error {
try {
return CXXAPIToPublicValueTypeInfo(*_typeInfo);
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
- (nullable ORTTensorTypeAndShapeInfo*)tensorTypeAndShapeInfoWithError:(NSError**)error {
try {
const auto tensorTypeAndShapeInfo = _typeInfo->GetTensorTypeAndShapeInfo();
return CXXAPIToPublicTensorTypeAndShapeInfo(tensorTypeAndShapeInfo);
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
- (nullable NSMutableData*)tensorDataWithError:(NSError**)error {
try {
const auto tensorTypeAndShapeInfo = _typeInfo->GetTensorTypeAndShapeInfo();
const size_t elementCount = tensorTypeAndShapeInfo.GetElementCount();
const size_t elementSize = SizeOfCAPITensorElementType(tensorTypeAndShapeInfo.GetElementType());
size_t rawDataLength;
if (!SafeMultiply(elementCount, elementSize, rawDataLength)) {
ORT_CXX_API_THROW("failed to compute tensor data length", ORT_RUNTIME_EXCEPTION);
}
void* rawData;
Ort::ThrowOnError(Ort::GetApi().GetTensorMutableData(*_value, &rawData));
return [NSMutableData dataWithBytesNoCopy:rawData
length:rawDataLength
freeWhenDone:NO];
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
#pragma mark - Internal
- (nullable instancetype)initWithCAPIOrtValue:(OrtValue*)CAPIOrtValue
externalTensorData:(nullable NSMutableData*)externalTensorData
error:(NSError**)error {
if ((self = [super init]) == nil) {
return nil;
}
try {
_value = Ort::Value{CAPIOrtValue};
_typeInfo = _value->GetTypeInfo();
_externalTensorData = externalTensorData;
return self;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error);
}
- (Ort::Value&)CXXAPIOrtValue {
return *_value;
}
@end
@implementation ORTValueTypeInfo
@end
@implementation ORTTensorTypeAndShapeInfo
@end
NS_ASSUME_NONNULL_END