mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-28 22:56:32 +00:00
[objc] Fix possible leak of OrtValue in initializer. (#16487)
Fix possible leak of OrtValue in initializer. There was a possible early return before ownership was transferred to the internal C++ Ort::Value.
This commit is contained in:
parent
8fc3037ff4
commit
05c4566fe9
5 changed files with 71 additions and 46 deletions
|
|
@ -3,9 +3,9 @@
|
|||
|
||||
#import "cxx_utils.h"
|
||||
|
||||
#import <vector>
|
||||
#import <optional>
|
||||
#import <string>
|
||||
#include <vector>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
#import "error_utils.h"
|
||||
|
||||
|
|
@ -54,24 +54,24 @@ std::vector<std::string> toStdStringVector(NSArray<NSString*>* strs) {
|
|||
NSArray<NSString*>* toNSStringNSArray(const std::vector<std::string>& strs) {
|
||||
NSMutableArray<NSString*>* result = [NSMutableArray arrayWithCapacity:strs.size()];
|
||||
for (const std::string& str : strs) {
|
||||
NSString* nsStr = [NSString stringWithUTF8String:str.c_str()];
|
||||
if (nsStr) {
|
||||
[result addObject:nsStr];
|
||||
} else {
|
||||
ORT_CXX_API_THROW("Failed to convert std::string to NSString", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
[result addObject:toNSString(str)];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
NSArray<ORTValue*>* _Nullable wrapUnownedCAPIOrtValues(const std::vector<OrtValue*>& values, NSError** error) {
|
||||
NSMutableArray<ORTValue*>* result = [NSMutableArray arrayWithCapacity:values.size()];
|
||||
for (size_t i = 0; i < values.size(); ++i) {
|
||||
ORTValue* val = [[ORTValue alloc] initWithCAPIOrtValue:values[i] externalTensorData:nil error:error];
|
||||
NSArray<ORTValue*>* _Nullable wrapUnownedCAPIOrtValues(const std::vector<OrtValue*>& CAPIValues, NSError** error) {
|
||||
NSMutableArray<ORTValue*>* result = [NSMutableArray arrayWithCapacity:CAPIValues.size()];
|
||||
for (size_t i = 0; i < CAPIValues.size(); ++i) {
|
||||
// Wrap the C OrtValue in a C++ Ort::Value to automatically handle its release.
|
||||
// Then, transfer that C++ Ort::Value to a new ORTValue.
|
||||
Ort::Value CXXAPIValue{CAPIValues[i]};
|
||||
ORTValue* val = [[ORTValue alloc] initWithCXXAPIOrtValue:std::move(CXXAPIValue)
|
||||
externalTensorData:nil
|
||||
error:error];
|
||||
if (!val) {
|
||||
// clean up all the C API Ortvalues which haven't been wrapped by ORTValue
|
||||
for (size_t j = i; j < values.size(); ++j) {
|
||||
Ort::GetApi().ReleaseValue(values[j]);
|
||||
// clean up remaining C OrtValues which haven't been wrapped by a C++ Ort::Value yet
|
||||
for (size_t j = i + 1; j < CAPIValues.size(); ++j) {
|
||||
Ort::GetApi().ReleaseValue(CAPIValues[j]);
|
||||
}
|
||||
return nil;
|
||||
}
|
||||
|
|
@ -82,6 +82,7 @@ NSArray<ORTValue*>* _Nullable wrapUnownedCAPIOrtValues(const std::vector<OrtValu
|
|||
|
||||
std::vector<const OrtValue*> getWrappedCAPIOrtValues(NSArray<ORTValue*>* values) {
|
||||
std::vector<const OrtValue*> result;
|
||||
result.reserve(values.count);
|
||||
for (ORTValue* val in values) {
|
||||
result.push_back(static_cast<const OrtValue*>([val CXXAPIOrtValue]));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -66,22 +66,26 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
}
|
||||
|
||||
std::vector<const char*> inputNames, outputNames;
|
||||
std::vector<const OrtValue*> inputValues;
|
||||
std::vector<OrtValue*> outputValues;
|
||||
std::vector<const OrtValue*> inputCAPIValues;
|
||||
std::vector<OrtValue*> outputCAPIValues;
|
||||
|
||||
inputNames.reserve(inputs.count);
|
||||
inputCAPIValues.reserve(inputs.count);
|
||||
for (NSString* inputName in inputs) {
|
||||
inputNames.push_back(inputName.UTF8String);
|
||||
inputValues.push_back(static_cast<const OrtValue*>([inputs[inputName] CXXAPIOrtValue]));
|
||||
inputCAPIValues.push_back(static_cast<const OrtValue*>([inputs[inputName] CXXAPIOrtValue]));
|
||||
}
|
||||
|
||||
outputNames.reserve(outputs.count);
|
||||
outputCAPIValues.reserve(outputs.count);
|
||||
for (NSString* outputName in outputs) {
|
||||
outputNames.push_back(outputName.UTF8String);
|
||||
outputValues.push_back(static_cast<OrtValue*>([outputs[outputName] CXXAPIOrtValue]));
|
||||
outputCAPIValues.push_back(static_cast<OrtValue*>([outputs[outputName] CXXAPIOrtValue]));
|
||||
}
|
||||
|
||||
Ort::ThrowOnError(Ort::GetApi().Run(*_session, [runOptions CXXAPIOrtRunOptions],
|
||||
inputNames.data(), inputValues.data(), inputNames.size(),
|
||||
outputNames.data(), outputNames.size(), outputValues.data()));
|
||||
inputNames.data(), inputCAPIValues.data(), inputNames.size(),
|
||||
outputNames.data(), outputNames.size(), outputCAPIValues.data()));
|
||||
|
||||
return YES;
|
||||
}
|
||||
|
|
@ -103,30 +107,39 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
NSArray<NSString*>* outputNameArray = outputNameSet.allObjects;
|
||||
|
||||
std::vector<const char*> inputNames, outputNames;
|
||||
std::vector<const OrtValue*> inputValues;
|
||||
std::vector<OrtValue*> outputValues;
|
||||
std::vector<const OrtValue*> inputCAPIValues;
|
||||
std::vector<OrtValue*> outputCAPIValues;
|
||||
|
||||
inputNames.reserve(inputs.count);
|
||||
inputCAPIValues.reserve(inputs.count);
|
||||
for (NSString* inputName in inputs) {
|
||||
inputNames.push_back(inputName.UTF8String);
|
||||
inputValues.push_back(static_cast<const OrtValue*>([inputs[inputName] CXXAPIOrtValue]));
|
||||
inputCAPIValues.push_back(static_cast<const OrtValue*>([inputs[inputName] CXXAPIOrtValue]));
|
||||
}
|
||||
|
||||
outputNames.reserve(outputNameArray.count);
|
||||
outputCAPIValues.reserve(outputNameArray.count);
|
||||
for (NSString* outputName in outputNameArray) {
|
||||
outputNames.push_back(outputName.UTF8String);
|
||||
outputValues.push_back(nullptr);
|
||||
outputCAPIValues.push_back(nullptr);
|
||||
}
|
||||
|
||||
Ort::ThrowOnError(Ort::GetApi().Run(*_session, [runOptions CXXAPIOrtRunOptions],
|
||||
inputNames.data(), inputValues.data(), inputNames.size(),
|
||||
outputNames.data(), outputNames.size(), outputValues.data()));
|
||||
inputNames.data(), inputCAPIValues.data(), inputNames.size(),
|
||||
outputNames.data(), outputNames.size(), outputCAPIValues.data()));
|
||||
|
||||
NSMutableDictionary<NSString*, ORTValue*>* outputs = [[NSMutableDictionary alloc] init];
|
||||
for (NSUInteger i = 0; i < outputNameArray.count; ++i) {
|
||||
ORTValue* outputValue = [[ORTValue alloc] initWithCAPIOrtValue:outputValues[i] externalTensorData:nil error:error];
|
||||
// Wrap the C OrtValue in a C++ Ort::Value to automatically handle its release.
|
||||
// Then, transfer that C++ Ort::Value to a new ORTValue.
|
||||
Ort::Value outputCXXAPIValue{outputCAPIValues[i]};
|
||||
ORTValue* outputValue = [[ORTValue alloc] initWithCXXAPIOrtValue:std::move(outputCXXAPIValue)
|
||||
externalTensorData:nil
|
||||
error:error];
|
||||
if (!outputValue) {
|
||||
// clean up remaining C API OrtValues which haven't been wrapped by an ORTValue yet
|
||||
for (NSUInteger j = i; j < outputNameArray.count; ++j) {
|
||||
Ort::GetApi().ReleaseValue(outputValues[j]);
|
||||
// clean up remaining C OrtValues which haven't been wrapped by a C++ Ort::Value yet
|
||||
for (NSUInteger j = i + 1; j < outputNameArray.count; ++j) {
|
||||
Ort::GetApi().ReleaseValue(outputCAPIValues[j]);
|
||||
}
|
||||
return nil;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -198,9 +198,9 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
error:(NSError**)error {
|
||||
try {
|
||||
Ort::Value val = [self CXXAPIOrtTrainingSession].ToBuffer(onlyTrainable);
|
||||
return [[ORTValue alloc] initWithCAPIOrtValue:val.release()
|
||||
externalTensorData:nil
|
||||
error:error];
|
||||
return [[ORTValue alloc] initWithCXXAPIOrtValue:std::move(val)
|
||||
externalTensorData:nil
|
||||
error:error];
|
||||
}
|
||||
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -85,9 +85,9 @@ bool SafeMultiply(size_t a, size_t b, size_t& out) {
|
|||
memoryInfo, tensorData.mutableBytes, tensorData.length,
|
||||
shapeVector.data(), shapeVector.size(), ONNXElementType);
|
||||
|
||||
return [self initWithCAPIOrtValue:ortValue.release()
|
||||
externalTensorData:tensorData
|
||||
error:error];
|
||||
return [self initWithCXXAPIOrtValue:std::move(ortValue)
|
||||
externalTensorData:tensorData
|
||||
error:error];
|
||||
}
|
||||
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
|
||||
}
|
||||
|
|
@ -129,17 +129,19 @@ bool SafeMultiply(size_t a, size_t b, size_t& out) {
|
|||
|
||||
#pragma mark - Internal
|
||||
|
||||
- (nullable instancetype)initWithCAPIOrtValue:(OrtValue*)CAPIOrtValue
|
||||
externalTensorData:(nullable NSMutableData*)externalTensorData
|
||||
error:(NSError**)error {
|
||||
- (nullable instancetype)initWithCXXAPIOrtValue:(Ort::Value&&)existingCXXAPIOrtValue
|
||||
externalTensorData:(nullable NSMutableData*)externalTensorData
|
||||
error:(NSError**)error {
|
||||
if ((self = [super init]) == nil) {
|
||||
return nil;
|
||||
}
|
||||
|
||||
try {
|
||||
_value = Ort::Value{CAPIOrtValue};
|
||||
_typeInfo = _value->GetTypeInfo();
|
||||
_typeInfo = existingCXXAPIOrtValue.GetTypeInfo();
|
||||
_externalTensorData = externalTensorData;
|
||||
|
||||
// transfer C++ Ort::Value ownership to this instance
|
||||
_value = std::move(existingCXXAPIOrtValue);
|
||||
return self;
|
||||
}
|
||||
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error);
|
||||
|
|
|
|||
|
|
@ -9,9 +9,18 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
|
||||
@interface ORTValue ()
|
||||
|
||||
- (nullable instancetype)initWithCAPIOrtValue:(OrtValue*)CAPIOrtValue
|
||||
externalTensorData:(nullable NSMutableData*)externalTensorData
|
||||
error:(NSError**)error NS_DESIGNATED_INITIALIZER;
|
||||
/**
|
||||
* Creates a value from an existing C++ API Ort::Value and takes ownership from it.
|
||||
* Note: Ownership is guaranteed to be transferred on success but not otherwise.
|
||||
*
|
||||
* @param existingCXXAPIOrtValue The existing C++ API Ort::Value.
|
||||
* @param externalTensorData Any external tensor data referenced by `existingCXXAPIOrtValue`.
|
||||
* @param error Optional error information set if an error occurs.
|
||||
* @return The instance, or nil if an error occurs.
|
||||
*/
|
||||
- (nullable instancetype)initWithCXXAPIOrtValue:(Ort::Value&&)existingCXXAPIOrtValue
|
||||
externalTensorData:(nullable NSMutableData*)externalTensorData
|
||||
error:(NSError**)error NS_DESIGNATED_INITIALIZER;
|
||||
|
||||
- (Ort::Value&)CXXAPIOrtValue;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue