Objective C Training API: TrainingSession (#16374)

### Description
- Implement Objective-C binding for `ORTTrainingSession`
- Add `ORTUtils` utility class to handle conversion between C++ and
Objective-C types
- Add test case for saving checkpoint
- Add unit test cases for `ORTTrainingSession`

### Motivation and Context
This PR is part of implementing Objective-C bindings for training API.
It implements objective-c binding for training session. The objective-C
API closely resembles the C++ API.

---------

Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
This commit is contained in:
Vrajang Parikh 2023-06-28 09:13:56 -07:00 committed by GitHub
parent 6dd4e4801a
commit 960e320dff
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 1150 additions and 12 deletions

View file

@ -42,11 +42,14 @@ file(GLOB onnxruntime_objc_srcs CONFIGURE_DEPENDS
if(NOT onnxruntime_ENABLE_TRAINING_APIS)
list(REMOVE_ITEM onnxruntime_objc_headers
"${OBJC_ROOT}/include/ort_checkpoint.h")
"${OBJC_ROOT}/include/ort_checkpoint.h"
"${OBJC_ROOT}/include/ort_training_session.h")
list(REMOVE_ITEM onnxruntime_objc_srcs
"${OBJC_ROOT}/ort_checkpoint_internal.h"
"${OBJC_ROOT}/ort_checkpoint.mm")
"${OBJC_ROOT}/ort_checkpoint.mm"
"${OBJC_ROOT}/ort_training_session_internal.h"
"${OBJC_ROOT}/ort_training_session.mm")
endif()
@ -124,7 +127,9 @@ if(onnxruntime_BUILD_UNIT_TESTS)
if(NOT onnxruntime_ENABLE_TRAINING_APIS)
list(REMOVE_ITEM onnxruntime_objc_test_srcs
"${OBJC_ROOT}/test/ort_checkpoint_test.mm")
"${OBJC_ROOT}/test/ort_checkpoint_test.mm"
"${OBJC_ROOT}/test/ort_training_session_test.mm"
"${OBJC_ROOT}/test/ort_training_utils_test.mm")
endif()

32
objectivec/cxx_utils.h Normal file
View file

@ -0,0 +1,32 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#import <Foundation/Foundation.h>
#include <optional>
#include <string>
#include <variant>
#import "cxx_api.h"
NS_ASSUME_NONNULL_BEGIN
@class ORTValue;
namespace utils {
NSString* toNSString(const std::string& str);
NSString* _Nullable toNullableNSString(const std::optional<std::string>& str);
std::string toStdString(NSString* str);
std::optional<std::string> toStdOptionalString(NSString* _Nullable str);
std::vector<std::string> toStdStringVector(NSArray<NSString*>* strs);
NSArray<NSString*>* toNSStringNSArray(const std::vector<std::string>& strs);
NSArray<ORTValue*>* _Nullable wrapUnownedCAPIOrtValues(const std::vector<OrtValue*>& values, NSError** error);
std::vector<const OrtValue*> getWrappedCAPIOrtValues(NSArray<ORTValue*>* values);
} // namespace utils
NS_ASSUME_NONNULL_END

93
objectivec/cxx_utils.mm Normal file
View file

@ -0,0 +1,93 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#import "cxx_utils.h"
#import <vector>
#import <optional>
#import <string>
#import "error_utils.h"
#import "ort_value_internal.h"
NS_ASSUME_NONNULL_BEGIN
namespace utils {
NSString* toNSString(const std::string& str) {
NSString* nsStr = [NSString stringWithUTF8String:str.c_str()];
if (!nsStr) {
ORT_CXX_API_THROW("Failed to convert std::string to NSString", ORT_INVALID_ARGUMENT);
}
return nsStr;
}
NSString* _Nullable toNullableNSString(const std::optional<std::string>& str) {
if (str.has_value()) {
return toNSString(*str);
}
return nil;
}
std::string toStdString(NSString* str) {
return std::string([str UTF8String]);
}
std::optional<std::string> toStdOptionalString(NSString* _Nullable str) {
if (str) {
return std::optional<std::string>([str UTF8String]);
}
return std::nullopt;
}
std::vector<std::string> toStdStringVector(NSArray<NSString*>* strs) {
std::vector<std::string> result;
result.reserve(strs.count);
for (NSString* str in strs) {
result.push_back([str UTF8String]);
}
return result;
}
NSArray<NSString*>* toNSStringNSArray(const std::vector<std::string>& strs) {
NSMutableArray<NSString*>* result = [NSMutableArray arrayWithCapacity:strs.size()];
for (const std::string& str : strs) {
NSString* nsStr = [NSString stringWithUTF8String:str.c_str()];
if (nsStr) {
[result addObject:nsStr];
} else {
ORT_CXX_API_THROW("Failed to convert std::string to NSString", ORT_INVALID_ARGUMENT);
}
}
return result;
}
NSArray<ORTValue*>* _Nullable wrapUnownedCAPIOrtValues(const std::vector<OrtValue*>& values, NSError** error) {
NSMutableArray<ORTValue*>* result = [NSMutableArray arrayWithCapacity:values.size()];
for (size_t i = 0; i < values.size(); ++i) {
ORTValue* val = [[ORTValue alloc] initWithCAPIOrtValue:values[i] externalTensorData:nil error:error];
if (!val) {
// clean up all the C API Ortvalues which haven't been wrapped by ORTValue
for (size_t j = i; j < values.size(); ++j) {
Ort::GetApi().ReleaseValue(values[j]);
}
return nil;
}
[result addObject:val];
}
return result;
}
std::vector<const OrtValue*> getWrappedCAPIOrtValues(NSArray<ORTValue*>* values) {
std::vector<const OrtValue*> result;
for (ORTValue* val in values) {
result.push_back(static_cast<const OrtValue*>([val CXXAPIOrtValue]));
}
return result;
}
} // namespace utils
NS_ASSUME_NONNULL_END

