diff --git a/objectivec/include/ort_env.h b/objectivec/include/ort_env.h index 8456b57bfa..67db76668b 100644 --- a/objectivec/include/ort_env.h +++ b/objectivec/include/ort_env.h @@ -24,6 +24,9 @@ NSString* _Nullable ORTVersion(void); /** * The ORT environment. + * It maintains shared state including the default logger. + * + * @note One ORTEnv should be created before and destroyed after other ORT API usage. */ @interface ORTEnv : NSObject diff --git a/objectivec/include/ort_training_session.h b/objectivec/include/ort_training_session.h index 15c0137817..2ad4fed93c 100644 --- a/objectivec/include/ort_training_session.h +++ b/objectivec/include/ort_training_session.h @@ -39,7 +39,7 @@ NS_ASSUME_NONNULL_BEGIN * session which will be moved to the device specified in the session option if needed. * * @param env The `ORTEnv` instance to use for the training session. - * @param sessionOptions The `ORTSessionOptions` to use for the training session. + * @param sessionOptions The optional `ORTSessionOptions` to use for the training session. * @param checkpoint Training states that are used as a starting point for training. * @param trainModelPath The path to the training onnx model. * @param evalModelPath The path to the evaluation onnx model. @@ -52,7 +52,7 @@ NS_ASSUME_NONNULL_BEGIN * keeps a strong (owning) pointer to the checkpoint state. */ - (nullable instancetype)initWithEnv:(ORTEnv*)env - sessionOptions:(ORTSessionOptions*)sessionOptions + sessionOptions:(nullable ORTSessionOptions*)sessionOptions checkpoint:(ORTCheckpoint*)checkpoint trainModelPath:(NSString*)trainModelPath evalModelPath:(nullable NSString*)evalModelPath diff --git a/objectivec/ort_session.mm b/objectivec/ort_session.mm index d27c3e2cef..87288bd1e9 100644 --- a/objectivec/ort_session.mm +++ b/objectivec/ort_session.mm @@ -23,6 +23,7 @@ enum class NamedValueType { NS_ASSUME_NONNULL_BEGIN @implementation ORTSession { + ORTEnv* _env; // keep a strong reference so the ORTEnv doesn't get destroyed before this does std::optional _session; } @@ -44,6 +45,7 @@ NS_ASSUME_NONNULL_BEGIN } } + _env = env; _session = Ort::Session{[env CXXAPIOrtEnv], path.UTF8String, [sessionOptions CXXAPIOrtSessionOptions]}; diff --git a/objectivec/ort_training_session.mm b/objectivec/ort_training_session.mm index 285151b412..5387bfda6d 100644 --- a/objectivec/ort_training_session.mm +++ b/objectivec/ort_training_session.mm @@ -19,8 +19,9 @@ NS_ASSUME_NONNULL_BEGIN @implementation ORTTrainingSession { - std::optional _session; + ORTEnv* _env; // keep a strong reference so the ORTEnv doesn't get destroyed before this does ORTCheckpoint* _checkpoint; + std::optional _session; } - (Ort::TrainingSession&)CXXAPIOrtTrainingSession { @@ -28,7 +29,7 @@ NS_ASSUME_NONNULL_BEGIN } - (nullable instancetype)initWithEnv:(ORTEnv*)env - sessionOptions:(ORTSessionOptions*)sessionOptions + sessionOptions:(nullable ORTSessionOptions*)sessionOptions checkpoint:(ORTCheckpoint*)checkpoint trainModelPath:(NSString*)trainModelPath evalModelPath:(nullable NSString*)evalModelPath @@ -39,9 +40,17 @@ NS_ASSUME_NONNULL_BEGIN } try { + if (!sessionOptions) { + sessionOptions = [[ORTSessionOptions alloc] initWithError:error]; + if (!sessionOptions) { + return nil; + } + } + std::optional evalPath = utils::toStdOptionalString(evalModelPath); std::optional optimizerPath = utils::toStdOptionalString(optimizerModelPath); + _env = env; _checkpoint = checkpoint; _session = Ort::TrainingSession{ [env CXXAPIOrtEnv], @@ -50,6 +59,7 @@ NS_ASSUME_NONNULL_BEGIN trainModelPath.UTF8String, evalPath, optimizerPath}; + return self; } ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) diff --git a/objectivec/test/ort_session_test.mm b/objectivec/test/ort_session_test.mm index f00f5db2f9..508289f7bc 100644 --- a/objectivec/test/ort_session_test.mm +++ b/objectivec/test/ort_session_test.mm @@ -295,6 +295,32 @@ static OrtStatus* _Nullable DummyRegisterCustomOpsFn(OrtSessionOptions* /*sessio XCTAssertTrue([stringData isEqualToArray:outputStringData]); } +- (void)testKeepORTEnvReference { + ORTEnv* __weak envWeak = _ortEnv; + // Remove sole strong reference to the ORTEnv created in setUp. + _ortEnv = nil; + // There should be no more strong references to it. + XCTAssertNil(envWeak); + + // Create a new ORTEnv. + NSError* err = nil; + ORTEnv* env = [[ORTEnv alloc] initWithLoggingLevel:ORTLoggingLevelWarning + error:&err]; + ORTAssertNullableResultSuccessful(env, err); + + ORTSession* session = [[ORTSession alloc] initWithEnv:env + modelPath:[ORTSessionTest getAddModelPath] + sessionOptions:[ORTSessionTest makeSessionOptions] + error:&err]; + ORTAssertNullableResultSuccessful(session, err); + + envWeak = env; + // Remove strong reference to the ORTEnv passed to the ORTSession initializer. + env = nil; + // ORTSession should keep a strong reference to it. + XCTAssertNotNil(envWeak); +} + @end NS_ASSUME_NONNULL_END