onnxruntime/objectivec/ort_training_session.mm
Vrajang Parikh fd8ad9b950
Enable iOS packaging for training (#16525)
### Description
Enable support for building iOS packages/CocoaPods with training API

- Add `Training` Package variant and config files in current iOS
packaging utilities to enable creation of training packages

### Motivation and Context
This PR introduces new `Training` variant in
`build_and_assemble_ios_pods.py` script which allows creating pods for
iOS with training API enabled.

The sample script to build training pods:

```
python3 tools/ci_build/github/apple/build_and_assemble_ios_pods.py --variant Training \
--build-settings-file  tools/ci_build/github/apple/default_full_ios_training_framework_build_settings.json \ 
-b=-- path_to_protoc_exe=<path/to/protoc>
``` 

Note: build settings file should have `--enable_training` as a build
parameter.


Simply adding training packaging increases the duration of the Azure
pipeline for packaging by 70 minutes. To address this issue, we need to
parallelize pod creation. In order not to further strain the pipeline,
the changes for training packaging will be added in another PR, which
optimizes the packaging pipeline.

---------

Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
2023-07-05 13:27:59 -07:00

224 lines
7.1 KiB
Text

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#import "ort_training_session_internal.h"
#import <vector>
#import <optional>
#import <string>
#import "cxx_api.h"
#import "cxx_utils.h"
#import "error_utils.h"
#import "ort_checkpoint_internal.h"
#import "ort_session_internal.h"
#import "ort_enums_internal.h"
#import "ort_env_internal.h"
#import "ort_value_internal.h"
NS_ASSUME_NONNULL_BEGIN
@implementation ORTTrainingSession {
std::optional<Ort::TrainingSession> _session;
ORTCheckpoint* _checkpoint;
}
- (Ort::TrainingSession&)CXXAPIOrtTrainingSession {
return *_session;
}
- (nullable instancetype)initWithEnv:(ORTEnv*)env
sessionOptions:(ORTSessionOptions*)sessionOptions
checkpoint:(ORTCheckpoint*)checkpoint
trainModelPath:(NSString*)trainModelPath
evalModelPath:(nullable NSString*)evalModelPath
optimizerModelPath:(nullable NSString*)optimizerModelPath
error:(NSError**)error {
if ((self = [super init]) == nil) {
return nil;
}
try {
std::optional<std::string> evalPath = utils::toStdOptionalString(evalModelPath);
std::optional<std::string> optimizerPath = utils::toStdOptionalString(optimizerModelPath);
_checkpoint = checkpoint;
_session = Ort::TrainingSession{
[env CXXAPIOrtEnv],
[sessionOptions CXXAPIOrtSessionOptions],
[checkpoint CXXAPIOrtCheckpoint],
trainModelPath.UTF8String,
evalPath,
optimizerPath};
return self;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
- (nullable NSArray<ORTValue*>*)trainStepWithInputValues:(NSArray<ORTValue*>*)inputs
error:(NSError**)error {
try {
std::vector<const OrtValue*> inputValues = utils::getWrappedCAPIOrtValues(inputs);
size_t outputCount;
Ort::ThrowOnError(Ort::GetTrainingApi().TrainingSessionGetTrainingModelOutputCount(*_session, &outputCount));
std::vector<OrtValue*> outputValues(outputCount, nullptr);
Ort::RunOptions runOptions;
Ort::ThrowOnError(Ort::GetTrainingApi().TrainStep(
*_session,
runOptions,
inputValues.size(),
inputValues.data(),
outputValues.size(),
outputValues.data()));
return utils::wrapUnownedCAPIOrtValues(outputValues, error);
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
- (nullable NSArray<ORTValue*>*)evalStepWithInputValues:(NSArray<ORTValue*>*)inputs
error:(NSError**)error {
try {
// create vector of OrtValue from NSArray<ORTValue*> with same size as inputValues
std::vector<const OrtValue*> inputValues = utils::getWrappedCAPIOrtValues(inputs);
size_t outputCount;
Ort::ThrowOnError(Ort::GetTrainingApi().TrainingSessionGetEvalModelOutputCount(*_session, &outputCount));
std::vector<OrtValue*> outputValues(outputCount, nullptr);
Ort::RunOptions runOptions;
Ort::ThrowOnError(Ort::GetTrainingApi().EvalStep(
*_session,
runOptions,
inputValues.size(),
inputValues.data(),
outputValues.size(),
outputValues.data()));
return utils::wrapUnownedCAPIOrtValues(outputValues, error);
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
- (BOOL)lazyResetGradWithError:(NSError**)error {
try {
[self CXXAPIOrtTrainingSession].LazyResetGrad();
return YES;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}
- (BOOL)optimizerStepWithError:(NSError**)error {
try {
[self CXXAPIOrtTrainingSession].OptimizerStep();
return YES;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}
- (nullable NSArray<NSString*>*)getTrainInputNamesWithError:(NSError**)error {
try {
std::vector<std::string> inputNames = [self CXXAPIOrtTrainingSession].InputNames(true);
return utils::toNSStringNSArray(inputNames);
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
- (nullable NSArray<NSString*>*)getTrainOutputNamesWithError:(NSError**)error {
try {
std::vector<std::string> outputNames = [self CXXAPIOrtTrainingSession].OutputNames(true);
return utils::toNSStringNSArray(outputNames);
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
- (nullable NSArray<NSString*>*)getEvalInputNamesWithError:(NSError**)error {
try {
std::vector<std::string> inputNames = [self CXXAPIOrtTrainingSession].InputNames(false);
return utils::toNSStringNSArray(inputNames);
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
- (nullable NSArray<NSString*>*)getEvalOutputNamesWithError:(NSError**)error {
try {
std::vector<std::string> outputNames = [self CXXAPIOrtTrainingSession].OutputNames(false);
return utils::toNSStringNSArray(outputNames);
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
- (BOOL)registerLinearLRSchedulerWithWarmupStepCount:(int64_t)warmupStepCount
totalStepCount:(int64_t)totalStepCount
initialLr:(float)initialLr
error:(NSError**)error {
try {
[self CXXAPIOrtTrainingSession].RegisterLinearLRScheduler(warmupStepCount, totalStepCount, initialLr);
return YES;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}
- (BOOL)schedulerStepWithError:(NSError**)error {
try {
[self CXXAPIOrtTrainingSession].SchedulerStep();
return YES;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}
- (float)getLearningRateWithError:(NSError**)error {
try {
return [self CXXAPIOrtTrainingSession].GetLearningRate();
}
ORT_OBJC_API_IMPL_CATCH(error, 0.0f);
}
- (BOOL)setLearningRate:(float)lr
error:(NSError**)error {
try {
[self CXXAPIOrtTrainingSession].SetLearningRate(lr);
return YES;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}
- (BOOL)fromBufferWithValue:(ORTValue*)buffer
error:(NSError**)error {
try {
[self CXXAPIOrtTrainingSession].FromBuffer([buffer CXXAPIOrtValue]);
return YES;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}
- (nullable ORTValue*)toBufferWithTrainable:(BOOL)onlyTrainable
error:(NSError**)error {
try {
Ort::Value val = [self CXXAPIOrtTrainingSession].ToBuffer(onlyTrainable);
return [[ORTValue alloc] initWithCXXAPIOrtValue:std::move(val)
externalTensorData:nil
error:error];
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
- (BOOL)exportModelForInferenceWithOutputPath:(NSString*)inferenceModelPath
graphOutputNames:(NSArray<NSString*>*)graphOutputNames
error:(NSError**)error {
try {
[self CXXAPIOrtTrainingSession].ExportModelForInferencing(utils::toStdString(inferenceModelPath),
utils::toStdStringVector(graphOutputNames));
return YES;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}
@end
void ORTSetSeed(int64_t seed) {
Ort::SetSeed(seed);
}
NS_ASSUME_NONNULL_END