diff --git a/cmake/onnxruntime_objectivec.cmake b/cmake/onnxruntime_objectivec.cmake index 494f28f403..4be2f51a96 100644 --- a/cmake/onnxruntime_objectivec.cmake +++ b/cmake/onnxruntime_objectivec.cmake @@ -42,11 +42,14 @@ file(GLOB onnxruntime_objc_srcs CONFIGURE_DEPENDS if(NOT onnxruntime_ENABLE_TRAINING_APIS) list(REMOVE_ITEM onnxruntime_objc_headers - "${OBJC_ROOT}/include/ort_checkpoint.h") + "${OBJC_ROOT}/include/ort_checkpoint.h" + "${OBJC_ROOT}/include/ort_training_session.h") list(REMOVE_ITEM onnxruntime_objc_srcs "${OBJC_ROOT}/ort_checkpoint_internal.h" - "${OBJC_ROOT}/ort_checkpoint.mm") + "${OBJC_ROOT}/ort_checkpoint.mm" + "${OBJC_ROOT}/ort_training_session_internal.h" + "${OBJC_ROOT}/ort_training_session.mm") endif() @@ -124,7 +127,9 @@ if(onnxruntime_BUILD_UNIT_TESTS) if(NOT onnxruntime_ENABLE_TRAINING_APIS) list(REMOVE_ITEM onnxruntime_objc_test_srcs - "${OBJC_ROOT}/test/ort_checkpoint_test.mm") + "${OBJC_ROOT}/test/ort_checkpoint_test.mm" + "${OBJC_ROOT}/test/ort_training_session_test.mm" + "${OBJC_ROOT}/test/ort_training_utils_test.mm") endif() diff --git a/objectivec/cxx_utils.h b/objectivec/cxx_utils.h new file mode 100644 index 0000000000..0b3666d3a2 --- /dev/null +++ b/objectivec/cxx_utils.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#import + +#include +#include +#include + +#import "cxx_api.h" + +NS_ASSUME_NONNULL_BEGIN +@class ORTValue; + +namespace utils { + +NSString* toNSString(const std::string& str); +NSString* _Nullable toNullableNSString(const std::optional& str); + +std::string toStdString(NSString* str); +std::optional toStdOptionalString(NSString* _Nullable str); + +std::vector toStdStringVector(NSArray* strs); +NSArray* toNSStringNSArray(const std::vector& strs); + +NSArray* _Nullable wrapUnownedCAPIOrtValues(const std::vector& values, NSError** error); + +std::vector getWrappedCAPIOrtValues(NSArray* values); + +} // namespace utils + +NS_ASSUME_NONNULL_END diff --git a/objectivec/cxx_utils.mm b/objectivec/cxx_utils.mm new file mode 100644 index 0000000000..9ebebeee0c --- /dev/null +++ b/objectivec/cxx_utils.mm @@ -0,0 +1,93 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#import "cxx_utils.h" + +#import +#import +#import + +#import "error_utils.h" + +#import "ort_value_internal.h" + +NS_ASSUME_NONNULL_BEGIN + +namespace utils { + +NSString* toNSString(const std::string& str) { + NSString* nsStr = [NSString stringWithUTF8String:str.c_str()]; + if (!nsStr) { + ORT_CXX_API_THROW("Failed to convert std::string to NSString", ORT_INVALID_ARGUMENT); + } + + return nsStr; +} + +NSString* _Nullable toNullableNSString(const std::optional& str) { + if (str.has_value()) { + return toNSString(*str); + } + return nil; +} + +std::string toStdString(NSString* str) { + return std::string([str UTF8String]); +} + +std::optional toStdOptionalString(NSString* _Nullable str) { + if (str) { + return std::optional([str UTF8String]); + } + return std::nullopt; +} + +std::vector toStdStringVector(NSArray* strs) { + std::vector result; + result.reserve(strs.count); + for (NSString* str in strs) { + result.push_back([str UTF8String]); + } + return result; +} + +NSArray* toNSStringNSArray(const std::vector& strs) { + NSMutableArray* result = [NSMutableArray arrayWithCapacity:strs.size()]; + for (const std::string& str : strs) { + NSString* nsStr = [NSString stringWithUTF8String:str.c_str()]; + if (nsStr) { + [result addObject:nsStr]; + } else { + ORT_CXX_API_THROW("Failed to convert std::string to NSString", ORT_INVALID_ARGUMENT); + } + } + return result; +} + +NSArray* _Nullable wrapUnownedCAPIOrtValues(const std::vector& values, NSError** error) { + NSMutableArray* result = [NSMutableArray arrayWithCapacity:values.size()]; + for (size_t i = 0; i < values.size(); ++i) { + ORTValue* val = [[ORTValue alloc] initWithCAPIOrtValue:values[i] externalTensorData:nil error:error]; + if (!val) { + // clean up all the C API Ortvalues which haven't been wrapped by ORTValue + for (size_t j = i; j < values.size(); ++j) { + Ort::GetApi().ReleaseValue(values[j]); + } + return nil; + } + [result addObject:val]; + } + return result; +} + +std::vector getWrappedCAPIOrtValues(NSArray* values) { + std::vector result; + for (ORTValue* val in values) { + result.push_back(static_cast([val CXXAPIOrtValue])); + } + return result; +} + +} // namespace utils + +NS_ASSUME_NONNULL_END diff --git a/objectivec/include/onnxruntime_training.h b/objectivec/include/onnxruntime_training.h new file mode 100644 index 0000000000..504447ea6f --- /dev/null +++ b/objectivec/include/onnxruntime_training.h @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// this header contains the entire ONNX Runtime training Objective-C API +// the headers below can also be imported individually + +#import "onnxruntime.h" +#import "ort_checkpoint.h" +#import "ort_training_session.h" diff --git a/objectivec/include/ort_checkpoint.h b/objectivec/include/ort_checkpoint.h index 8c55f21943..2b0144a38d 100644 --- a/objectivec/include/ort_checkpoint.h +++ b/objectivec/include/ort_checkpoint.h @@ -11,8 +11,8 @@ NS_ASSUME_NONNULL_BEGIN * An ORT checkpoint is a snapshot of the state of a model at a given point in time. * * This class holds the entire training session state that includes model parameters, - * their gradients, optimizer parameters, and user properties. The ORTTrainingSession leverages the - * ORTCheckpointState by accessing and updating the contained training state. + * their gradients, optimizer parameters, and user properties. The `ORTTrainingSession` leverages the + * `ORTCheckpoint` by accessing and updating the contained training state. * * Available since v1.16.0. * diff --git a/objectivec/include/ort_training_session.h b/objectivec/include/ort_training_session.h new file mode 100644 index 0000000000..54b7e54289 --- /dev/null +++ b/objectivec/include/ort_training_session.h @@ -0,0 +1,261 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef ENABLE_TRAINING_APIS +#import +#include + +NS_ASSUME_NONNULL_BEGIN + +@class ORTCheckpoint; +@class ORTEnv; +@class ORTValue; +@class ORTSessionOptions; + +/** + * Trainer class that provides methods to train, evaluate and optimize ONNX models. + * + * The training session requires four training artifacts: + * 1. Training onnx model + * 2. Evaluation onnx model (optional) + * 3. Optimizer onnx model + * 4. Checkpoint directory + * + * [onnxruntime-training python utility](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md) + * can be used to generate above training artifacts. + * + * Available since v1.16.0. + */ +@interface ORTTrainingSession : NSObject + +- (instancetype)init NS_UNAVAILABLE; + +/** + * Creates a training session from the training artifacts that can be used to begin or resume training. + * + * The initializer instantiates the training session based on provided env and session options, which can be used to + * begin or resume training from a given checkpoint state. The checkpoint state represents the parameters of training + * session which will be moved to the device specified in the session option if needed. + * + * @param env The `ORTEnv` instance to use for the training session. + * @param sessionOptions The `ORTSessionOptions` to use for the training session. + * @param checkpoint Training states that are used as a starting point for training. + * @param trainModelPath The path to the training onnx model. + * @param evalModelPath The path to the evaluation onnx model. + * @param optimizerModelPath The path to the optimizer onnx model used to perform gradient descent. + * @param error Optional error information set if an error occurs. + * @return The instance, or nil if an error occurs. + * + */ +- (nullable instancetype)initWithEnv:(ORTEnv*)env + sessionOptions:(ORTSessionOptions*)sessionOptions + checkpoint:(ORTCheckpoint*)checkpoint + trainModelPath:(NSString*)trainModelPath + evalModelPath:(nullable NSString*)evalModelPath + optimizerModelPath:(nullable NSString*)optimizerModelPath + error:(NSError**)error NS_DESIGNATED_INITIALIZER; + +/** + * Performs a training step, which is equivalent to a forward and backward propagation in a single step. + * + * The training step computes the outputs of the training model and the gradients of the trainable parameters + * for the given input values. The train step is performed based on the training model that was provided to the training session. + * It is equivalent to running forward and backward propagation in a single step. The computed gradients are stored inside + * the training session state so they can be later consumed by `optimizerStep`. The gradients can be lazily reset by + * calling `lazyResetGrad` method. + * + * @param inputs The input values to the training model. + * @param error Optional error information set if an error occurs. + * @return The output values of the training model. + */ +- (nullable NSArray*)trainStepWithInputValues:(NSArray*)inputs + error:(NSError**)error; + +/** + * Performs a evaluation step that computes the outputs of the evaluation model for the given inputs. + * The eval step is performed based on the evaluation model that was provided to the training session. + * + * @param inputs The input values to the eval model. + * @param error Optional error information set if an error occurs. + * @return The output values of the eval model. + * + */ +- (nullable NSArray*)evalStepWithInputValues:(NSArray*)inputs + error:(NSError**)error; + +/** + * Reset the gradients of all trainable parameters to zero lazily. + * + * Calling this method sets the internal state of the training session such that the gradients of the trainable parameters + * in the ORTCheckpoint will be scheduled to be reset just before the new gradients are computed on the next + * invocation of the `trainStep` method. + * + * @param error Optional error information set if an error occurs. + * @return YES if the gradients are set to reset successfully, NO otherwise. + */ +- (BOOL)lazyResetGradWithError:(NSError**)error; + +/** + * Performs the weight updates for the trainable parameters using the optimizer model. The optimizer step is performed + * based on the optimizer model that was provided to the training session. The updated parameters are stored inside the + * training state so that they can be used by the next `trainStep` method call. + * + * @param error Optional error information set if an error occurs. + * @return YES if the optimizer step was performed successfully, NO otherwise. + */ +- (BOOL)optimizerStepWithError:(NSError**)error; + +/** + * Returns the names of the user inputs for the training model that can be associated with + * the `ORTValue` provided to the `trainStep`. + * + * @param error Optional error information set if an error occurs. + * @return The names of the user inputs for the training model. + */ +- (nullable NSArray*)getTrainInputNamesWithError:(NSError**)error; + +/** + * Returns the names of the user inputs for the evaluation model that can be associated with + * the `ORTValue` provided to the `evalStep`. + * + * @param error Optional error information set if an error occurs. + * @return The names of the user inputs for the evaluation model. + */ +- (nullable NSArray*)getEvalInputNamesWithError:(NSError**)error; + +/** + * Returns the names of the user outputs for the training model that can be associated with + * the `ORTValue` returned by the `trainStep`. + * + * @param error Optional error information set if an error occurs. + * @return The names of the user outputs for the training model. + */ +- (nullable NSArray*)getTrainOutputNamesWithError:(NSError**)error; + +/** + * Returns the names of the user outputs for the evaluation model that can be associated with + * the `ORTValue` returned by the `evalStep`. + * + * @param error Optional error information set if an error occurs. + * @return The names of the user outputs for the evaluation model. + */ +- (nullable NSArray*)getEvalOutputNamesWithError:(NSError**)error; + +/** + * Registers a linear learning rate scheduler for the training session. + * + * The scheduler gradually decreases the learning rate from the initial value to zero over the course of the training. + * The decrease is performed by multiplying the current learning rate by a linearly updated factor. + * Before the decrease, the learning rate is gradually increased from zero to the initial value during a warmup phase. + * + * @param warmupStepCount The number of steps to perform the linear warmup. + * @param totalStepCount The total number of steps to perform the linear decay. + * @param initialLr The initial learning rate. + * @param error Optional error information set if an error occurs. + * @return YES if the scheduler was registered successfully, NO otherwise. + */ +- (BOOL)registerLinearLRSchedulerWithWarmupStepCount:(int64_t)warmupStepCount + totalStepCount:(int64_t)totalStepCount + initialLr:(float)initialLr + error:(NSError**)error; + +/** + * Update the learning rate based on the registered learning rate scheduler. + * + * Performs a scheduler step that updates the learning rate that is being used by the training session. + * This function should typically be called before invoking the optimizer step for each round, or as necessary + * to update the learning rate being used by the training session. + * + * @note A valid predefined learning rate scheduler must be first registered to invoke this method. + * + * @param error Optional error information set if an error occurs. + * @return YES if the scheduler step was performed successfully, NO otherwise. + */ +- (BOOL)schedulerStepWithError:(NSError**)error; + +/** + * Returns the current learning rate being used by the training session. + * + * @param error Optional error information set if an error occurs. + * @return The current learning rate or 0.0f if an error occurs. + */ +- (float)getLearningRateWithError:(NSError**)error __attribute__((swift_error(nonnull_error))); + +/** + * Sets the learning rate being used by the training session. + * + * The current learning rate is maintained by the training session and can be overwritten by invoking this method + * with the desired learning rate. This function should not be used when a valid learning rate scheduler is registered. + * It should be used either to set the learning rate derived from a custom learning rate scheduler or to set a constant + * learning rate to be used throughout the training session. + * + * @note It does not set the initial learning rate that may be needed by the predefined learning rate schedulers. + * To set the initial learning rate for learning rate schedulers, use the `registerLinearLRScheduler` method. + * + * @param lr The learning rate to be used by the training session. + * @param error Optional error information set if an error occurs. + * @return YES if the learning rate was set successfully, NO otherwise. + */ +- (BOOL)setLearningRate:(float)lr + error:(NSError**)error; + +/** + * Loads the training session model parameters from a contiguous buffer. + * + * @param buffer Contiguous buffer to load the parameters from. + * @param error Optional error information set if an error occurs. + * @return YES if the parameters were loaded successfully, NO otherwise. + */ +- (BOOL)fromBufferWithValue:(ORTValue*)buffer + error:(NSError**)error; + +/** + * Returns a contiguous buffer that holds a copy of all training state parameters. + * + * @param onlyTrainable If YES, returns a buffer that holds only the trainable parameters, otherwise returns a buffer + * that holds all the parameters. + * @param error Optional error information set if an error occurs. + * @return A contiguous buffer that holds a copy of all training state parameters. + */ +- (nullable ORTValue*)toBufferWithTrainable:(BOOL)onlyTrainable + error:(NSError**)error; + +/** + * Exports the training session model that can be used for inference. + * + * If the training session was provided with an eval model, the training session can generate an inference model if it + * knows the inference graph outputs. The input inference graph outputs are used to prune the eval model so that the + * inference model's outputs align with the provided outputs. The exported model is saved at the path provided and + * can be used for inferencing with `ORTSession`. + * + * @note The method reloads the eval model from the path provided to the initializer and expects this path to be valid. + * + * @param inferenceModelPath The path to the serialized the inference model. + * @param graphOutputNames The names of the outputs that are needed in the inference model. + * @param error Optional error information set if an error occurs. + * @return YES if the inference model was exported successfully, NO otherwise. + */ +- (BOOL)exportModelForInferenceWithOutputPath:(NSString*)inferenceModelPath + graphOutputNames:(NSArray*)graphOutputNames + error:(NSError**)error; +@end + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * This function sets the seed for generating random numbers. + * Use this function to generate reproducible results. It should be noted that completely reproducible results are not guaranteed. + * + * @param seed Manually set seed to use for random number generation. + */ +void ORTSetSeed(int64_t seed); + +#ifdef __cplusplus +} +#endif + +NS_ASSUME_NONNULL_END + +#endif // ENABLE_TRAINING_APIS diff --git a/objectivec/ort_training_session.mm b/objectivec/ort_training_session.mm new file mode 100644 index 0000000000..638492d5ff --- /dev/null +++ b/objectivec/ort_training_session.mm @@ -0,0 +1,227 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef ENABLE_TRAINING_APIS +#import "ort_training_session_internal.h" + +#import +#import +#import + +#import "cxx_api.h" +#import "cxx_utils.h" +#import "error_utils.h" +#import "ort_checkpoint_internal.h" +#import "ort_session_internal.h" +#import "ort_enums_internal.h" +#import "ort_env_internal.h" +#import "ort_value_internal.h" + +NS_ASSUME_NONNULL_BEGIN + +@implementation ORTTrainingSession { + std::optional _session; + ORTCheckpoint* _checkpoint; +} + +- (Ort::TrainingSession&)CXXAPIOrtTrainingSession { + return *_session; +} + +- (nullable instancetype)initWithEnv:(ORTEnv*)env + sessionOptions:(ORTSessionOptions*)sessionOptions + checkpoint:(ORTCheckpoint*)checkpoint + trainModelPath:(NSString*)trainModelPath + evalModelPath:(nullable NSString*)evalModelPath + optimizerModelPath:(nullable NSString*)optimizerModelPath + error:(NSError**)error { + if ((self = [super init]) == nil) { + return nil; + } + + try { + std::optional evalPath = utils::toStdOptionalString(evalModelPath); + std::optional optimizerPath = utils::toStdOptionalString(optimizerModelPath); + + _checkpoint = checkpoint; + _session = Ort::TrainingSession{ + [env CXXAPIOrtEnv], + [sessionOptions CXXAPIOrtSessionOptions], + [checkpoint CXXAPIOrtCheckpoint], + trainModelPath.UTF8String, + evalPath, + optimizerPath}; + return self; + } + ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) +} + +- (nullable NSArray*)trainStepWithInputValues:(NSArray*)inputs + error:(NSError**)error { + try { + std::vector inputValues = utils::getWrappedCAPIOrtValues(inputs); + + size_t outputCount; + Ort::ThrowOnError(Ort::GetTrainingApi().TrainingSessionGetTrainingModelOutputCount(*_session, &outputCount)); + std::vector outputValues(outputCount, nullptr); + + Ort::RunOptions runOptions; + Ort::ThrowOnError(Ort::GetTrainingApi().TrainStep( + *_session, + runOptions, + inputValues.size(), + inputValues.data(), + outputValues.size(), + outputValues.data())); + + return utils::wrapUnownedCAPIOrtValues(outputValues, error); + } + ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) +} +- (nullable NSArray*)evalStepWithInputValues:(NSArray*)inputs + error:(NSError**)error { + try { + // create vector of OrtValue from NSArray with same size as inputValues + std::vector inputValues = utils::getWrappedCAPIOrtValues(inputs); + + size_t outputCount; + Ort::ThrowOnError(Ort::GetTrainingApi().TrainingSessionGetEvalModelOutputCount(*_session, &outputCount)); + std::vector outputValues(outputCount, nullptr); + + Ort::RunOptions runOptions; + Ort::ThrowOnError(Ort::GetTrainingApi().EvalStep( + *_session, + runOptions, + inputValues.size(), + inputValues.data(), + outputValues.size(), + outputValues.data())); + + return utils::wrapUnownedCAPIOrtValues(outputValues, error); + } + ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) +} + +- (BOOL)lazyResetGradWithError:(NSError**)error { + try { + [self CXXAPIOrtTrainingSession].LazyResetGrad(); + return YES; + } + ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) +} + +- (BOOL)optimizerStepWithError:(NSError**)error { + try { + [self CXXAPIOrtTrainingSession].OptimizerStep(); + return YES; + } + ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) +} + +- (nullable NSArray*)getTrainInputNamesWithError:(NSError**)error { + try { + std::vector inputNames = [self CXXAPIOrtTrainingSession].InputNames(true); + return utils::toNSStringNSArray(inputNames); + } + ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) +} + +- (nullable NSArray*)getTrainOutputNamesWithError:(NSError**)error { + try { + std::vector outputNames = [self CXXAPIOrtTrainingSession].OutputNames(true); + return utils::toNSStringNSArray(outputNames); + } + ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) +} + +- (nullable NSArray*)getEvalInputNamesWithError:(NSError**)error { + try { + std::vector inputNames = [self CXXAPIOrtTrainingSession].InputNames(false); + return utils::toNSStringNSArray(inputNames); + } + ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) +} + +- (nullable NSArray*)getEvalOutputNamesWithError:(NSError**)error { + try { + std::vector outputNames = [self CXXAPIOrtTrainingSession].OutputNames(false); + return utils::toNSStringNSArray(outputNames); + } + ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) +} + +- (BOOL)registerLinearLRSchedulerWithWarmupStepCount:(int64_t)warmupStepCount + totalStepCount:(int64_t)totalStepCount + initialLr:(float)initialLr + error:(NSError**)error { + try { + [self CXXAPIOrtTrainingSession].RegisterLinearLRScheduler(warmupStepCount, totalStepCount, initialLr); + return YES; + } + ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) +} + +- (BOOL)schedulerStepWithError:(NSError**)error { + try { + [self CXXAPIOrtTrainingSession].SchedulerStep(); + return YES; + } + ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) +} + +- (float)getLearningRateWithError:(NSError**)error { + try { + return [self CXXAPIOrtTrainingSession].GetLearningRate(); + } + ORT_OBJC_API_IMPL_CATCH(error, 0.0f); +} + +- (BOOL)setLearningRate:(float)lr + error:(NSError**)error { + try { + [self CXXAPIOrtTrainingSession].SetLearningRate(lr); + return YES; + } + ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) +} + +- (BOOL)fromBufferWithValue:(ORTValue*)buffer + error:(NSError**)error { + try { + [self CXXAPIOrtTrainingSession].FromBuffer([buffer CXXAPIOrtValue]); + return YES; + } + ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) +} + +- (nullable ORTValue*)toBufferWithTrainable:(BOOL)onlyTrainable + error:(NSError**)error { + try { + Ort::Value val = [self CXXAPIOrtTrainingSession].ToBuffer(onlyTrainable); + return [[ORTValue alloc] initWithCAPIOrtValue:val.release() + externalTensorData:nil + error:error]; + } + ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) +} + +- (BOOL)exportModelForInferenceWithOutputPath:(NSString*)inferenceModelPath + graphOutputNames:(NSArray*)graphOutputNames + error:(NSError**)error { + try { + [self CXXAPIOrtTrainingSession].ExportModelForInferencing(utils::toStdString(inferenceModelPath), + utils::toStdStringVector(graphOutputNames)); + return YES; + } + ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) +} + +@end + +void ORTSetSeed(int64_t seed) { + Ort::SetSeed(seed); +} + +NS_ASSUME_NONNULL_END + +#endif // ENABLE_TRAINING_APIS diff --git a/objectivec/ort_training_session_internal.h b/objectivec/ort_training_session_internal.h new file mode 100644 index 0000000000..402c84eb5b --- /dev/null +++ b/objectivec/ort_training_session_internal.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef ENABLE_TRAINING_APIS +#import "ort_training_session.h" + +#import "cxx_api.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface ORTTrainingSession () + +- (Ort::TrainingSession&)CXXAPIOrtTrainingSession; + +@end + +NS_ASSUME_NONNULL_END + +#endif // ENABLE_TRAINING_APIS diff --git a/objectivec/test/assertion_utils.h b/objectivec/test/assertion_utils.h index f2b73e6d53..2b72435b95 100644 --- a/objectivec/test/assertion_utils.h +++ b/objectivec/test/assertion_utils.h @@ -29,4 +29,18 @@ NS_ASSUME_NONNULL_BEGIN XCTAssertNotNil(error); \ } while (0) +#define ORTAssertEqualFloatAndNoError(expected, result, error) \ + do { \ + XCTAssertEqualWithAccuracy(expected, result, 1e-3f, @"Expected %f but got %f. Error:%@", expected, result, error); \ + XCTAssertNil(error); \ + } while (0) + +#define ORTAssertEqualFloatArrays(expected, result) \ + do { \ + XCTAssertEqual(expected.count, result.count); \ + for (size_t i = 0; i < expected.count; ++i) { \ + XCTAssertEqualWithAccuracy([expected[i] floatValue], [result[i] floatValue], 1e-3f); \ + } \ + } while (0) + NS_ASSUME_NONNULL_END diff --git a/objectivec/test/ort_checkpoint_test.mm b/objectivec/test/ort_checkpoint_test.mm index 01788a5bc0..df97dcf01d 100644 --- a/objectivec/test/ort_checkpoint_test.mm +++ b/objectivec/test/ort_checkpoint_test.mm @@ -5,7 +5,11 @@ #import #import "ort_checkpoint.h" +#import "ort_training_session.h" #import "ort_env.h" +#import "ort_session.h" + +#import "test/test_utils.h" #import "test/assertion_utils.h" NS_ASSUME_NONNULL_BEGIN @@ -27,22 +31,41 @@ NS_ASSUME_NONNULL_BEGIN ORTAssertNullableResultSuccessful(_ortEnv, err); } -- (NSString*)getCheckpointPath { ++ (NSString*)getCheckpointPath { NSBundle* bundle = [NSBundle bundleForClass:[ORTCheckpointTest class]]; NSString* path = [[bundle resourcePath] stringByAppendingPathComponent:@"checkpoint.ckpt"]; return path; } -- (void)testLoadCheckpoint { ++ (NSString*)getTrainingModelPath { + NSBundle* bundle = [NSBundle bundleForClass:[ORTCheckpointTest class]]; + NSString* path = [[bundle resourcePath] stringByAppendingPathComponent:@"training_model.onnx"]; + return path; +} + +- (void)testSaveCheckpoint { NSError* error = nil; - ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[self getCheckpointPath] error:&error]; + ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[ORTCheckpointTest getCheckpointPath] error:&error]; + ORTAssertNullableResultSuccessful(checkpoint, error); + + // save checkpoint + NSString* path = [test_utils::createTemporaryDirectory(self) stringByAppendingPathComponent:@"save_checkpoint.ckpt"]; + XCTAssertNotNil(path); + BOOL result = [checkpoint saveCheckpointToPath:path withOptimizerState:NO error:&error]; + + ORTAssertBoolResultSuccessful(result, error); +} + +- (void)testInitCheckpoint { + NSError* error = nil; + ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[ORTCheckpointTest getCheckpointPath] error:&error]; ORTAssertNullableResultSuccessful(checkpoint, error); } - (void)testIntProperty { NSError* error = nil; // Load checkpoint - ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[self getCheckpointPath] error:&error]; + ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[ORTCheckpointTest getCheckpointPath] error:&error]; ORTAssertNullableResultSuccessful(checkpoint, error); // Add property @@ -57,7 +80,7 @@ NS_ASSUME_NONNULL_BEGIN - (void)testFloatProperty { NSError* error = nil; // Load checkpoint - ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[self getCheckpointPath] error:&error]; + ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[ORTCheckpointTest getCheckpointPath] error:&error]; ORTAssertNullableResultSuccessful(checkpoint, error); // Add property @@ -72,7 +95,7 @@ NS_ASSUME_NONNULL_BEGIN - (void)testStringProperty { NSError* error = nil; // Load checkpoint - ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[self getCheckpointPath] error:&error]; + ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[ORTCheckpointTest getCheckpointPath] error:&error]; ORTAssertNullableResultSuccessful(checkpoint, error); // Add property diff --git a/objectivec/test/ort_training_session_test.mm b/objectivec/test/ort_training_session_test.mm new file mode 100644 index 0000000000..30ef51f3a0 --- /dev/null +++ b/objectivec/test/ort_training_session_test.mm @@ -0,0 +1,362 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef ENABLE_TRAINING_APIS +#import + +#import "ort_checkpoint.h" +#import "ort_training_session.h" +#import "ort_env.h" +#import "ort_session.h" +#import "ort_value.h" + +#import "test/test_utils.h" +#import "test/assertion_utils.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface ORTTrainingSessionTest : XCTestCase +@property(readonly, nullable) ORTEnv* ortEnv; +@property(readonly, nullable) ORTCheckpoint* checkpoint; +@property(readonly, nullable) ORTTrainingSession* session; +@end + +@implementation ORTTrainingSessionTest + +- (void)setUp { + [super setUp]; + + self.continueAfterFailure = NO; + + NSError* err = nil; + _ortEnv = [[ORTEnv alloc] initWithLoggingLevel:ORTLoggingLevelWarning + error:&err]; + ORTAssertNullableResultSuccessful(_ortEnv, err); + _checkpoint = [[ORTCheckpoint alloc] initWithPath:[ORTTrainingSessionTest + getFilePathFromName:@"checkpoint.ckpt"] + error:&err]; + ORTAssertNullableResultSuccessful(_checkpoint, err); + _session = [self makeTrainingSessionWithCheckpoint:_checkpoint]; +} + ++ (NSString*)getFilePathFromName:(NSString*)name { + NSBundle* bundle = [NSBundle bundleForClass:[ORTTrainingSessionTest class]]; + NSString* path = [[bundle resourcePath] stringByAppendingPathComponent:name]; + return path; +} + ++ (NSMutableData*)loadTensorDataFromFile:(NSString*)filePath skipHeader:(BOOL)skipHeader { + NSError* error = nil; + NSString* fileContents = [NSString stringWithContentsOfFile:filePath + encoding:NSUTF8StringEncoding + error:&error]; + ORTAssertNullableResultSuccessful(fileContents, error); + + NSArray* lines = [fileContents componentsSeparatedByCharactersInSet:[NSCharacterSet newlineCharacterSet]]; + + if (skipHeader) { + lines = [lines subarrayWithRange:NSMakeRange(1, lines.count - 1)]; + } + + NSArray* dataArray = [lines[0] componentsSeparatedByCharactersInSet: + [NSCharacterSet characterSetWithCharactersInString:@",[] "]]; + NSMutableData* tensorData = [NSMutableData data]; + + for (NSString* str in dataArray) { + if (str.length > 0) { + float value = [str floatValue]; + [tensorData appendBytes:&value length:sizeof(float)]; + } + } + + return tensorData; +} + +- (ORTTrainingSession*)makeTrainingSessionWithCheckpoint:(ORTCheckpoint*)checkpoint { + NSError* error = nil; + ORTSessionOptions* sessionOptions = [[ORTSessionOptions alloc] initWithError:&error]; + ORTAssertNullableResultSuccessful(sessionOptions, error); + + ORTTrainingSession* session = [[ORTTrainingSession alloc] + initWithEnv:self.ortEnv + sessionOptions:sessionOptions + checkpoint:checkpoint + trainModelPath:[ORTTrainingSessionTest getFilePathFromName:@"training_model.onnx"] + evalModelPath:[ORTTrainingSessionTest getFilePathFromName:@"eval_model.onnx"] + optimizerModelPath:[ORTTrainingSessionTest getFilePathFromName:@"adamw.onnx"] + error:&error]; + + ORTAssertNullableResultSuccessful(session, error); + return session; +} + +- (void)testInitTrainingSession { + NSError* error = nil; + + // check that inputNames contains input-0 + NSArray* inputNames = [self.session getTrainInputNamesWithError:&error]; + ORTAssertNullableResultSuccessful(inputNames, error); + + XCTAssertTrue(inputNames.count > 0); + XCTAssertTrue([inputNames containsObject:@"input-0"]); + + // check that outNames contains onnx::loss::21273 + NSArray* outputNames = [self.session getTrainOutputNamesWithError:&error]; + ORTAssertNullableResultSuccessful(outputNames, error); + + XCTAssertTrue(outputNames.count > 0); + XCTAssertTrue([outputNames containsObject:@"onnx::loss::21273"]); +} + +- (void)testInitTrainingSessionWithEval { + NSError* error = nil; + + // check that inputNames contains input-0 + NSArray* inputNames = [self.session getEvalInputNamesWithError:&error]; + ORTAssertNullableResultSuccessful(inputNames, error); + + XCTAssertTrue(inputNames.count > 0); + XCTAssertTrue([inputNames containsObject:@"input-0"]); + + // check that outNames contains onnx::loss::21273 + NSArray* outputNames = [self.session getEvalOutputNamesWithError:&error]; + ORTAssertNullableResultSuccessful(outputNames, error); + + XCTAssertTrue(outputNames.count > 0); + XCTAssertTrue([outputNames containsObject:@"onnx::loss::21273"]); +} + +- (void)runTrainStep { + // load input and expected output + NSError* error = nil; + NSMutableData* expectedOutput = [ORTTrainingSessionTest loadTensorDataFromFile:[ORTTrainingSessionTest + getFilePathFromName:@"loss_1.out"] + skipHeader:YES]; + + NSMutableData* input = [ORTTrainingSessionTest loadTensorDataFromFile:[ORTTrainingSessionTest + getFilePathFromName:@"input-0.in"] + skipHeader:YES]; + + int32_t labels[] = {1, 1}; + + // create ORTValue array for input and labels + NSMutableArray* inputValues = [NSMutableArray array]; + + ORTValue* inputTensor = [[ORTValue alloc] initWithTensorData:input + elementType:ORTTensorElementDataTypeFloat + shape:@[ @2, @784 ] + error:&error]; + ORTAssertNullableResultSuccessful(inputTensor, error); + [inputValues addObject:inputTensor]; + + ORTValue* labelTensor = [[ORTValue alloc] initWithTensorData:[NSMutableData dataWithBytes:labels + length:sizeof(labels)] + elementType:ORTTensorElementDataTypeInt32 + shape:@[ @2 ] + error:&error]; + + ORTAssertNullableResultSuccessful(labelTensor, error); + [inputValues addObject:labelTensor]; + + NSArray* outputs = [self.session trainStepWithInputValues:inputValues error:&error]; + ORTAssertNullableResultSuccessful(outputs, error); + XCTAssertTrue(outputs.count > 0); + + BOOL result = [self.session lazyResetGradWithError:&error]; + ORTAssertBoolResultSuccessful(result, error); + + outputs = [self.session trainStepWithInputValues:inputValues error:&error]; + ORTAssertNullableResultSuccessful(outputs, error); + XCTAssertTrue(outputs.count > 0); + + ORTValue* outputValue = outputs[0]; + ORTValueTypeInfo* typeInfo = [outputValue typeInfoWithError:&error]; + ORTAssertNullableResultSuccessful(typeInfo, error); + XCTAssertEqual(typeInfo.type, ORTValueTypeTensor); + XCTAssertNotNil(typeInfo.tensorTypeAndShapeInfo); + + ORTTensorTypeAndShapeInfo* tensorInfo = [outputValue tensorTypeAndShapeInfoWithError:&error]; + ORTAssertNullableResultSuccessful(tensorInfo, error); + XCTAssertEqual(tensorInfo.elementType, ORTTensorElementDataTypeFloat); + + NSMutableData* tensorData = [outputValue tensorDataWithError:&error]; + ORTAssertNullableResultSuccessful(tensorData, error); + ORTAssertEqualFloatArrays(test_utils::getFloatArrayFromData(tensorData), + test_utils::getFloatArrayFromData(expectedOutput)); +} + +- (void)testTrainStepOutput { + [self runTrainStep]; +} + +- (void)testOptimizerStep { + // load input and expected output + NSError* error = nil; + NSMutableData* expectedOutput1 = [ORTTrainingSessionTest loadTensorDataFromFile:[ORTTrainingSessionTest + getFilePathFromName:@"loss_1.out"] + skipHeader:YES]; + + NSMutableData* expectedOutput2 = [ORTTrainingSessionTest loadTensorDataFromFile:[ORTTrainingSessionTest + getFilePathFromName:@"loss_2.out"] + skipHeader:YES]; + + NSMutableData* input = [ORTTrainingSessionTest loadTensorDataFromFile:[ORTTrainingSessionTest + getFilePathFromName:@"input-0.in"] + skipHeader:YES]; + + int32_t labels[] = {1, 1}; + + // create ORTValue array for input and labels + NSMutableArray* inputValues = [NSMutableArray array]; + + ORTValue* inputTensor = [[ORTValue alloc] initWithTensorData:input + elementType:ORTTensorElementDataTypeFloat + shape:@[ @2, @784 ] + error:&error]; + ORTAssertNullableResultSuccessful(inputTensor, error); + [inputValues addObject:inputTensor]; + + ORTValue* labelTensor = [[ORTValue alloc] initWithTensorData:[NSMutableData dataWithBytes:labels + length:sizeof(labels)] + elementType:ORTTensorElementDataTypeInt32 + shape:@[ @2 ] + error:&error]; + ORTAssertNullableResultSuccessful(labelTensor, error); + [inputValues addObject:labelTensor]; + + // run train step, optimizer steps and check loss + NSArray* outputs = [self.session trainStepWithInputValues:inputValues error:&error]; + ORTAssertNullableResultSuccessful(outputs, error); + + NSMutableData* loss = [outputs[0] tensorDataWithError:&error]; + ORTAssertNullableResultSuccessful(loss, error); + ORTAssertEqualFloatArrays(test_utils::getFloatArrayFromData(loss), + test_utils::getFloatArrayFromData(expectedOutput1)); + + BOOL result = [self.session lazyResetGradWithError:&error]; + ORTAssertBoolResultSuccessful(result, error); + + outputs = [self.session trainStepWithInputValues:inputValues error:&error]; + ORTAssertNullableResultSuccessful(outputs, error); + + loss = [outputs[0] tensorDataWithError:&error]; + ORTAssertNullableResultSuccessful(loss, error); + ORTAssertEqualFloatArrays(test_utils::getFloatArrayFromData(loss), + test_utils::getFloatArrayFromData(expectedOutput1)); + + result = [self.session optimizerStepWithError:&error]; + ORTAssertBoolResultSuccessful(result, error); + + outputs = [self.session trainStepWithInputValues:inputValues error:&error]; + ORTAssertNullableResultSuccessful(outputs, error); + loss = [outputs[0] tensorDataWithError:&error]; + ORTAssertNullableResultSuccessful(loss, error); + ORTAssertEqualFloatArrays(test_utils::getFloatArrayFromData(loss), + test_utils::getFloatArrayFromData(expectedOutput2)); +} + +- (void)testSetLearningRate { + NSError* error = nil; + + float learningRate = 0.1f; + BOOL result = [self.session setLearningRate:learningRate error:&error]; + ORTAssertBoolResultSuccessful(result, error); + + float actualLearningRate = [self.session getLearningRateWithError:&error]; + ORTAssertEqualFloatAndNoError(learningRate, actualLearningRate, error); +} + +- (void)testLinearLRScheduler { + NSError* error = nil; + + float learningRate = 0.1f; + BOOL result = [self.session registerLinearLRSchedulerWithWarmupStepCount:2 + totalStepCount:4 + initialLr:learningRate + error:&error]; + + ORTAssertBoolResultSuccessful(result, error); + + [self runTrainStep]; + + result = [self.session optimizerStepWithError:&error]; + ORTAssertBoolResultSuccessful(result, error); + result = [self.session schedulerStepWithError:&error]; + ORTAssertBoolResultSuccessful(result, error); + ORTAssertEqualFloatAndNoError(0.05f, [self.session getLearningRateWithError:&error], error); + + result = [self.session optimizerStepWithError:&error]; + ORTAssertBoolResultSuccessful(result, error); + result = [self.session schedulerStepWithError:&error]; + ORTAssertBoolResultSuccessful(result, error); + ORTAssertEqualFloatAndNoError(0.1f, [self.session getLearningRateWithError:&error], error); + + result = [self.session optimizerStepWithError:&error]; + ORTAssertBoolResultSuccessful(result, error); + result = [self.session schedulerStepWithError:&error]; + ORTAssertBoolResultSuccessful(result, error); + ORTAssertEqualFloatAndNoError(0.05f, [self.session getLearningRateWithError:&error], error); + + result = [self.session optimizerStepWithError:&error]; + ORTAssertBoolResultSuccessful(result, error); + result = [self.session schedulerStepWithError:&error]; + ORTAssertBoolResultSuccessful(result, error); + ORTAssertEqualFloatAndNoError(0.0f, [self.session getLearningRateWithError:&error], error); +} + +- (void)testExportModelForInference { + NSError* error = nil; + + NSString* inferenceModelPath = [test_utils::createTemporaryDirectory(self) + stringByAppendingPathComponent:@"inference_model.onnx"]; + XCTAssertNotNil(inferenceModelPath); + + NSArray* graphOutputNames = [NSArray arrayWithObjects:@"output-0", nil]; + + BOOL result = [self.session exportModelForInferenceWithOutputPath:inferenceModelPath + graphOutputNames:graphOutputNames + error:&error]; + + ORTAssertBoolResultSuccessful(result, error); + XCTAssertTrue([[NSFileManager defaultManager] fileExistsAtPath:inferenceModelPath]); + + [self addTeardownBlock:^{ + NSError* error = nil; + [[NSFileManager defaultManager] removeItemAtPath:inferenceModelPath error:&error]; + }]; +} + +- (void)testToBuffer { + NSError* error = nil; + ORTValue* buffer = [self.session toBufferWithTrainable:YES error:&error]; + ORTAssertNullableResultSuccessful(buffer, error); + + ORTValueTypeInfo* typeInfo = [buffer typeInfoWithError:&error]; + ORTAssertNullableResultSuccessful(typeInfo, error); + XCTAssertEqual(typeInfo.type, ORTValueTypeTensor); + XCTAssertNotNil(typeInfo.tensorTypeAndShapeInfo); +} + +- (void)testFromBuffer { + NSError* error = nil; + + ORTValue* buffer = [self.session toBufferWithTrainable:YES error:&error]; + ORTAssertNullableResultSuccessful(buffer, error); + + BOOL result = [self.session fromBufferWithValue:buffer error:&error]; + ORTAssertBoolResultSuccessful(result, error); +} + +- (void)tearDown { + _session = nil; + _checkpoint = nil; + _ortEnv = nil; + + [super tearDown]; +} + +@end + +NS_ASSUME_NONNULL_END + +#endif // ENABLE_TRAINING_APIS diff --git a/objectivec/test/ort_training_utils_test.mm b/objectivec/test/ort_training_utils_test.mm new file mode 100644 index 0000000000..77636bb237 --- /dev/null +++ b/objectivec/test/ort_training_utils_test.mm @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef ENABLE_TRAINING_APIS +#import +#import "ort_training_session.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface ORTTrainingUtilsTest : XCTestCase +@end + +@implementation ORTTrainingUtilsTest + +- (void)setUp { + [super setUp]; + + self.continueAfterFailure = NO; +} + +- (void)testSetSeed { + ORTSetSeed(2718); +} + +@end + +NS_ASSUME_NONNULL_END + +#endif // ENABLE_TRAINING_APIS diff --git a/objectivec/test/test_utils.h b/objectivec/test/test_utils.h new file mode 100644 index 0000000000..8a5e6e4821 --- /dev/null +++ b/objectivec/test/test_utils.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#import +#import + +#import "ort_session.h" +#import "ort_env.h" +#import "ort_value.h" + +NS_ASSUME_NONNULL_BEGIN + +namespace test_utils { + +NSString* _Nullable createTemporaryDirectory(XCTestCase* testCase); + +NSArray* getFloatArrayFromData(NSData* data); + +} // namespace test_utils + +NS_ASSUME_NONNULL_END diff --git a/objectivec/test/test_utils.mm b/objectivec/test/test_utils.mm new file mode 100644 index 0000000000..ef8e8c563b --- /dev/null +++ b/objectivec/test/test_utils.mm @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#import "test_utils.h" + +NS_ASSUME_NONNULL_BEGIN + +namespace test_utils { + +NSString* createTemporaryDirectory(XCTestCase* testCase) { + NSString* temporaryDirectory = NSTemporaryDirectory(); + NSString* directoryPath = [temporaryDirectory stringByAppendingPathComponent:@"ort-objective-c-test"]; + + NSError* error = nil; + [[NSFileManager defaultManager] createDirectoryAtPath:directoryPath + withIntermediateDirectories:YES + attributes:nil + error:&error]; + + XCTAssertNil(error, @"Error creating temporary directory: %@", error.localizedDescription); + + // add teardown block to delete the temporary directory + [testCase addTeardownBlock:^{ + NSError* error = nil; + [[NSFileManager defaultManager] removeItemAtPath:directoryPath error:&error]; + XCTAssertNil(error, @"Error removing temporary directory: %@", error.localizedDescription); + }]; + + return directoryPath; +} + +NSArray* getFloatArrayFromData(NSData* data) { + NSMutableArray* array = [NSMutableArray array]; + float value; + for (size_t i = 0; i < data.length / sizeof(float); ++i) { + [data getBytes:&value range:NSMakeRange(i * sizeof(float), sizeof(float))]; + [array addObject:[NSNumber numberWithFloat:value]]; + } + return array; +} + +} // namespace test_utils + +NS_ASSUME_NONNULL_END diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h index 8653244844..5bfdfcc74e 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h @@ -194,7 +194,6 @@ class TrainingSession : public detail::Base { * \param[in] input_values The user inputs to the training model. * \return A std::vector of Ort::Value objects that represents the output of the forward pass of the training model. * - * \snippet{doc} snippets.dox OrtStatus Return Value * */ std::vector TrainStep(const std::vector& input_values);