[objc] Fix possible leak of OrtValue in initializer. (#16487)

Fix possible leak of OrtValue in initializer. There was a possible early return before ownership was transferred to the internal C++ Ort::Value.
This commit is contained in:
Edward Chen 2023-06-29 17:37:16 -07:00 committed by GitHub
parent 8fc3037ff4
commit 05c4566fe9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 71 additions and 46 deletions

View file

@ -3,9 +3,9 @@
#import "cxx_utils.h"
#import <vector>
#import <optional>
#import <string>
#include <vector>
#include <optional>
#include <string>
#import "error_utils.h"
@ -54,24 +54,24 @@ std::vector<std::string> toStdStringVector(NSArray<NSString*>* strs) {
NSArray<NSString*>* toNSStringNSArray(const std::vector<std::string>& strs) {
NSMutableArray<NSString*>* result = [NSMutableArray arrayWithCapacity:strs.size()];
for (const std::string& str : strs) {
NSString* nsStr = [NSString stringWithUTF8String:str.c_str()];
if (nsStr) {
[result addObject:nsStr];
} else {
ORT_CXX_API_THROW("Failed to convert std::string to NSString", ORT_INVALID_ARGUMENT);
}
[result addObject:toNSString(str)];
}
return result;
}
NSArray<ORTValue*>* _Nullable wrapUnownedCAPIOrtValues(const std::vector<OrtValue*>& values, NSError** error) {
NSMutableArray<ORTValue*>* result = [NSMutableArray arrayWithCapacity:values.size()];
for (size_t i = 0; i < values.size(); ++i) {
ORTValue* val = [[ORTValue alloc] initWithCAPIOrtValue:values[i] externalTensorData:nil error:error];
NSArray<ORTValue*>* _Nullable wrapUnownedCAPIOrtValues(const std::vector<OrtValue*>& CAPIValues, NSError** error) {
NSMutableArray<ORTValue*>* result = [NSMutableArray arrayWithCapacity:CAPIValues.size()];
for (size_t i = 0; i < CAPIValues.size(); ++i) {
// Wrap the C OrtValue in a C++ Ort::Value to automatically handle its release.
// Then, transfer that C++ Ort::Value to a new ORTValue.
Ort::Value CXXAPIValue{CAPIValues[i]};
ORTValue* val = [[ORTValue alloc] initWithCXXAPIOrtValue:std::move(CXXAPIValue)
externalTensorData:nil
error:error];
if (!val) {
// clean up all the C API Ortvalues which haven't been wrapped by ORTValue
for (size_t j = i; j < values.size(); ++j) {
Ort::GetApi().ReleaseValue(values[j]);
// clean up remaining C OrtValues which haven't been wrapped by a C++ Ort::Value yet
for (size_t j = i + 1; j < CAPIValues.size(); ++j) {
Ort::GetApi().ReleaseValue(CAPIValues[j]);
}
return nil;
}
@ -82,6 +82,7 @@ NSArray<ORTValue*>* _Nullable wrapUnownedCAPIOrtValues(const std::vector<OrtValu
std::vector<const OrtValue*> getWrappedCAPIOrtValues(NSArray<ORTValue*>* values) {
std::vector<const OrtValue*> result;
result.reserve(values.count);
for (ORTValue* val in values) {
result.push_back(static_cast<const OrtValue*>([val CXXAPIOrtValue]));
}

View file

@ -66,22 +66,26 @@ NS_ASSUME_NONNULL_BEGIN
}
std::vector<const char*> inputNames, outputNames;
std::vector<const OrtValue*> inputValues;
std::vector<OrtValue*> outputValues;
std::vector<const OrtValue*> inputCAPIValues;
std::vector<OrtValue*> outputCAPIValues;
inputNames.reserve(inputs.count);
inputCAPIValues.reserve(inputs.count);
for (NSString* inputName in inputs) {
inputNames.push_back(inputName.UTF8String);
inputValues.push_back(static_cast<const OrtValue*>([inputs[inputName] CXXAPIOrtValue]));
inputCAPIValues.push_back(static_cast<const OrtValue*>([inputs[inputName] CXXAPIOrtValue]));
}
outputNames.reserve(outputs.count);
outputCAPIValues.reserve(outputs.count);
for (NSString* outputName in outputs) {
outputNames.push_back(outputName.UTF8String);
outputValues.push_back(static_cast<OrtValue*>([outputs[outputName] CXXAPIOrtValue]));
outputCAPIValues.push_back(static_cast<OrtValue*>([outputs[outputName] CXXAPIOrtValue]));
}
Ort::ThrowOnError(Ort::GetApi().Run(*_session, [runOptions CXXAPIOrtRunOptions],
inputNames.data(), inputValues.data(), inputNames.size(),
outputNames.data(), outputNames.size(), outputValues.data()));
inputNames.data(), inputCAPIValues.data(), inputNames.size(),
outputNames.data(), outputNames.size(), outputCAPIValues.data()));
return YES;
}
@ -103,30 +107,39 @@ NS_ASSUME_NONNULL_BEGIN
NSArray<NSString*>* outputNameArray = outputNameSet.allObjects;
std::vector<const char*> inputNames, outputNames;
std::vector<const OrtValue*> inputValues;
std::vector<OrtValue*> outputValues;
std::vector<const OrtValue*> inputCAPIValues;
std::vector<OrtValue*> outputCAPIValues;
inputNames.reserve(inputs.count);
inputCAPIValues.reserve(inputs.count);
for (NSString* inputName in inputs) {
inputNames.push_back(inputName.UTF8String);
inputValues.push_back(static_cast<const OrtValue*>([inputs[inputName] CXXAPIOrtValue]));
inputCAPIValues.push_back(static_cast<const OrtValue*>([inputs[inputName] CXXAPIOrtValue]));
}
outputNames.reserve(outputNameArray.count);
outputCAPIValues.reserve(outputNameArray.count);
for (NSString* outputName in outputNameArray) {
outputNames.push_back(outputName.UTF8String);
outputValues.push_back(nullptr);
outputCAPIValues.push_back(nullptr);
}
Ort::ThrowOnError(Ort::GetApi().Run(*_session, [runOptions CXXAPIOrtRunOptions],
inputNames.data(), inputValues.data(), inputNames.size(),
outputNames.data(), outputNames.size(), outputValues.data()));
inputNames.data(), inputCAPIValues.data(), inputNames.size(),
outputNames.data(), outputNames.size(), outputCAPIValues.data()));
NSMutableDictionary<NSString*, ORTValue*>* outputs = [[NSMutableDictionary alloc] init];
for (NSUInteger i = 0; i < outputNameArray.count; ++i) {
ORTValue* outputValue = [[ORTValue alloc] initWithCAPIOrtValue:outputValues[i] externalTensorData:nil error:error];
// Wrap the C OrtValue in a C++ Ort::Value to automatically handle its release.
// Then, transfer that C++ Ort::Value to a new ORTValue.
Ort::Value outputCXXAPIValue{outputCAPIValues[i]};
ORTValue* outputValue = [[ORTValue alloc] initWithCXXAPIOrtValue:std::move(outputCXXAPIValue)
externalTensorData:nil
error:error];
if (!outputValue) {
// clean up remaining C API OrtValues which haven't been wrapped by an ORTValue yet
for (NSUInteger j = i; j < outputNameArray.count; ++j) {
Ort::GetApi().ReleaseValue(outputValues[j]);
// clean up remaining C OrtValues which haven't been wrapped by a C++ Ort::Value yet
for (NSUInteger j = i + 1; j < outputNameArray.count; ++j) {
Ort::GetApi().ReleaseValue(outputCAPIValues[j]);
}
return nil;
}

View file

@ -198,9 +198,9 @@ NS_ASSUME_NONNULL_BEGIN
error:(NSError**)error {
try {
Ort::Value val = [self CXXAPIOrtTrainingSession].ToBuffer(onlyTrainable);
return [[ORTValue alloc] initWithCAPIOrtValue:val.release()
externalTensorData:nil
error:error];
return [[ORTValue alloc] initWithCXXAPIOrtValue:std::move(val)
externalTensorData:nil
error:error];
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}

View file

@ -85,9 +85,9 @@ bool SafeMultiply(size_t a, size_t b, size_t& out) {
memoryInfo, tensorData.mutableBytes, tensorData.length,
shapeVector.data(), shapeVector.size(), ONNXElementType);
return [self initWithCAPIOrtValue:ortValue.release()
externalTensorData:tensorData
error:error];
return [self initWithCXXAPIOrtValue:std::move(ortValue)
externalTensorData:tensorData
error:error];
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
@ -129,17 +129,19 @@ bool SafeMultiply(size_t a, size_t b, size_t& out) {
#pragma mark - Internal
- (nullable instancetype)initWithCAPIOrtValue:(OrtValue*)CAPIOrtValue
externalTensorData:(nullable NSMutableData*)externalTensorData
error:(NSError**)error {
- (nullable instancetype)initWithCXXAPIOrtValue:(Ort::Value&&)existingCXXAPIOrtValue
externalTensorData:(nullable NSMutableData*)externalTensorData
error:(NSError**)error {
if ((self = [super init]) == nil) {
return nil;
}
try {
_value = Ort::Value{CAPIOrtValue};
_typeInfo = _value->GetTypeInfo();
_typeInfo = existingCXXAPIOrtValue.GetTypeInfo();
_externalTensorData = externalTensorData;
// transfer C++ Ort::Value ownership to this instance
_value = std::move(existingCXXAPIOrtValue);
return self;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error);

View file

@ -9,9 +9,18 @@ NS_ASSUME_NONNULL_BEGIN
@interface ORTValue ()
- (nullable instancetype)initWithCAPIOrtValue:(OrtValue*)CAPIOrtValue
externalTensorData:(nullable NSMutableData*)externalTensorData
error:(NSError**)error NS_DESIGNATED_INITIALIZER;
/**
* Creates a value from an existing C++ API Ort::Value and takes ownership from it.
* Note: Ownership is guaranteed to be transferred on success but not otherwise.
*
* @param existingCXXAPIOrtValue The existing C++ API Ort::Value.
* @param externalTensorData Any external tensor data referenced by `existingCXXAPIOrtValue`.
* @param error Optional error information set if an error occurs.
* @return The instance, or nil if an error occurs.
*/
- (nullable instancetype)initWithCXXAPIOrtValue:(Ort::Value&&)existingCXXAPIOrtValue
externalTensorData:(nullable NSMutableData*)externalTensorData
error:(NSError**)error NS_DESIGNATED_INITIALIZER;
- (Ort::Value&)CXXAPIOrtValue;