onnxruntime/objectivec/ort_session.mm
Scott McKay 446c478fbd
Add iOS Swift Package Manager support (#15297)
### Description
<!-- Describe your changes. -->
Add Swift Package Manager (SPM) support for ORT based on  #14621
- uses the existing objective-c bindings
- some re-organization of the directory structure was required but the
contents of the files are unchanged, apart from adjustments due to file
movements

Add tool for updating ORT native pod used in the SPM package
Update CIs to use ORT native pod from build, and build/test using SPM



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
iOS developers are using SPM as much as cocoapods, so adding SPM means
both are catered for.
2023-04-20 16:18:35 +10:00

361 lines
10 KiB
Text

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#import "ort_session_internal.h"
#include <optional>
#include <vector>
#import "cxx_api.h"
#import "error_utils.h"
#import "ort_enums_internal.h"
#import "ort_env_internal.h"
#import "ort_value_internal.h"
namespace {
enum class NamedValueType {
Input,
OverridableInitializer,
Output,
};
} // namespace
NS_ASSUME_NONNULL_BEGIN
@implementation ORTSession {
std::optional<Ort::Session> _session;
}
#pragma mark - Public
- (nullable instancetype)initWithEnv:(ORTEnv*)env
modelPath:(NSString*)path
sessionOptions:(nullable ORTSessionOptions*)sessionOptions
error:(NSError**)error {
if ((self = [super init]) == nil) {
return nil;
}
try {
if (!sessionOptions) {
sessionOptions = [[ORTSessionOptions alloc] initWithError:error];
if (!sessionOptions) {
return nil;
}
}
_session = Ort::Session{[env CXXAPIOrtEnv],
path.UTF8String,
[sessionOptions CXXAPIOrtSessionOptions]};
return self;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
- (BOOL)runWithInputs:(NSDictionary<NSString*, ORTValue*>*)inputs
outputs:(NSDictionary<NSString*, ORTValue*>*)outputs
runOptions:(nullable ORTRunOptions*)runOptions
error:(NSError**)error {
try {
if (!runOptions) {
runOptions = [[ORTRunOptions alloc] initWithError:error];
if (!runOptions) {
return NO;
}
}
std::vector<const char*> inputNames, outputNames;
std::vector<const OrtValue*> inputValues;
std::vector<OrtValue*> outputValues;
for (NSString* inputName in inputs) {
inputNames.push_back(inputName.UTF8String);
inputValues.push_back(static_cast<const OrtValue*>([inputs[inputName] CXXAPIOrtValue]));
}
for (NSString* outputName in outputs) {
outputNames.push_back(outputName.UTF8String);
outputValues.push_back(static_cast<OrtValue*>([outputs[outputName] CXXAPIOrtValue]));
}
Ort::ThrowOnError(Ort::GetApi().Run(*_session, [runOptions CXXAPIOrtRunOptions],
inputNames.data(), inputValues.data(), inputNames.size(),
outputNames.data(), outputNames.size(), outputValues.data()));
return YES;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}
- (nullable NSDictionary<NSString*, ORTValue*>*)runWithInputs:(NSDictionary<NSString*, ORTValue*>*)inputs
outputNames:(NSSet<NSString*>*)outputNameSet
runOptions:(nullable ORTRunOptions*)runOptions
error:(NSError**)error {
try {
if (!runOptions) {
runOptions = [[ORTRunOptions alloc] initWithError:error];
if (!runOptions) {
return nil;
}
}
NSArray<NSString*>* outputNameArray = outputNameSet.allObjects;
std::vector<const char*> inputNames, outputNames;
std::vector<const OrtValue*> inputValues;
std::vector<OrtValue*> outputValues;
for (NSString* inputName in inputs) {
inputNames.push_back(inputName.UTF8String);
inputValues.push_back(static_cast<const OrtValue*>([inputs[inputName] CXXAPIOrtValue]));
}
for (NSString* outputName in outputNameArray) {
outputNames.push_back(outputName.UTF8String);
outputValues.push_back(nullptr);
}
Ort::ThrowOnError(Ort::GetApi().Run(*_session, [runOptions CXXAPIOrtRunOptions],
inputNames.data(), inputValues.data(), inputNames.size(),
outputNames.data(), outputNames.size(), outputValues.data()));
NSMutableDictionary<NSString*, ORTValue*>* outputs = [[NSMutableDictionary alloc] init];
for (NSUInteger i = 0; i < outputNameArray.count; ++i) {
ORTValue* outputValue = [[ORTValue alloc] initWithCAPIOrtValue:outputValues[i] externalTensorData:nil error:error];
if (!outputValue) {
// clean up remaining C API OrtValues which haven't been wrapped by an ORTValue yet
for (NSUInteger j = i; j < outputNameArray.count; ++j) {
Ort::GetApi().ReleaseValue(outputValues[j]);
}
return nil;
}
outputs[outputNameArray[i]] = outputValue;
}
return outputs;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
- (nullable NSArray<NSString*>*)inputNamesWithError:(NSError**)error {
return [self namesWithType:NamedValueType::Input error:error];
}
- (nullable NSArray<NSString*>*)overridableInitializerNamesWithError:(NSError**)error {
return [self namesWithType:NamedValueType::OverridableInitializer error:error];
}
- (nullable NSArray<NSString*>*)outputNamesWithError:(NSError**)error {
return [self namesWithType:NamedValueType::Output error:error];
}
#pragma mark - Private
- (nullable NSArray<NSString*>*)namesWithType:(NamedValueType)namedValueType
error:(NSError**)error {
try {
auto getCount = [&session = *_session, namedValueType]() {
if (namedValueType == NamedValueType::Input) {
return session.GetInputCount();
} else if (namedValueType == NamedValueType::OverridableInitializer) {
return session.GetOverridableInitializerCount();
} else {
return session.GetOutputCount();
}
};
auto getName = [&session = *_session, namedValueType](size_t i, OrtAllocator* allocator) {
if (namedValueType == NamedValueType::Input) {
return session.GetInputNameAllocated(i, allocator);
} else if (namedValueType == NamedValueType::OverridableInitializer) {
return session.GetOverridableInitializerNameAllocated(i, allocator);
} else {
return session.GetOutputNameAllocated(i, allocator);
}
};
const size_t nameCount = getCount();
Ort::AllocatorWithDefaultOptions allocator;
NSMutableArray<NSString*>* result = [NSMutableArray arrayWithCapacity:nameCount];
for (size_t i = 0; i < nameCount; ++i) {
auto name = getName(i, allocator);
NSString* nameNsstr = [NSString stringWithUTF8String:name.get()];
NSAssert(nameNsstr != nil, @"nameNsstr must not be nil");
[result addObject:nameNsstr];
}
return result;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
@end
@implementation ORTSessionOptions {
std::optional<Ort::SessionOptions> _sessionOptions;
}
#pragma mark - Public
- (nullable instancetype)initWithError:(NSError**)error {
if ((self = [super init]) == nil) {
return nil;
}
try {
_sessionOptions = Ort::SessionOptions{};
return self;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
- (BOOL)appendExecutionProvider:(NSString*)providerName
providerOptions:(NSDictionary<NSString*, NSString*>*)providerOptions
error:(NSError**)error {
try {
std::unordered_map<std::string, std::string> options;
NSArray* keys = [providerOptions allKeys];
for (NSString* key in keys) {
NSString* value = [providerOptions objectForKey:key];
options.emplace(key.UTF8String, value.UTF8String);
}
_sessionOptions->AppendExecutionProvider(providerName.UTF8String, options);
return YES;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error);
}
- (BOOL)setIntraOpNumThreads:(int)intraOpNumThreads
error:(NSError**)error {
try {
_sessionOptions->SetIntraOpNumThreads(intraOpNumThreads);
return YES;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}
- (BOOL)setGraphOptimizationLevel:(ORTGraphOptimizationLevel)graphOptimizationLevel
error:(NSError**)error {
try {
_sessionOptions->SetGraphOptimizationLevel(
PublicToCAPIGraphOptimizationLevel(graphOptimizationLevel));
return YES;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}
- (BOOL)setOptimizedModelFilePath:(NSString*)optimizedModelFilePath
error:(NSError**)error {
try {
_sessionOptions->SetOptimizedModelFilePath(optimizedModelFilePath.UTF8String);
return YES;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}
- (BOOL)setLogID:(NSString*)logID
error:(NSError**)error {
try {
_sessionOptions->SetLogId(logID.UTF8String);
return YES;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}
- (BOOL)setLogSeverityLevel:(ORTLoggingLevel)loggingLevel
error:(NSError**)error {
try {
_sessionOptions->SetLogSeverityLevel(PublicToCAPILoggingLevel(loggingLevel));
return YES;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}
- (BOOL)addConfigEntryWithKey:(NSString*)key
value:(NSString*)value
error:(NSError**)error {
try {
_sessionOptions->AddConfigEntry(key.UTF8String, value.UTF8String);
return YES;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}
- (BOOL)registerCustomOpsUsingFunction:(NSString*)registrationFuncName
error:(NSError**)error {
try {
_sessionOptions->RegisterCustomOpsUsingFunction(registrationFuncName.UTF8String);
return YES;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}
#pragma mark - Internal
- (Ort::SessionOptions&)CXXAPIOrtSessionOptions {
return *_sessionOptions;
}
@end
@implementation ORTRunOptions {
std::optional<Ort::RunOptions> _runOptions;
}
#pragma mark - Public
- (nullable instancetype)initWithError:(NSError**)error {
if ((self = [super init]) == nil) {
return nil;
}
try {
_runOptions = Ort::RunOptions{};
return self;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
- (BOOL)setLogTag:(NSString*)logTag
error:(NSError**)error {
try {
_runOptions->SetRunTag(logTag.UTF8String);
return YES;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}
- (BOOL)setLogSeverityLevel:(ORTLoggingLevel)loggingLevel
error:(NSError**)error {
try {
_runOptions->SetRunLogSeverityLevel(PublicToCAPILoggingLevel(loggingLevel));
return YES;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}
- (BOOL)addConfigEntryWithKey:(NSString*)key
value:(NSString*)value
error:(NSError**)error {
try {
_runOptions->AddConfigEntry(key.UTF8String, value.UTF8String);
return YES;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}
#pragma mark - Internal
- (Ort::RunOptions&)CXXAPIOrtRunOptions {
return *_runOptions;
}
@end
NS_ASSUME_NONNULL_END