View file

@ -0,0 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// this header contains the entire ONNX Runtime training Objective-C API
// the headers below can also be imported individually
#import "onnxruntime.h"
#import "ort_checkpoint.h"
#import "ort_training_session.h"

View file

@ -11,8 +11,8 @@ NS_ASSUME_NONNULL_BEGIN
* An ORT checkpoint is a snapshot of the state of a model at a given point in time.
*
* This class holds the entire training session state that includes model parameters,
* their gradients, optimizer parameters, and user properties. The ORTTrainingSession leverages the
* ORTCheckpointState by accessing and updating the contained training state.
* their gradients, optimizer parameters, and user properties. The `ORTTrainingSession` leverages the
* `ORTCheckpoint` by accessing and updating the contained training state.
*
* Available since v1.16.0.
*

View file

@ -0,0 +1,261 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef ENABLE_TRAINING_APIS
#import <Foundation/Foundation.h>
#include <stdint.h>
NS_ASSUME_NONNULL_BEGIN
@class ORTCheckpoint;
@class ORTEnv;
@class ORTValue;
@class ORTSessionOptions;
/**
* Trainer class that provides methods to train, evaluate and optimize ONNX models.
*
* The training session requires four training artifacts:
* 1. Training onnx model
* 2. Evaluation onnx model (optional)
* 3. Optimizer onnx model
* 4. Checkpoint directory
*
* [onnxruntime-training python utility](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md)
* can be used to generate above training artifacts.
*
* Available since v1.16.0.
*/
@interface ORTTrainingSession : NSObject
- (instancetype)init NS_UNAVAILABLE;
/**
* Creates a training session from the training artifacts that can be used to begin or resume training.
*
* The initializer instantiates the training session based on provided env and session options, which can be used to
* begin or resume training from a given checkpoint state. The checkpoint state represents the parameters of training
* session which will be moved to the device specified in the session option if needed.
*
* @param env The `ORTEnv` instance to use for the training session.
* @param sessionOptions The `ORTSessionOptions` to use for the training session.
* @param checkpoint Training states that are used as a starting point for training.
* @param trainModelPath The path to the training onnx model.
* @param evalModelPath The path to the evaluation onnx model.
* @param optimizerModelPath The path to the optimizer onnx model used to perform gradient descent.
* @param error Optional error information set if an error occurs.
* @return The instance, or nil if an error occurs.
*
*/
- (nullable instancetype)initWithEnv:(ORTEnv*)env
sessionOptions:(ORTSessionOptions*)sessionOptions
checkpoint:(ORTCheckpoint*)checkpoint
trainModelPath:(NSString*)trainModelPath
evalModelPath:(nullable NSString*)evalModelPath
optimizerModelPath:(nullable NSString*)optimizerModelPath
error:(NSError**)error NS_DESIGNATED_INITIALIZER;
/**
* Performs a training step, which is equivalent to a forward and backward propagation in a single step.
*
* The training step computes the outputs of the training model and the gradients of the trainable parameters
* for the given input values. The train step is performed based on the training model that was provided to the training session.
* It is equivalent to running forward and backward propagation in a single step. The computed gradients are stored inside
* the training session state so they can be later consumed by `optimizerStep`. The gradients can be lazily reset by
* calling `lazyResetGrad` method.
*
* @param inputs The input values to the training model.
* @param error Optional error information set if an error occurs.
* @return The output values of the training model.
*/
- (nullable NSArray<ORTValue*>*)trainStepWithInputValues:(NSArray<ORTValue*>*)inputs
error:(NSError**)error;
/**
* Performs a evaluation step that computes the outputs of the evaluation model for the given inputs.
* The eval step is performed based on the evaluation model that was provided to the training session.
*
* @param inputs The input values to the eval model.
* @param error Optional error information set if an error occurs.
* @return The output values of the eval model.
*
*/
- (nullable NSArray<ORTValue*>*)evalStepWithInputValues:(NSArray<ORTValue*>*)inputs
error:(NSError**)error;
/**
* Reset the gradients of all trainable parameters to zero lazily.
*
* Calling this method sets the internal state of the training session such that the gradients of the trainable parameters
* in the ORTCheckpoint will be scheduled to be reset just before the new gradients are computed on the next
* invocation of the `trainStep` method.
*
* @param error Optional error information set if an error occurs.
* @return YES if the gradients are set to reset successfully, NO otherwise.
*/
- (BOOL)lazyResetGradWithError:(NSError**)error;
/**
* Performs the weight updates for the trainable parameters using the optimizer model. The optimizer step is performed
* based on the optimizer model that was provided to the training session. The updated parameters are stored inside the
* training state so that they can be used by the next `trainStep` method call.
*
* @param error Optional error information set if an error occurs.
* @return YES if the optimizer step was performed successfully, NO otherwise.
*/
- (BOOL)optimizerStepWithError:(NSError**)error;
/**
* Returns the names of the user inputs for the training model that can be associated with
* the `ORTValue` provided to the `trainStep`.
*
* @param error Optional error information set if an error occurs.
* @return The names of the user inputs for the training model.
*/
- (nullable NSArray<NSString*>*)getTrainInputNamesWithError:(NSError**)error;
/**
* Returns the names of the user inputs for the evaluation model that can be associated with
* the `ORTValue` provided to the `evalStep`.
*
* @param error Optional error information set if an error occurs.
* @return The names of the user inputs for the evaluation model.
*/
- (nullable NSArray<NSString*>*)getEvalInputNamesWithError:(NSError**)error;
/**
* Returns the names of the user outputs for the training model that can be associated with
* the `ORTValue` returned by the `trainStep`.
*
* @param error Optional error information set if an error occurs.
* @return The names of the user outputs for the training model.
*/
- (nullable NSArray<NSString*>*)getTrainOutputNamesWithError:(NSError**)error;
/**
* Returns the names of the user outputs for the evaluation model that can be associated with
* the `ORTValue` returned by the `evalStep`.
*
* @param error Optional error information set if an error occurs.
* @return The names of the user outputs for the evaluation model.
*/
- (nullable NSArray<NSString*>*)getEvalOutputNamesWithError:(NSError**)error;
/**
* Registers a linear learning rate scheduler for the training session.
*
* The scheduler gradually decreases the learning rate from the initial value to zero over the course of the training.
* The decrease is performed by multiplying the current learning rate by a linearly updated factor.
* Before the decrease, the learning rate is gradually increased from zero to the initial value during a warmup phase.
*
* @param warmupStepCount The number of steps to perform the linear warmup.
* @param totalStepCount The total number of steps to perform the linear decay.
* @param initialLr The initial learning rate.
* @param error Optional error information set if an error occurs.
* @return YES if the scheduler was registered successfully, NO otherwise.
*/
- (BOOL)registerLinearLRSchedulerWithWarmupStepCount:(int64_t)warmupStepCount
totalStepCount:(int64_t)totalStepCount
initialLr:(float)initialLr
error:(NSError**)error;
/**
* Update the learning rate based on the registered learning rate scheduler.
*
* Performs a scheduler step that updates the learning rate that is being used by the training session.
* This function should typically be called before invoking the optimizer step for each round, or as necessary
* to update the learning rate being used by the training session.
*
* @note A valid predefined learning rate scheduler must be first registered to invoke this method.
*
* @param error Optional error information set if an error occurs.
* @return YES if the scheduler step was performed successfully, NO otherwise.
*/
- (BOOL)schedulerStepWithError:(NSError**)error;
/**
* Returns the current learning rate being used by the training session.
*
* @param error Optional error information set if an error occurs.
* @return The current learning rate or 0.0f if an error occurs.
*/
- (float)getLearningRateWithError:(NSError**)error __attribute__((swift_error(nonnull_error)));
/**
* Sets the learning rate being used by the training session.
*
* The current learning rate is maintained by the training session and can be overwritten by invoking this method
* with the desired learning rate. This function should not be used when a valid learning rate scheduler is registered.
* It should be used either to set the learning rate derived from a custom learning rate scheduler or to set a constant
* learning rate to be used throughout the training session.
*
* @note It does not set the initial learning rate that may be needed by the predefined learning rate schedulers.
* To set the initial learning rate for learning rate schedulers, use the `registerLinearLRScheduler` method.
*
* @param lr The learning rate to be used by the training session.
* @param error Optional error information set if an error occurs.
* @return YES if the learning rate was set successfully, NO otherwise.
*/
- (BOOL)setLearningRate:(float)lr
error:(NSError**)error;
/**
* Loads the training session model parameters from a contiguous buffer.
*
* @param buffer Contiguous buffer to load the parameters from.
* @param error Optional error information set if an error occurs.
* @return YES if the parameters were loaded successfully, NO otherwise.
*/
- (BOOL)fromBufferWithValue:(ORTValue*)buffer
error:(NSError**)error;
/**
* Returns a contiguous buffer that holds a copy of all training state parameters.
*
* @param onlyTrainable If YES, returns a buffer that holds only the trainable parameters, otherwise returns a buffer
* that holds all the parameters.
* @param error Optional error information set if an error occurs.
* @return A contiguous buffer that holds a copy of all training state parameters.
*/
- (nullable ORTValue*)toBufferWithTrainable:(BOOL)onlyTrainable
error:(NSError**)error;
/**
* Exports the training session model that can be used for inference.
*
* If the training session was provided with an eval model, the training session can generate an inference model if it
* knows the inference graph outputs. The input inference graph outputs are used to prune the eval model so that the
* inference model's outputs align with the provided outputs. The exported model is saved at the path provided and
* can be used for inferencing with `ORTSession`.
*
* @note The method reloads the eval model from the path provided to the initializer and expects this path to be valid.
*
* @param inferenceModelPath The path to the serialized the inference model.
* @param graphOutputNames The names of the outputs that are needed in the inference model.
* @param error Optional error information set if an error occurs.
* @return YES if the inference model was exported successfully, NO otherwise.
*/
- (BOOL)exportModelForInferenceWithOutputPath:(NSString*)inferenceModelPath
graphOutputNames:(NSArray<NSString*>*)graphOutputNames
error:(NSError**)error;
@end
#ifdef __cplusplus
extern "C" {
#endif
/**
* This function sets the seed for generating random numbers.
* Use this function to generate reproducible results. It should be noted that completely reproducible results are not guaranteed.
*
* @param seed Manually set seed to use for random number generation.
*/
void ORTSetSeed(int64_t seed);
#ifdef __cplusplus
}
#endif
NS_ASSUME_NONNULL_END
#endif // ENABLE_TRAINING_APIS

