// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #import "ort_training_session_internal.h" #import #import #import #import "cxx_api.h" #import "cxx_utils.h" #import "error_utils.h" #import "ort_checkpoint_internal.h" #import "ort_session_internal.h" #import "ort_enums_internal.h" #import "ort_env_internal.h" #import "ort_value_internal.h" NS_ASSUME_NONNULL_BEGIN @implementation ORTTrainingSession { ORTEnv* _env; // keep a strong reference so the ORTEnv doesn't get destroyed before this does ORTCheckpoint* _checkpoint; std::optional _session; } - (Ort::TrainingSession&)CXXAPIOrtTrainingSession { return *_session; } - (nullable instancetype)initWithEnv:(ORTEnv*)env sessionOptions:(nullable ORTSessionOptions*)sessionOptions checkpoint:(ORTCheckpoint*)checkpoint trainModelPath:(NSString*)trainModelPath evalModelPath:(nullable NSString*)evalModelPath optimizerModelPath:(nullable NSString*)optimizerModelPath error:(NSError**)error { if ((self = [super init]) == nil) { return nil; } try { if (!sessionOptions) { sessionOptions = [[ORTSessionOptions alloc] initWithError:error]; if (!sessionOptions) { return nil; } } std::optional evalPath = utils::toStdOptionalString(evalModelPath); std::optional optimizerPath = utils::toStdOptionalString(optimizerModelPath); _env = env; _checkpoint = checkpoint; _session = Ort::TrainingSession{ [env CXXAPIOrtEnv], [sessionOptions CXXAPIOrtSessionOptions], [checkpoint CXXAPIOrtCheckpoint], trainModelPath.UTF8String, evalPath, optimizerPath}; return self; } ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) } - (nullable NSArray*)trainStepWithInputValues:(NSArray*)inputs error:(NSError**)error { try { std::vector inputValues = utils::getWrappedCAPIOrtValues(inputs); size_t outputCount; Ort::ThrowOnError(Ort::GetTrainingApi().TrainingSessionGetTrainingModelOutputCount(*_session, &outputCount)); std::vector outputValues(outputCount, nullptr); Ort::RunOptions runOptions; Ort::ThrowOnError(Ort::GetTrainingApi().TrainStep( *_session, runOptions, inputValues.size(), inputValues.data(), outputValues.size(), outputValues.data())); return utils::wrapUnownedCAPIOrtValues(outputValues, error); } ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) } - (nullable NSArray*)evalStepWithInputValues:(NSArray*)inputs error:(NSError**)error { try { // create vector of OrtValue from NSArray with same size as inputValues std::vector inputValues = utils::getWrappedCAPIOrtValues(inputs); size_t outputCount; Ort::ThrowOnError(Ort::GetTrainingApi().TrainingSessionGetEvalModelOutputCount(*_session, &outputCount)); std::vector outputValues(outputCount, nullptr); Ort::RunOptions runOptions; Ort::ThrowOnError(Ort::GetTrainingApi().EvalStep( *_session, runOptions, inputValues.size(), inputValues.data(), outputValues.size(), outputValues.data())); return utils::wrapUnownedCAPIOrtValues(outputValues, error); } ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) } - (BOOL)lazyResetGradWithError:(NSError**)error { try { [self CXXAPIOrtTrainingSession].LazyResetGrad(); return YES; } ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) } - (BOOL)optimizerStepWithError:(NSError**)error { try { [self CXXAPIOrtTrainingSession].OptimizerStep(); return YES; } ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) } - (nullable NSArray*)getTrainInputNamesWithError:(NSError**)error { try { std::vector inputNames = [self CXXAPIOrtTrainingSession].InputNames(true); return utils::toNSStringNSArray(inputNames); } ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) } - (nullable NSArray*)getTrainOutputNamesWithError:(NSError**)error { try { std::vector outputNames = [self CXXAPIOrtTrainingSession].OutputNames(true); return utils::toNSStringNSArray(outputNames); } ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) } - (nullable NSArray*)getEvalInputNamesWithError:(NSError**)error { try { std::vector inputNames = [self CXXAPIOrtTrainingSession].InputNames(false); return utils::toNSStringNSArray(inputNames); } ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) } - (nullable NSArray*)getEvalOutputNamesWithError:(NSError**)error { try { std::vector outputNames = [self CXXAPIOrtTrainingSession].OutputNames(false); return utils::toNSStringNSArray(outputNames); } ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) } - (BOOL)registerLinearLRSchedulerWithWarmupStepCount:(int64_t)warmupStepCount totalStepCount:(int64_t)totalStepCount initialLr:(float)initialLr error:(NSError**)error { try { [self CXXAPIOrtTrainingSession].RegisterLinearLRScheduler(warmupStepCount, totalStepCount, initialLr); return YES; } ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) } - (BOOL)schedulerStepWithError:(NSError**)error { try { [self CXXAPIOrtTrainingSession].SchedulerStep(); return YES; } ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) } - (float)getLearningRateWithError:(NSError**)error { try { return [self CXXAPIOrtTrainingSession].GetLearningRate(); } ORT_OBJC_API_IMPL_CATCH(error, 0.0f); } - (BOOL)setLearningRate:(float)lr error:(NSError**)error { try { [self CXXAPIOrtTrainingSession].SetLearningRate(lr); return YES; } ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) } - (BOOL)fromBufferWithValue:(ORTValue*)buffer error:(NSError**)error { try { [self CXXAPIOrtTrainingSession].FromBuffer([buffer CXXAPIOrtValue]); return YES; } ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) } - (nullable ORTValue*)toBufferWithTrainable:(BOOL)onlyTrainable error:(NSError**)error { try { Ort::Value val = [self CXXAPIOrtTrainingSession].ToBuffer(onlyTrainable); return [[ORTValue alloc] initWithCXXAPIOrtValue:std::move(val) externalTensorData:nil error:error]; } ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) } - (BOOL)exportModelForInferenceWithOutputPath:(NSString*)inferenceModelPath graphOutputNames:(NSArray*)graphOutputNames error:(NSError**)error { try { [self CXXAPIOrtTrainingSession].ExportModelForInferencing(utils::toStdString(inferenceModelPath), utils::toStdStringVector(graphOutputNames)); return YES; } ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) } @end void ORTSetSeed(int64_t seed) { Ort::SetSeed(seed); } NS_ASSUME_NONNULL_END