mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Objective-C API updates (#18738)
- Add ORTSession and ORTTrainingSession strong references to ORTEnv. - Make ORTTrainingSession session options parameter optional.
This commit is contained in:
parent
bf33919afb
commit
7ed48a299a
5 changed files with 45 additions and 4 deletions
|
|
@ -24,6 +24,9 @@ NSString* _Nullable ORTVersion(void);
|
|||
|
||||
/**
|
||||
* The ORT environment.
|
||||
* It maintains shared state including the default logger.
|
||||
*
|
||||
* @note One ORTEnv should be created before and destroyed after other ORT API usage.
|
||||
*/
|
||||
@interface ORTEnv : NSObject
|
||||
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
* session which will be moved to the device specified in the session option if needed.
|
||||
*
|
||||
* @param env The `ORTEnv` instance to use for the training session.
|
||||
* @param sessionOptions The `ORTSessionOptions` to use for the training session.
|
||||
* @param sessionOptions The optional `ORTSessionOptions` to use for the training session.
|
||||
* @param checkpoint Training states that are used as a starting point for training.
|
||||
* @param trainModelPath The path to the training onnx model.
|
||||
* @param evalModelPath The path to the evaluation onnx model.
|
||||
|
|
@ -52,7 +52,7 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
* keeps a strong (owning) pointer to the checkpoint state.
|
||||
*/
|
||||
- (nullable instancetype)initWithEnv:(ORTEnv*)env
|
||||
sessionOptions:(ORTSessionOptions*)sessionOptions
|
||||
sessionOptions:(nullable ORTSessionOptions*)sessionOptions
|
||||
checkpoint:(ORTCheckpoint*)checkpoint
|
||||
trainModelPath:(NSString*)trainModelPath
|
||||
evalModelPath:(nullable NSString*)evalModelPath
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ enum class NamedValueType {
|
|||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
@implementation ORTSession {
|
||||
ORTEnv* _env; // keep a strong reference so the ORTEnv doesn't get destroyed before this does
|
||||
std::optional<Ort::Session> _session;
|
||||
}
|
||||
|
||||
|
|
@ -44,6 +45,7 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
}
|
||||
}
|
||||
|
||||
_env = env;
|
||||
_session = Ort::Session{[env CXXAPIOrtEnv],
|
||||
path.UTF8String,
|
||||
[sessionOptions CXXAPIOrtSessionOptions]};
|
||||
|
|
|
|||
|
|
@ -19,8 +19,9 @@
|
|||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
@implementation ORTTrainingSession {
|
||||
std::optional<Ort::TrainingSession> _session;
|
||||
ORTEnv* _env; // keep a strong reference so the ORTEnv doesn't get destroyed before this does
|
||||
ORTCheckpoint* _checkpoint;
|
||||
std::optional<Ort::TrainingSession> _session;
|
||||
}
|
||||
|
||||
- (Ort::TrainingSession&)CXXAPIOrtTrainingSession {
|
||||
|
|
@ -28,7 +29,7 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
}
|
||||
|
||||
- (nullable instancetype)initWithEnv:(ORTEnv*)env
|
||||
sessionOptions:(ORTSessionOptions*)sessionOptions
|
||||
sessionOptions:(nullable ORTSessionOptions*)sessionOptions
|
||||
checkpoint:(ORTCheckpoint*)checkpoint
|
||||
trainModelPath:(NSString*)trainModelPath
|
||||
evalModelPath:(nullable NSString*)evalModelPath
|
||||
|
|
@ -39,9 +40,17 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
}
|
||||
|
||||
try {
|
||||
if (!sessionOptions) {
|
||||
sessionOptions = [[ORTSessionOptions alloc] initWithError:error];
|
||||
if (!sessionOptions) {
|
||||
return nil;
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<std::string> evalPath = utils::toStdOptionalString(evalModelPath);
|
||||
std::optional<std::string> optimizerPath = utils::toStdOptionalString(optimizerModelPath);
|
||||
|
||||
_env = env;
|
||||
_checkpoint = checkpoint;
|
||||
_session = Ort::TrainingSession{
|
||||
[env CXXAPIOrtEnv],
|
||||
|
|
@ -50,6 +59,7 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
trainModelPath.UTF8String,
|
||||
evalPath,
|
||||
optimizerPath};
|
||||
|
||||
return self;
|
||||
}
|
||||
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
|
||||
|
|
|
|||
|
|
@ -295,6 +295,32 @@ static OrtStatus* _Nullable DummyRegisterCustomOpsFn(OrtSessionOptions* /*sessio
|
|||
XCTAssertTrue([stringData isEqualToArray:outputStringData]);
|
||||
}
|
||||
|
||||
- (void)testKeepORTEnvReference {
|
||||
ORTEnv* __weak envWeak = _ortEnv;
|
||||
// Remove sole strong reference to the ORTEnv created in setUp.
|
||||
_ortEnv = nil;
|
||||
// There should be no more strong references to it.
|
||||
XCTAssertNil(envWeak);
|
||||
|
||||
// Create a new ORTEnv.
|
||||
NSError* err = nil;
|
||||
ORTEnv* env = [[ORTEnv alloc] initWithLoggingLevel:ORTLoggingLevelWarning
|
||||
error:&err];
|
||||
ORTAssertNullableResultSuccessful(env, err);
|
||||
|
||||
ORTSession* session = [[ORTSession alloc] initWithEnv:env
|
||||
modelPath:[ORTSessionTest getAddModelPath]
|
||||
sessionOptions:[ORTSessionTest makeSessionOptions]
|
||||
error:&err];
|
||||
ORTAssertNullableResultSuccessful(session, err);
|
||||
|
||||
envWeak = env;
|
||||
// Remove strong reference to the ORTEnv passed to the ORTSession initializer.
|
||||
env = nil;
|
||||
// ORTSession should keep a strong reference to it.
|
||||
XCTAssertNotNil(envWeak);
|
||||
}
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
||||
|
|
|
|||
Loading…
Reference in a new issue