View file

@ -0,0 +1,227 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef ENABLE_TRAINING_APIS
#import "ort_training_session_internal.h"
#import <vector>
#import <optional>
#import <string>
#import "cxx_api.h"
#import "cxx_utils.h"
#import "error_utils.h"
#import "ort_checkpoint_internal.h"
#import "ort_session_internal.h"
#import "ort_enums_internal.h"
#import "ort_env_internal.h"
#import "ort_value_internal.h"
NS_ASSUME_NONNULL_BEGIN
@implementation ORTTrainingSession {
std::optional<Ort::TrainingSession> _session;
ORTCheckpoint* _checkpoint;
}
- (Ort::TrainingSession&)CXXAPIOrtTrainingSession {
return *_session;
}
- (nullable instancetype)initWithEnv:(ORTEnv*)env
sessionOptions:(ORTSessionOptions*)sessionOptions
checkpoint:(ORTCheckpoint*)checkpoint
trainModelPath:(NSString*)trainModelPath
evalModelPath:(nullable NSString*)evalModelPath
optimizerModelPath:(nullable NSString*)optimizerModelPath
error:(NSError**)error {
if ((self = [super init]) == nil) {
return nil;
}
try {
std::optional<std::string> evalPath = utils::toStdOptionalString(evalModelPath);
std::optional<std::string> optimizerPath = utils::toStdOptionalString(optimizerModelPath);
_checkpoint = checkpoint;
_session = Ort::TrainingSession{
[env CXXAPIOrtEnv],
[sessionOptions CXXAPIOrtSessionOptions],
[checkpoint CXXAPIOrtCheckpoint],
trainModelPath.UTF8String,
evalPath,
optimizerPath};
return self;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
- (nullable NSArray<ORTValue*>*)trainStepWithInputValues:(NSArray<ORTValue*>*)inputs
error:(NSError**)error {
try {
std::vector<const OrtValue*> inputValues = utils::getWrappedCAPIOrtValues(inputs);
size_t outputCount;
Ort::ThrowOnError(Ort::GetTrainingApi().TrainingSessionGetTrainingModelOutputCount(*_session, &outputCount));
std::vector<OrtValue*> outputValues(outputCount, nullptr);
Ort::RunOptions runOptions;
Ort::ThrowOnError(Ort::GetTrainingApi().TrainStep(
*_session,
runOptions,
inputValues.size(),
inputValues.data(),
outputValues.size(),
outputValues.data()));
return utils::wrapUnownedCAPIOrtValues(outputValues, error);
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
- (nullable NSArray<ORTValue*>*)evalStepWithInputValues:(NSArray<ORTValue*>*)inputs
error:(NSError**)error {
try {
// create vector of OrtValue from NSArray<ORTValue*> with same size as inputValues
std::vector<const OrtValue*> inputValues = utils::getWrappedCAPIOrtValues(inputs);
size_t outputCount;
Ort::ThrowOnError(Ort::GetTrainingApi().TrainingSessionGetEvalModelOutputCount(*_session, &outputCount));
std::vector<OrtValue*> outputValues(outputCount, nullptr);
Ort::RunOptions runOptions;
Ort::ThrowOnError(Ort::GetTrainingApi().EvalStep(
*_session,
runOptions,
inputValues.size(),
inputValues.data(),
outputValues.size(),
outputValues.data()));
return utils::wrapUnownedCAPIOrtValues(outputValues, error);
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
- (BOOL)lazyResetGradWithError:(NSError**)error {
try {
[self CXXAPIOrtTrainingSession].LazyResetGrad();
return YES;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}
- (BOOL)optimizerStepWithError:(NSError**)error {
try {
[self CXXAPIOrtTrainingSession].OptimizerStep();
return YES;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}
- (nullable NSArray<NSString*>*)getTrainInputNamesWithError:(NSError**)error {
try {
std::vector<std::string> inputNames = [self CXXAPIOrtTrainingSession].InputNames(true);
return utils::toNSStringNSArray(inputNames);
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
- (nullable NSArray<NSString*>*)getTrainOutputNamesWithError:(NSError**)error {
try {
std::vector<std::string> outputNames = [self CXXAPIOrtTrainingSession].OutputNames(true);
return utils::toNSStringNSArray(outputNames);
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
- (nullable NSArray<NSString*>*)getEvalInputNamesWithError:(NSError**)error {
try {
std::vector<std::string> inputNames = [self CXXAPIOrtTrainingSession].InputNames(false);
return utils::toNSStringNSArray(inputNames);
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
- (nullable NSArray<NSString*>*)getEvalOutputNamesWithError:(NSError**)error {
try {
std::vector<std::string> outputNames = [self CXXAPIOrtTrainingSession].OutputNames(false);
return utils::toNSStringNSArray(outputNames);
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
- (BOOL)registerLinearLRSchedulerWithWarmupStepCount:(int64_t)warmupStepCount
totalStepCount:(int64_t)totalStepCount
initialLr:(float)initialLr
error:(NSError**)error {
try {
[self CXXAPIOrtTrainingSession].RegisterLinearLRScheduler(warmupStepCount, totalStepCount, initialLr);
return YES;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}
- (BOOL)schedulerStepWithError:(NSError**)error {
try {
[self CXXAPIOrtTrainingSession].SchedulerStep();
return YES;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}
- (float)getLearningRateWithError:(NSError**)error {
try {
return [self CXXAPIOrtTrainingSession].GetLearningRate();
}
ORT_OBJC_API_IMPL_CATCH(error, 0.0f);
}
- (BOOL)setLearningRate:(float)lr
error:(NSError**)error {
try {
[self CXXAPIOrtTrainingSession].SetLearningRate(lr);
return YES;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}
- (BOOL)fromBufferWithValue:(ORTValue*)buffer
error:(NSError**)error {
try {
[self CXXAPIOrtTrainingSession].FromBuffer([buffer CXXAPIOrtValue]);
return YES;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}
- (nullable ORTValue*)toBufferWithTrainable:(BOOL)onlyTrainable
error:(NSError**)error {
try {
Ort::Value val = [self CXXAPIOrtTrainingSession].ToBuffer(onlyTrainable);
return [[ORTValue alloc] initWithCAPIOrtValue:val.release()
externalTensorData:nil
error:error];
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
- (BOOL)exportModelForInferenceWithOutputPath:(NSString*)inferenceModelPath
graphOutputNames:(NSArray<NSString*>*)graphOutputNames
error:(NSError**)error {
try {
[self CXXAPIOrtTrainingSession].ExportModelForInferencing(utils::toStdString(inferenceModelPath),
utils::toStdStringVector(graphOutputNames));
return YES;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}
@end
void ORTSetSeed(int64_t seed) {
Ort::SetSeed(seed);
}
NS_ASSUME_NONNULL_END
#endif // ENABLE_TRAINING_APIS

View file

@ -0,0 +1,19 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef ENABLE_TRAINING_APIS
#import "ort_training_session.h"
#import "cxx_api.h"
NS_ASSUME_NONNULL_BEGIN
@interface ORTTrainingSession ()
- (Ort::TrainingSession&)CXXAPIOrtTrainingSession;
@end
NS_ASSUME_NONNULL_END
#endif // ENABLE_TRAINING_APIS

View file

@ -29,4 +29,18 @@ NS_ASSUME_NONNULL_BEGIN
XCTAssertNotNil(error); \
} while (0)
#define ORTAssertEqualFloatAndNoError(expected, result, error) \
do { \
XCTAssertEqualWithAccuracy(expected, result, 1e-3f, @"Expected %f but got %f. Error:%@", expected, result, error); \
XCTAssertNil(error); \
} while (0)
#define ORTAssertEqualFloatArrays(expected, result) \
do { \
XCTAssertEqual(expected.count, result.count); \
for (size_t i = 0; i < expected.count; ++i) { \
XCTAssertEqualWithAccuracy([expected[i] floatValue], [result[i] floatValue], 1e-3f); \
} \
} while (0)
NS_ASSUME_NONNULL_END

View file

@ -5,7 +5,11 @@
#import <XCTest/XCTest.h>
#import "ort_checkpoint.h"
#import "ort_training_session.h"
#import "ort_env.h"
#import "ort_session.h"
#import "test/test_utils.h"
#import "test/assertion_utils.h"
NS_ASSUME_NONNULL_BEGIN
@ -27,22 +31,41 @@ NS_ASSUME_NONNULL_BEGIN
ORTAssertNullableResultSuccessful(_ortEnv, err);
}
- (NSString*)getCheckpointPath {
+ (NSString*)getCheckpointPath {
NSBundle* bundle = [NSBundle bundleForClass:[ORTCheckpointTest class]];
NSString* path = [[bundle resourcePath] stringByAppendingPathComponent:@"checkpoint.ckpt"];
return path;
}
- (void)testLoadCheckpoint {
+ (NSString*)getTrainingModelPath {
NSBundle* bundle = [NSBundle bundleForClass:[ORTCheckpointTest class]];
NSString* path = [[bundle resourcePath] stringByAppendingPathComponent:@"training_model.onnx"];
return path;
}
- (void)testSaveCheckpoint {
NSError* error = nil;
ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[self getCheckpointPath] error:&error];
ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[ORTCheckpointTest getCheckpointPath] error:&error];
ORTAssertNullableResultSuccessful(checkpoint, error);
// save checkpoint
NSString* path = [test_utils::createTemporaryDirectory(self) stringByAppendingPathComponent:@"save_checkpoint.ckpt"];
XCTAssertNotNil(path);
BOOL result = [checkpoint saveCheckpointToPath:path withOptimizerState:NO error:&error];
ORTAssertBoolResultSuccessful(result, error);
}
- (void)testInitCheckpoint {
NSError* error = nil;
ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[ORTCheckpointTest getCheckpointPath] error:&error];
ORTAssertNullableResultSuccessful(checkpoint, error);
}
- (void)testIntProperty {
NSError* error = nil;
// Load checkpoint
ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[self getCheckpointPath] error:&error];
ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[ORTCheckpointTest getCheckpointPath] error:&error];
ORTAssertNullableResultSuccessful(checkpoint, error);
// Add property
@ -57,7 +80,7 @@ NS_ASSUME_NONNULL_BEGIN
- (void)testFloatProperty {
NSError* error = nil;
// Load checkpoint
ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[self getCheckpointPath] error:&error];
ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[ORTCheckpointTest getCheckpointPath] error:&error];
ORTAssertNullableResultSuccessful(checkpoint, error);
// Add property
@ -72,7 +95,7 @@ NS_ASSUME_NONNULL_BEGIN
- (void)testStringProperty {
NSError* error = nil;
// Load checkpoint
ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[self getCheckpointPath] error:&error];
ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[ORTCheckpointTest getCheckpointPath] error:&error];
ORTAssertNullableResultSuccessful(checkpoint, error);
// Add property

View file

@ -0,0 +1,362 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef ENABLE_TRAINING_APIS
#import <XCTest/XCTest.h>
#import "ort_checkpoint.h"
#import "ort_training_session.h"
#import "ort_env.h"
#import "ort_session.h"
#import "ort_value.h"
#import "test/test_utils.h"
#import "test/assertion_utils.h"
NS_ASSUME_NONNULL_BEGIN
@interface ORTTrainingSessionTest : XCTestCase
@property(readonly, nullable) ORTEnv* ortEnv;
@property(readonly, nullable) ORTCheckpoint* checkpoint;
@property(readonly, nullable) ORTTrainingSession* session;
@end
@implementation ORTTrainingSessionTest
- (void)setUp {
[super setUp];
self.continueAfterFailure = NO;
NSError* err = nil;
_ortEnv = [[ORTEnv alloc] initWithLoggingLevel:ORTLoggingLevelWarning
error:&err];
ORTAssertNullableResultSuccessful(_ortEnv, err);
_checkpoint = [[ORTCheckpoint alloc] initWithPath:[ORTTrainingSessionTest
getFilePathFromName:@"checkpoint.ckpt"]
error:&err];
ORTAssertNullableResultSuccessful(_checkpoint, err);
_session = [self makeTrainingSessionWithCheckpoint:_checkpoint];
}
+ (NSString*)getFilePathFromName:(NSString*)name {
NSBundle* bundle = [NSBundle bundleForClass:[ORTTrainingSessionTest class]];
NSString* path = [[bundle resourcePath] stringByAppendingPathComponent:name];
return path;
}
+ (NSMutableData*)loadTensorDataFromFile:(NSString*)filePath skipHeader:(BOOL)skipHeader {
NSError* error = nil;
NSString* fileContents = [NSString stringWithContentsOfFile:filePath
encoding:NSUTF8StringEncoding
error:&error];
ORTAssertNullableResultSuccessful(fileContents, error);
NSArray<NSString*>* lines = [fileContents componentsSeparatedByCharactersInSet:[NSCharacterSet newlineCharacterSet]];
if (skipHeader) {
lines = [lines subarrayWithRange:NSMakeRange(1, lines.count - 1)];
}
NSArray<NSString*>* dataArray = [lines[0] componentsSeparatedByCharactersInSet:
[NSCharacterSet characterSetWithCharactersInString:@",[] "]];
NSMutableData* tensorData = [NSMutableData data];
for (NSString* str in dataArray) {
if (str.length > 0) {
float value = [str floatValue];
[tensorData appendBytes:&value length:sizeof(float)];
}
}
return tensorData;
}
- (ORTTrainingSession*)makeTrainingSessionWithCheckpoint:(ORTCheckpoint*)checkpoint {
NSError* error = nil;
ORTSessionOptions* sessionOptions = [[ORTSessionOptions alloc] initWithError:&error];
ORTAssertNullableResultSuccessful(sessionOptions, error);
ORTTrainingSession* session = [[ORTTrainingSession alloc]
initWithEnv:self.ortEnv
sessionOptions:sessionOptions
checkpoint:checkpoint
trainModelPath:[ORTTrainingSessionTest getFilePathFromName:@"training_model.onnx"]
evalModelPath:[ORTTrainingSessionTest getFilePathFromName:@"eval_model.onnx"]
optimizerModelPath:[ORTTrainingSessionTest getFilePathFromName:@"adamw.onnx"]
error:&error];
ORTAssertNullableResultSuccessful(session, error);
return session;
}
- (void)testInitTrainingSession {
NSError* error = nil;
// check that inputNames contains input-0
NSArray<NSString*>* inputNames = [self.session getTrainInputNamesWithError:&error];
ORTAssertNullableResultSuccessful(inputNames, error);
XCTAssertTrue(inputNames.count > 0);
XCTAssertTrue([inputNames containsObject:@"input-0"]);
// check that outNames contains onnx::loss::21273
NSArray<NSString*>* outputNames = [self.session getTrainOutputNamesWithError:&error];
ORTAssertNullableResultSuccessful(outputNames, error);
XCTAssertTrue(outputNames.count > 0);
XCTAssertTrue([outputNames containsObject:@"onnx::loss::21273"]);
}
- (void)testInitTrainingSessionWithEval {
NSError* error = nil;
// check that inputNames contains input-0
NSArray<NSString*>* inputNames = [self.session getEvalInputNamesWithError:&error];
ORTAssertNullableResultSuccessful(inputNames, error);
XCTAssertTrue(inputNames.count > 0);
XCTAssertTrue([inputNames containsObject:@"input-0"]);
// check that outNames contains onnx::loss::21273
NSArray<NSString*>* outputNames = [self.session getEvalOutputNamesWithError:&error];
ORTAssertNullableResultSuccessful(outputNames, error);
XCTAssertTrue(outputNames.count > 0);
XCTAssertTrue([outputNames containsObject:@"onnx::loss::21273"]);
}
- (void)runTrainStep {
// load input and expected output
NSError* error = nil;
NSMutableData* expectedOutput = [ORTTrainingSessionTest loadTensorDataFromFile:[ORTTrainingSessionTest
getFilePathFromName:@"loss_1.out"]
skipHeader:YES];
NSMutableData* input = [ORTTrainingSessionTest loadTensorDataFromFile:[ORTTrainingSessionTest
getFilePathFromName:@"input-0.in"]
skipHeader:YES];
int32_t labels[] = {1, 1};
// create ORTValue array for input and labels
NSMutableArray<ORTValue*>* inputValues = [NSMutableArray array];
ORTValue* inputTensor = [[ORTValue alloc] initWithTensorData:input
elementType:ORTTensorElementDataTypeFloat
shape:@[ @2, @784 ]
error:&error];
ORTAssertNullableResultSuccessful(inputTensor, error);
[inputValues addObject:inputTensor];
ORTValue* labelTensor = [[ORTValue alloc] initWithTensorData:[NSMutableData dataWithBytes:labels
length:sizeof(labels)]
elementType:ORTTensorElementDataTypeInt32
shape:@[ @2 ]
error:&error];
ORTAssertNullableResultSuccessful(labelTensor, error);
[inputValues addObject:labelTensor];
NSArray<ORTValue*>* outputs = [self.session trainStepWithInputValues:inputValues error:&error];
ORTAssertNullableResultSuccessful(outputs, error);
XCTAssertTrue(outputs.count > 0);
BOOL result = [self.session lazyResetGradWithError:&error];
ORTAssertBoolResultSuccessful(result, error);
outputs = [self.session trainStepWithInputValues:inputValues error:&error];
ORTAssertNullableResultSuccessful(outputs, error);
XCTAssertTrue(outputs.count > 0);
ORTValue* outputValue = outputs[0];
ORTValueTypeInfo* typeInfo = [outputValue typeInfoWithError:&error];
ORTAssertNullableResultSuccessful(typeInfo, error);
XCTAssertEqual(typeInfo.type, ORTValueTypeTensor);
XCTAssertNotNil(typeInfo.tensorTypeAndShapeInfo);
ORTTensorTypeAndShapeInfo* tensorInfo = [outputValue tensorTypeAndShapeInfoWithError:&error];
ORTAssertNullableResultSuccessful(tensorInfo, error);
XCTAssertEqual(tensorInfo.elementType, ORTTensorElementDataTypeFloat);
NSMutableData* tensorData = [outputValue tensorDataWithError:&error];
ORTAssertNullableResultSuccessful(tensorData, error);
ORTAssertEqualFloatArrays(test_utils::getFloatArrayFromData(tensorData),
test_utils::getFloatArrayFromData(expectedOutput));
}
- (void)testTrainStepOutput {
[self runTrainStep];
}
- (void)testOptimizerStep {
// load input and expected output
NSError* error = nil;
NSMutableData* expectedOutput1 = [ORTTrainingSessionTest loadTensorDataFromFile:[ORTTrainingSessionTest
getFilePathFromName:@"loss_1.out"]
skipHeader:YES];
NSMutableData* expectedOutput2 = [ORTTrainingSessionTest loadTensorDataFromFile:[ORTTrainingSessionTest
getFilePathFromName:@"loss_2.out"]
skipHeader:YES];
NSMutableData* input = [ORTTrainingSessionTest loadTensorDataFromFile:[ORTTrainingSessionTest
getFilePathFromName:@"input-0.in"]
skipHeader:YES];
int32_t labels[] = {1, 1};
// create ORTValue array for input and labels
NSMutableArray<ORTValue*>* inputValues = [NSMutableArray array];
ORTValue* inputTensor = [[ORTValue alloc] initWithTensorData:input
elementType:ORTTensorElementDataTypeFloat
shape:@[ @2, @784 ]
error:&error];
ORTAssertNullableResultSuccessful(inputTensor, error);
[inputValues addObject:inputTensor];
ORTValue* labelTensor = [[ORTValue alloc] initWithTensorData:[NSMutableData dataWithBytes:labels
length:sizeof(labels)]
elementType:ORTTensorElementDataTypeInt32
shape:@[ @2 ]
error:&error];
ORTAssertNullableResultSuccessful(labelTensor, error);
[inputValues addObject:labelTensor];
// run train step, optimizer steps and check loss
NSArray<ORTValue*>* outputs = [self.session trainStepWithInputValues:inputValues error:&error];
ORTAssertNullableResultSuccessful(outputs, error);
NSMutableData* loss = [outputs[0] tensorDataWithError:&error];
ORTAssertNullableResultSuccessful(loss, error);
ORTAssertEqualFloatArrays(test_utils::getFloatArrayFromData(loss),
test_utils::getFloatArrayFromData(expectedOutput1));
BOOL result = [self.session lazyResetGradWithError:&error];
ORTAssertBoolResultSuccessful(result, error);
outputs = [self.session trainStepWithInputValues:inputValues error:&error];
ORTAssertNullableResultSuccessful(outputs, error);
loss = [outputs[0] tensorDataWithError:&error];
ORTAssertNullableResultSuccessful(loss, error);
ORTAssertEqualFloatArrays(test_utils::getFloatArrayFromData(loss),
test_utils::getFloatArrayFromData(expectedOutput1));
result = [self.session optimizerStepWithError:&error];
ORTAssertBoolResultSuccessful(result, error);
outputs = [self.session trainStepWithInputValues:inputValues error:&error];
ORTAssertNullableResultSuccessful(outputs, error);
loss = [outputs[0] tensorDataWithError:&error];
ORTAssertNullableResultSuccessful(loss, error);
ORTAssertEqualFloatArrays(test_utils::getFloatArrayFromData(loss),
test_utils::getFloatArrayFromData(expectedOutput2));
}
- (void)testSetLearningRate {
NSError* error = nil;
float learningRate = 0.1f;
BOOL result = [self.session setLearningRate:learningRate error:&error];
ORTAssertBoolResultSuccessful(result, error);
float actualLearningRate = [self.session getLearningRateWithError:&error];
ORTAssertEqualFloatAndNoError(learningRate, actualLearningRate, error);
}
- (void)testLinearLRScheduler {
NSError* error = nil;
float learningRate = 0.1f;
BOOL result = [self.session registerLinearLRSchedulerWithWarmupStepCount:2
totalStepCount:4
initialLr:learningRate
error:&error];
ORTAssertBoolResultSuccessful(result, error);
[self runTrainStep];
result = [self.session optimizerStepWithError:&error];
ORTAssertBoolResultSuccessful(result, error);
result = [self.session schedulerStepWithError:&error];
ORTAssertBoolResultSuccessful(result, error);
ORTAssertEqualFloatAndNoError(0.05f, [self.session getLearningRateWithError:&error], error);
result = [self.session optimizerStepWithError:&error];
ORTAssertBoolResultSuccessful(result, error);
result = [self.session schedulerStepWithError:&error];
ORTAssertBoolResultSuccessful(result, error);
ORTAssertEqualFloatAndNoError(0.1f, [self.session getLearningRateWithError:&error], error);
result = [self.session optimizerStepWithError:&error];
ORTAssertBoolResultSuccessful(result, error);
result = [self.session schedulerStepWithError:&error];
ORTAssertBoolResultSuccessful(result, error);
ORTAssertEqualFloatAndNoError(0.05f, [self.session getLearningRateWithError:&error], error);
result = [self.session optimizerStepWithError:&error];
ORTAssertBoolResultSuccessful(result, error);
result = [self.session schedulerStepWithError:&error];
ORTAssertBoolResultSuccessful(result, error);
ORTAssertEqualFloatAndNoError(0.0f, [self.session getLearningRateWithError:&error], error);
}
- (void)testExportModelForInference {
NSError* error = nil;
NSString* inferenceModelPath = [test_utils::createTemporaryDirectory(self)
stringByAppendingPathComponent:@"inference_model.onnx"];
XCTAssertNotNil(inferenceModelPath);
NSArray<NSString*>* graphOutputNames = [NSArray arrayWithObjects:@"output-0", nil];
BOOL result = [self.session exportModelForInferenceWithOutputPath:inferenceModelPath
graphOutputNames:graphOutputNames
error:&error];
ORTAssertBoolResultSuccessful(result, error);
XCTAssertTrue([[NSFileManager defaultManager] fileExistsAtPath:inferenceModelPath]);
[self addTeardownBlock:^{
NSError* error = nil;
[[NSFileManager defaultManager] removeItemAtPath:inferenceModelPath error:&error];
}];
}
- (void)testToBuffer {
NSError* error = nil;
ORTValue* buffer = [self.session toBufferWithTrainable:YES error:&error];
ORTAssertNullableResultSuccessful(buffer, error);
ORTValueTypeInfo* typeInfo = [buffer typeInfoWithError:&error];
ORTAssertNullableResultSuccessful(typeInfo, error);
XCTAssertEqual(typeInfo.type, ORTValueTypeTensor);
XCTAssertNotNil(typeInfo.tensorTypeAndShapeInfo);
}
- (void)testFromBuffer {
NSError* error = nil;
ORTValue* buffer = [self.session toBufferWithTrainable:YES error:&error];
ORTAssertNullableResultSuccessful(buffer, error);
BOOL result = [self.session fromBufferWithValue:buffer error:&error];
ORTAssertBoolResultSuccessful(result, error);
}
- (void)tearDown {
_session = nil;
_checkpoint = nil;
_ortEnv = nil;
[super tearDown];
}
@end
NS_ASSUME_NONNULL_END
#endif // ENABLE_TRAINING_APIS

View file

@ -0,0 +1,29 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef ENABLE_TRAINING_APIS
#import <XCTest/XCTest.h>
#import "ort_training_session.h"
NS_ASSUME_NONNULL_BEGIN
@interface ORTTrainingUtilsTest : XCTestCase
@end
@implementation ORTTrainingUtilsTest
- (void)setUp {
[super setUp];
self.continueAfterFailure = NO;
}
- (void)testSetSeed {
ORTSetSeed(2718);
}
@end
NS_ASSUME_NONNULL_END
#endif // ENABLE_TRAINING_APIS

View file

@ -0,0 +1,21 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#import <Foundation/Foundation.h>
#import <XCTest/XCTest.h>
#import "ort_session.h"
#import "ort_env.h"
#import "ort_value.h"
NS_ASSUME_NONNULL_BEGIN
namespace test_utils {
NSString* _Nullable createTemporaryDirectory(XCTestCase* testCase);
NSArray<NSNumber*>* getFloatArrayFromData(NSData* data);
} // namespace test_utils
NS_ASSUME_NONNULL_END

View file

@ -0,0 +1,44 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#import "test_utils.h"
NS_ASSUME_NONNULL_BEGIN
namespace test_utils {
NSString* createTemporaryDirectory(XCTestCase* testCase) {
NSString* temporaryDirectory = NSTemporaryDirectory();
NSString* directoryPath = [temporaryDirectory stringByAppendingPathComponent:@"ort-objective-c-test"];
NSError* error = nil;
[[NSFileManager defaultManager] createDirectoryAtPath:directoryPath
withIntermediateDirectories:YES
attributes:nil
error:&error];
XCTAssertNil(error, @"Error creating temporary directory: %@", error.localizedDescription);
// add teardown block to delete the temporary directory
[testCase addTeardownBlock:^{
NSError* error = nil;
[[NSFileManager defaultManager] removeItemAtPath:directoryPath error:&error];
XCTAssertNil(error, @"Error removing temporary directory: %@", error.localizedDescription);
}];
return directoryPath;
}
NSArray<NSNumber*>* getFloatArrayFromData(NSData* data) {
NSMutableArray<NSNumber*>* array = [NSMutableArray array];
float value;
for (size_t i = 0; i < data.length / sizeof(float); ++i) {
[data getBytes:&value range:NSMakeRange(i * sizeof(float), sizeof(float))];
[array addObject:[NSNumber numberWithFloat:value]];
}
return array;
}
} // namespace test_utils
NS_ASSUME_NONNULL_END

View file

@ -194,7 +194,6 @@ class TrainingSession : public detail::Base<OrtTrainingSession> {
* \param[in] input_values The user inputs to the training model.
* \return A std::vector of Ort::Value objects that represents the output of the forward pass of the training model.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
std::vector<Value> TrainStep(const std::vector<Value>& input_values);