Objective-C API updates (#18738)

- Add ORTSession and ORTTrainingSession strong references to ORTEnv.
- Make ORTTrainingSession session options parameter optional.
This commit is contained in:
Edward Chen 2023-12-07 16:47:46 -08:00 committed by GitHub
parent bf33919afb
commit 7ed48a299a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 45 additions and 4 deletions

View file

@ -24,6 +24,9 @@ NSString* _Nullable ORTVersion(void);
/**
* The ORT environment.
* It maintains shared state including the default logger.
*
* @note One ORTEnv should be created before and destroyed after other ORT API usage.
*/
@interface ORTEnv : NSObject

View file

@ -39,7 +39,7 @@ NS_ASSUME_NONNULL_BEGIN
* session which will be moved to the device specified in the session option if needed.
*
* @param env The `ORTEnv` instance to use for the training session.
* @param sessionOptions The `ORTSessionOptions` to use for the training session.
* @param sessionOptions The optional `ORTSessionOptions` to use for the training session.
* @param checkpoint Training states that are used as a starting point for training.
* @param trainModelPath The path to the training onnx model.
* @param evalModelPath The path to the evaluation onnx model.
@ -52,7 +52,7 @@ NS_ASSUME_NONNULL_BEGIN
* keeps a strong (owning) pointer to the checkpoint state.
*/
- (nullable instancetype)initWithEnv:(ORTEnv*)env
sessionOptions:(ORTSessionOptions*)sessionOptions
sessionOptions:(nullable ORTSessionOptions*)sessionOptions
checkpoint:(ORTCheckpoint*)checkpoint
trainModelPath:(NSString*)trainModelPath
evalModelPath:(nullable NSString*)evalModelPath

View file

@ -23,6 +23,7 @@ enum class NamedValueType {
NS_ASSUME_NONNULL_BEGIN
@implementation ORTSession {
ORTEnv* _env; // keep a strong reference so the ORTEnv doesn't get destroyed before this does
std::optional<Ort::Session> _session;
}
@ -44,6 +45,7 @@ NS_ASSUME_NONNULL_BEGIN
}
}
_env = env;
_session = Ort::Session{[env CXXAPIOrtEnv],
path.UTF8String,
[sessionOptions CXXAPIOrtSessionOptions]};

View file

@ -19,8 +19,9 @@
NS_ASSUME_NONNULL_BEGIN
@implementation ORTTrainingSession {
std::optional<Ort::TrainingSession> _session;
ORTEnv* _env; // keep a strong reference so the ORTEnv doesn't get destroyed before this does
ORTCheckpoint* _checkpoint;
std::optional<Ort::TrainingSession> _session;
}
- (Ort::TrainingSession&)CXXAPIOrtTrainingSession {
@ -28,7 +29,7 @@ NS_ASSUME_NONNULL_BEGIN
}
- (nullable instancetype)initWithEnv:(ORTEnv*)env
sessionOptions:(ORTSessionOptions*)sessionOptions
sessionOptions:(nullable ORTSessionOptions*)sessionOptions
checkpoint:(ORTCheckpoint*)checkpoint
trainModelPath:(NSString*)trainModelPath
evalModelPath:(nullable NSString*)evalModelPath
@ -39,9 +40,17 @@ NS_ASSUME_NONNULL_BEGIN
}
try {
if (!sessionOptions) {
sessionOptions = [[ORTSessionOptions alloc] initWithError:error];
if (!sessionOptions) {
return nil;
}
}
std::optional<std::string> evalPath = utils::toStdOptionalString(evalModelPath);
std::optional<std::string> optimizerPath = utils::toStdOptionalString(optimizerModelPath);
_env = env;
_checkpoint = checkpoint;
_session = Ort::TrainingSession{
[env CXXAPIOrtEnv],
@ -50,6 +59,7 @@ NS_ASSUME_NONNULL_BEGIN
trainModelPath.UTF8String,
evalPath,
optimizerPath};
return self;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)

View file

@ -295,6 +295,32 @@ static OrtStatus* _Nullable DummyRegisterCustomOpsFn(OrtSessionOptions* /*sessio
XCTAssertTrue([stringData isEqualToArray:outputStringData]);
}
- (void)testKeepORTEnvReference {
ORTEnv* __weak envWeak = _ortEnv;
// Remove sole strong reference to the ORTEnv created in setUp.
_ortEnv = nil;
// There should be no more strong references to it.
XCTAssertNil(envWeak);
// Create a new ORTEnv.
NSError* err = nil;
ORTEnv* env = [[ORTEnv alloc] initWithLoggingLevel:ORTLoggingLevelWarning
error:&err];
ORTAssertNullableResultSuccessful(env, err);
ORTSession* session = [[ORTSession alloc] initWithEnv:env
modelPath:[ORTSessionTest getAddModelPath]
sessionOptions:[ORTSessionTest makeSessionOptions]
error:&err];
ORTAssertNullableResultSuccessful(session, err);
envWeak = env;
// Remove strong reference to the ORTEnv passed to the ORTSession initializer.
env = nil;
// ORTSession should keep a strong reference to it.
XCTAssertNotNil(envWeak);
}
@end
NS_ASSUME_NONNULL_END