mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-15 20:50:42 +00:00
Objective C Training API: TrainingSession (#16374)
### Description - Implement Objective-C binding for `ORTTrainingSession` - Add `ORTUtils` utility class to handle conversion between C++ and Objective-C types - Add test case for saving checkpoint - Add unit test cases for `ORTTrainingSession` ### Motivation and Context This PR is part of implementing Objective-C bindings for training API. It implements objective-c binding for training session. The objective-C API closely resembles the C++ API. --------- Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
This commit is contained in:
parent
6dd4e4801a
commit
960e320dff
15 changed files with 1150 additions and 12 deletions
|
|
@ -42,11 +42,14 @@ file(GLOB onnxruntime_objc_srcs CONFIGURE_DEPENDS
|
|||
|
||||
if(NOT onnxruntime_ENABLE_TRAINING_APIS)
|
||||
list(REMOVE_ITEM onnxruntime_objc_headers
|
||||
"${OBJC_ROOT}/include/ort_checkpoint.h")
|
||||
"${OBJC_ROOT}/include/ort_checkpoint.h"
|
||||
"${OBJC_ROOT}/include/ort_training_session.h")
|
||||
|
||||
list(REMOVE_ITEM onnxruntime_objc_srcs
|
||||
"${OBJC_ROOT}/ort_checkpoint_internal.h"
|
||||
"${OBJC_ROOT}/ort_checkpoint.mm")
|
||||
"${OBJC_ROOT}/ort_checkpoint.mm"
|
||||
"${OBJC_ROOT}/ort_training_session_internal.h"
|
||||
"${OBJC_ROOT}/ort_training_session.mm")
|
||||
endif()
|
||||
|
||||
|
||||
|
|
@ -124,7 +127,9 @@ if(onnxruntime_BUILD_UNIT_TESTS)
|
|||
|
||||
if(NOT onnxruntime_ENABLE_TRAINING_APIS)
|
||||
list(REMOVE_ITEM onnxruntime_objc_test_srcs
|
||||
"${OBJC_ROOT}/test/ort_checkpoint_test.mm")
|
||||
"${OBJC_ROOT}/test/ort_checkpoint_test.mm"
|
||||
"${OBJC_ROOT}/test/ort_training_session_test.mm"
|
||||
"${OBJC_ROOT}/test/ort_training_utils_test.mm")
|
||||
|
||||
endif()
|
||||
|
||||
|
|
|
|||
32
objectivec/cxx_utils.h
Normal file
32
objectivec/cxx_utils.h
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <variant>
|
||||
|
||||
#import "cxx_api.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
@class ORTValue;
|
||||
|
||||
namespace utils {
|
||||
|
||||
NSString* toNSString(const std::string& str);
|
||||
NSString* _Nullable toNullableNSString(const std::optional<std::string>& str);
|
||||
|
||||
std::string toStdString(NSString* str);
|
||||
std::optional<std::string> toStdOptionalString(NSString* _Nullable str);
|
||||
|
||||
std::vector<std::string> toStdStringVector(NSArray<NSString*>* strs);
|
||||
NSArray<NSString*>* toNSStringNSArray(const std::vector<std::string>& strs);
|
||||
|
||||
NSArray<ORTValue*>* _Nullable wrapUnownedCAPIOrtValues(const std::vector<OrtValue*>& values, NSError** error);
|
||||
|
||||
std::vector<const OrtValue*> getWrappedCAPIOrtValues(NSArray<ORTValue*>* values);
|
||||
|
||||
} // namespace utils
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
||||
93
objectivec/cxx_utils.mm
Normal file
93
objectivec/cxx_utils.mm
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#import "cxx_utils.h"
|
||||
|
||||
#import <vector>
|
||||
#import <optional>
|
||||
#import <string>
|
||||
|
||||
#import "error_utils.h"
|
||||
|
||||
#import "ort_value_internal.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
namespace utils {
|
||||
|
||||
NSString* toNSString(const std::string& str) {
|
||||
NSString* nsStr = [NSString stringWithUTF8String:str.c_str()];
|
||||
if (!nsStr) {
|
||||
ORT_CXX_API_THROW("Failed to convert std::string to NSString", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
return nsStr;
|
||||
}
|
||||
|
||||
NSString* _Nullable toNullableNSString(const std::optional<std::string>& str) {
|
||||
if (str.has_value()) {
|
||||
return toNSString(*str);
|
||||
}
|
||||
return nil;
|
||||
}
|
||||
|
||||
std::string toStdString(NSString* str) {
|
||||
return std::string([str UTF8String]);
|
||||
}
|
||||
|
||||
std::optional<std::string> toStdOptionalString(NSString* _Nullable str) {
|
||||
if (str) {
|
||||
return std::optional<std::string>([str UTF8String]);
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::vector<std::string> toStdStringVector(NSArray<NSString*>* strs) {
|
||||
std::vector<std::string> result;
|
||||
result.reserve(strs.count);
|
||||
for (NSString* str in strs) {
|
||||
result.push_back([str UTF8String]);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
NSArray<NSString*>* toNSStringNSArray(const std::vector<std::string>& strs) {
|
||||
NSMutableArray<NSString*>* result = [NSMutableArray arrayWithCapacity:strs.size()];
|
||||
for (const std::string& str : strs) {
|
||||
NSString* nsStr = [NSString stringWithUTF8String:str.c_str()];
|
||||
if (nsStr) {
|
||||
[result addObject:nsStr];
|
||||
} else {
|
||||
ORT_CXX_API_THROW("Failed to convert std::string to NSString", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
NSArray<ORTValue*>* _Nullable wrapUnownedCAPIOrtValues(const std::vector<OrtValue*>& values, NSError** error) {
|
||||
NSMutableArray<ORTValue*>* result = [NSMutableArray arrayWithCapacity:values.size()];
|
||||
for (size_t i = 0; i < values.size(); ++i) {
|
||||
ORTValue* val = [[ORTValue alloc] initWithCAPIOrtValue:values[i] externalTensorData:nil error:error];
|
||||
if (!val) {
|
||||
// clean up all the C API Ortvalues which haven't been wrapped by ORTValue
|
||||
for (size_t j = i; j < values.size(); ++j) {
|
||||
Ort::GetApi().ReleaseValue(values[j]);
|
||||
}
|
||||
return nil;
|
||||
}
|
||||
[result addObject:val];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<const OrtValue*> getWrappedCAPIOrtValues(NSArray<ORTValue*>* values) {
|
||||
std::vector<const OrtValue*> result;
|
||||
for (ORTValue* val in values) {
|
||||
result.push_back(static_cast<const OrtValue*>([val CXXAPIOrtValue]));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
||||
9
objectivec/include/onnxruntime_training.h
Normal file
9
objectivec/include/onnxruntime_training.h
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// this header contains the entire ONNX Runtime training Objective-C API
|
||||
// the headers below can also be imported individually
|
||||
|
||||
#import "onnxruntime.h"
|
||||
#import "ort_checkpoint.h"
|
||||
#import "ort_training_session.h"
|
||||
|
|
@ -11,8 +11,8 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
* An ORT checkpoint is a snapshot of the state of a model at a given point in time.
|
||||
*
|
||||
* This class holds the entire training session state that includes model parameters,
|
||||
* their gradients, optimizer parameters, and user properties. The ORTTrainingSession leverages the
|
||||
* ORTCheckpointState by accessing and updating the contained training state.
|
||||
* their gradients, optimizer parameters, and user properties. The `ORTTrainingSession` leverages the
|
||||
* `ORTCheckpoint` by accessing and updating the contained training state.
|
||||
*
|
||||
* Available since v1.16.0.
|
||||
*
|
||||
|
|
|
|||
261
objectivec/include/ort_training_session.h
Normal file
261
objectivec/include/ort_training_session.h
Normal file
|
|
@ -0,0 +1,261 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifdef ENABLE_TRAINING_APIS
|
||||
#import <Foundation/Foundation.h>
|
||||
#include <stdint.h>
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
@class ORTCheckpoint;
|
||||
@class ORTEnv;
|
||||
@class ORTValue;
|
||||
@class ORTSessionOptions;
|
||||
|
||||
/**
|
||||
* Trainer class that provides methods to train, evaluate and optimize ONNX models.
|
||||
*
|
||||
* The training session requires four training artifacts:
|
||||
* 1. Training onnx model
|
||||
* 2. Evaluation onnx model (optional)
|
||||
* 3. Optimizer onnx model
|
||||
* 4. Checkpoint directory
|
||||
*
|
||||
* [onnxruntime-training python utility](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md)
|
||||
* can be used to generate above training artifacts.
|
||||
*
|
||||
* Available since v1.16.0.
|
||||
*/
|
||||
@interface ORTTrainingSession : NSObject
|
||||
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
|
||||
/**
|
||||
* Creates a training session from the training artifacts that can be used to begin or resume training.
|
||||
*
|
||||
* The initializer instantiates the training session based on provided env and session options, which can be used to
|
||||
* begin or resume training from a given checkpoint state. The checkpoint state represents the parameters of training
|
||||
* session which will be moved to the device specified in the session option if needed.
|
||||
*
|
||||
* @param env The `ORTEnv` instance to use for the training session.
|
||||
* @param sessionOptions The `ORTSessionOptions` to use for the training session.
|
||||
* @param checkpoint Training states that are used as a starting point for training.
|
||||
* @param trainModelPath The path to the training onnx model.
|
||||
* @param evalModelPath The path to the evaluation onnx model.
|
||||
* @param optimizerModelPath The path to the optimizer onnx model used to perform gradient descent.
|
||||
* @param error Optional error information set if an error occurs.
|
||||
* @return The instance, or nil if an error occurs.
|
||||
*
|
||||
*/
|
||||
- (nullable instancetype)initWithEnv:(ORTEnv*)env
|
||||
sessionOptions:(ORTSessionOptions*)sessionOptions
|
||||
checkpoint:(ORTCheckpoint*)checkpoint
|
||||
trainModelPath:(NSString*)trainModelPath
|
||||
evalModelPath:(nullable NSString*)evalModelPath
|
||||
optimizerModelPath:(nullable NSString*)optimizerModelPath
|
||||
error:(NSError**)error NS_DESIGNATED_INITIALIZER;
|
||||
|
||||
/**
|
||||
* Performs a training step, which is equivalent to a forward and backward propagation in a single step.
|
||||
*
|
||||
* The training step computes the outputs of the training model and the gradients of the trainable parameters
|
||||
* for the given input values. The train step is performed based on the training model that was provided to the training session.
|
||||
* It is equivalent to running forward and backward propagation in a single step. The computed gradients are stored inside
|
||||
* the training session state so they can be later consumed by `optimizerStep`. The gradients can be lazily reset by
|
||||
* calling `lazyResetGrad` method.
|
||||
*
|
||||
* @param inputs The input values to the training model.
|
||||
* @param error Optional error information set if an error occurs.
|
||||
* @return The output values of the training model.
|
||||
*/
|
||||
- (nullable NSArray<ORTValue*>*)trainStepWithInputValues:(NSArray<ORTValue*>*)inputs
|
||||
error:(NSError**)error;
|
||||
|
||||
/**
|
||||
* Performs a evaluation step that computes the outputs of the evaluation model for the given inputs.
|
||||
* The eval step is performed based on the evaluation model that was provided to the training session.
|
||||
*
|
||||
* @param inputs The input values to the eval model.
|
||||
* @param error Optional error information set if an error occurs.
|
||||
* @return The output values of the eval model.
|
||||
*
|
||||
*/
|
||||
- (nullable NSArray<ORTValue*>*)evalStepWithInputValues:(NSArray<ORTValue*>*)inputs
|
||||
error:(NSError**)error;
|
||||
|
||||
/**
|
||||
* Reset the gradients of all trainable parameters to zero lazily.
|
||||
*
|
||||
* Calling this method sets the internal state of the training session such that the gradients of the trainable parameters
|
||||
* in the ORTCheckpoint will be scheduled to be reset just before the new gradients are computed on the next
|
||||
* invocation of the `trainStep` method.
|
||||
*
|
||||
* @param error Optional error information set if an error occurs.
|
||||
* @return YES if the gradients are set to reset successfully, NO otherwise.
|
||||
*/
|
||||
- (BOOL)lazyResetGradWithError:(NSError**)error;
|
||||
|
||||
/**
|
||||
* Performs the weight updates for the trainable parameters using the optimizer model. The optimizer step is performed
|
||||
* based on the optimizer model that was provided to the training session. The updated parameters are stored inside the
|
||||
* training state so that they can be used by the next `trainStep` method call.
|
||||
*
|
||||
* @param error Optional error information set if an error occurs.
|
||||
* @return YES if the optimizer step was performed successfully, NO otherwise.
|
||||
*/
|
||||
- (BOOL)optimizerStepWithError:(NSError**)error;
|
||||
|
||||
/**
|
||||
* Returns the names of the user inputs for the training model that can be associated with
|
||||
* the `ORTValue` provided to the `trainStep`.
|
||||
*
|
||||
* @param error Optional error information set if an error occurs.
|
||||
* @return The names of the user inputs for the training model.
|
||||
*/
|
||||
- (nullable NSArray<NSString*>*)getTrainInputNamesWithError:(NSError**)error;
|
||||
|
||||
/**
|
||||
* Returns the names of the user inputs for the evaluation model that can be associated with
|
||||
* the `ORTValue` provided to the `evalStep`.
|
||||
*
|
||||
* @param error Optional error information set if an error occurs.
|
||||
* @return The names of the user inputs for the evaluation model.
|
||||
*/
|
||||
- (nullable NSArray<NSString*>*)getEvalInputNamesWithError:(NSError**)error;
|
||||
|
||||
/**
|
||||
* Returns the names of the user outputs for the training model that can be associated with
|
||||
* the `ORTValue` returned by the `trainStep`.
|
||||
*
|
||||
* @param error Optional error information set if an error occurs.
|
||||
* @return The names of the user outputs for the training model.
|
||||
*/
|
||||
- (nullable NSArray<NSString*>*)getTrainOutputNamesWithError:(NSError**)error;
|
||||
|
||||
/**
|
||||
* Returns the names of the user outputs for the evaluation model that can be associated with
|
||||
* the `ORTValue` returned by the `evalStep`.
|
||||
*
|
||||
* @param error Optional error information set if an error occurs.
|
||||
* @return The names of the user outputs for the evaluation model.
|
||||
*/
|
||||
- (nullable NSArray<NSString*>*)getEvalOutputNamesWithError:(NSError**)error;
|
||||
|
||||
/**
|
||||
* Registers a linear learning rate scheduler for the training session.
|
||||
*
|
||||
* The scheduler gradually decreases the learning rate from the initial value to zero over the course of the training.
|
||||
* The decrease is performed by multiplying the current learning rate by a linearly updated factor.
|
||||
* Before the decrease, the learning rate is gradually increased from zero to the initial value during a warmup phase.
|
||||
*
|
||||
* @param warmupStepCount The number of steps to perform the linear warmup.
|
||||
* @param totalStepCount The total number of steps to perform the linear decay.
|
||||
* @param initialLr The initial learning rate.
|
||||
* @param error Optional error information set if an error occurs.
|
||||
* @return YES if the scheduler was registered successfully, NO otherwise.
|
||||
*/
|
||||
- (BOOL)registerLinearLRSchedulerWithWarmupStepCount:(int64_t)warmupStepCount
|
||||
totalStepCount:(int64_t)totalStepCount
|
||||
initialLr:(float)initialLr
|
||||
error:(NSError**)error;
|
||||
|
||||
/**
|
||||
* Update the learning rate based on the registered learning rate scheduler.
|
||||
*
|
||||
* Performs a scheduler step that updates the learning rate that is being used by the training session.
|
||||
* This function should typically be called before invoking the optimizer step for each round, or as necessary
|
||||
* to update the learning rate being used by the training session.
|
||||
*
|
||||
* @note A valid predefined learning rate scheduler must be first registered to invoke this method.
|
||||
*
|
||||
* @param error Optional error information set if an error occurs.
|
||||
* @return YES if the scheduler step was performed successfully, NO otherwise.
|
||||
*/
|
||||
- (BOOL)schedulerStepWithError:(NSError**)error;
|
||||
|
||||
/**
|
||||
* Returns the current learning rate being used by the training session.
|
||||
*
|
||||
* @param error Optional error information set if an error occurs.
|
||||
* @return The current learning rate or 0.0f if an error occurs.
|
||||
*/
|
||||
- (float)getLearningRateWithError:(NSError**)error __attribute__((swift_error(nonnull_error)));
|
||||
|
||||
/**
|
||||
* Sets the learning rate being used by the training session.
|
||||
*
|
||||
* The current learning rate is maintained by the training session and can be overwritten by invoking this method
|
||||
* with the desired learning rate. This function should not be used when a valid learning rate scheduler is registered.
|
||||
* It should be used either to set the learning rate derived from a custom learning rate scheduler or to set a constant
|
||||
* learning rate to be used throughout the training session.
|
||||
*
|
||||
* @note It does not set the initial learning rate that may be needed by the predefined learning rate schedulers.
|
||||
* To set the initial learning rate for learning rate schedulers, use the `registerLinearLRScheduler` method.
|
||||
*
|
||||
* @param lr The learning rate to be used by the training session.
|
||||
* @param error Optional error information set if an error occurs.
|
||||
* @return YES if the learning rate was set successfully, NO otherwise.
|
||||
*/
|
||||
- (BOOL)setLearningRate:(float)lr
|
||||
error:(NSError**)error;
|
||||
|
||||
/**
|
||||
* Loads the training session model parameters from a contiguous buffer.
|
||||
*
|
||||
* @param buffer Contiguous buffer to load the parameters from.
|
||||
* @param error Optional error information set if an error occurs.
|
||||
* @return YES if the parameters were loaded successfully, NO otherwise.
|
||||
*/
|
||||
- (BOOL)fromBufferWithValue:(ORTValue*)buffer
|
||||
error:(NSError**)error;
|
||||
|
||||
/**
|
||||
* Returns a contiguous buffer that holds a copy of all training state parameters.
|
||||
*
|
||||
* @param onlyTrainable If YES, returns a buffer that holds only the trainable parameters, otherwise returns a buffer
|
||||
* that holds all the parameters.
|
||||
* @param error Optional error information set if an error occurs.
|
||||
* @return A contiguous buffer that holds a copy of all training state parameters.
|
||||
*/
|
||||
- (nullable ORTValue*)toBufferWithTrainable:(BOOL)onlyTrainable
|
||||
error:(NSError**)error;
|
||||
|
||||
/**
|
||||
* Exports the training session model that can be used for inference.
|
||||
*
|
||||
* If the training session was provided with an eval model, the training session can generate an inference model if it
|
||||
* knows the inference graph outputs. The input inference graph outputs are used to prune the eval model so that the
|
||||
* inference model's outputs align with the provided outputs. The exported model is saved at the path provided and
|
||||
* can be used for inferencing with `ORTSession`.
|
||||
*
|
||||
* @note The method reloads the eval model from the path provided to the initializer and expects this path to be valid.
|
||||
*
|
||||
* @param inferenceModelPath The path to the serialized the inference model.
|
||||
* @param graphOutputNames The names of the outputs that are needed in the inference model.
|
||||
* @param error Optional error information set if an error occurs.
|
||||
* @return YES if the inference model was exported successfully, NO otherwise.
|
||||
*/
|
||||
- (BOOL)exportModelForInferenceWithOutputPath:(NSString*)inferenceModelPath
|
||||
graphOutputNames:(NSArray<NSString*>*)graphOutputNames
|
||||
error:(NSError**)error;
|
||||
@end
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* This function sets the seed for generating random numbers.
|
||||
* Use this function to generate reproducible results. It should be noted that completely reproducible results are not guaranteed.
|
||||
*
|
||||
* @param seed Manually set seed to use for random number generation.
|
||||
*/
|
||||
void ORTSetSeed(int64_t seed);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
||||
|
||||
#endif // ENABLE_TRAINING_APIS
|
||||
227
objectivec/ort_training_session.mm
Normal file
227
objectivec/ort_training_session.mm
Normal file
|
|
@ -0,0 +1,227 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifdef ENABLE_TRAINING_APIS
|
||||
#import "ort_training_session_internal.h"
|
||||
|
||||
#import <vector>
|
||||
#import <optional>
|
||||
#import <string>
|
||||
|
||||
#import "cxx_api.h"
|
||||
#import "cxx_utils.h"
|
||||
#import "error_utils.h"
|
||||
#import "ort_checkpoint_internal.h"
|
||||
#import "ort_session_internal.h"
|
||||
#import "ort_enums_internal.h"
|
||||
#import "ort_env_internal.h"
|
||||
#import "ort_value_internal.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
@implementation ORTTrainingSession {
|
||||
std::optional<Ort::TrainingSession> _session;
|
||||
ORTCheckpoint* _checkpoint;
|
||||
}
|
||||
|
||||
- (Ort::TrainingSession&)CXXAPIOrtTrainingSession {
|
||||
return *_session;
|
||||
}
|
||||
|
||||
- (nullable instancetype)initWithEnv:(ORTEnv*)env
|
||||
sessionOptions:(ORTSessionOptions*)sessionOptions
|
||||
checkpoint:(ORTCheckpoint*)checkpoint
|
||||
trainModelPath:(NSString*)trainModelPath
|
||||
evalModelPath:(nullable NSString*)evalModelPath
|
||||
optimizerModelPath:(nullable NSString*)optimizerModelPath
|
||||
error:(NSError**)error {
|
||||
if ((self = [super init]) == nil) {
|
||||
return nil;
|
||||
}
|
||||
|
||||
try {
|
||||
std::optional<std::string> evalPath = utils::toStdOptionalString(evalModelPath);
|
||||
std::optional<std::string> optimizerPath = utils::toStdOptionalString(optimizerModelPath);
|
||||
|
||||
_checkpoint = checkpoint;
|
||||
_session = Ort::TrainingSession{
|
||||
[env CXXAPIOrtEnv],
|
||||
[sessionOptions CXXAPIOrtSessionOptions],
|
||||
[checkpoint CXXAPIOrtCheckpoint],
|
||||
trainModelPath.UTF8String,
|
||||
evalPath,
|
||||
optimizerPath};
|
||||
return self;
|
||||
}
|
||||
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
|
||||
}
|
||||
|
||||
- (nullable NSArray<ORTValue*>*)trainStepWithInputValues:(NSArray<ORTValue*>*)inputs
|
||||
error:(NSError**)error {
|
||||
try {
|
||||
std::vector<const OrtValue*> inputValues = utils::getWrappedCAPIOrtValues(inputs);
|
||||
|
||||
size_t outputCount;
|
||||
Ort::ThrowOnError(Ort::GetTrainingApi().TrainingSessionGetTrainingModelOutputCount(*_session, &outputCount));
|
||||
std::vector<OrtValue*> outputValues(outputCount, nullptr);
|
||||
|
||||
Ort::RunOptions runOptions;
|
||||
Ort::ThrowOnError(Ort::GetTrainingApi().TrainStep(
|
||||
*_session,
|
||||
runOptions,
|
||||
inputValues.size(),
|
||||
inputValues.data(),
|
||||
outputValues.size(),
|
||||
outputValues.data()));
|
||||
|
||||
return utils::wrapUnownedCAPIOrtValues(outputValues, error);
|
||||
}
|
||||
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
|
||||
}
|
||||
- (nullable NSArray<ORTValue*>*)evalStepWithInputValues:(NSArray<ORTValue*>*)inputs
|
||||
error:(NSError**)error {
|
||||
try {
|
||||
// create vector of OrtValue from NSArray<ORTValue*> with same size as inputValues
|
||||
std::vector<const OrtValue*> inputValues = utils::getWrappedCAPIOrtValues(inputs);
|
||||
|
||||
size_t outputCount;
|
||||
Ort::ThrowOnError(Ort::GetTrainingApi().TrainingSessionGetEvalModelOutputCount(*_session, &outputCount));
|
||||
std::vector<OrtValue*> outputValues(outputCount, nullptr);
|
||||
|
||||
Ort::RunOptions runOptions;
|
||||
Ort::ThrowOnError(Ort::GetTrainingApi().EvalStep(
|
||||
*_session,
|
||||
runOptions,
|
||||
inputValues.size(),
|
||||
inputValues.data(),
|
||||
outputValues.size(),
|
||||
outputValues.data()));
|
||||
|
||||
return utils::wrapUnownedCAPIOrtValues(outputValues, error);
|
||||
}
|
||||
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
|
||||
}
|
||||
|
||||
- (BOOL)lazyResetGradWithError:(NSError**)error {
|
||||
try {
|
||||
[self CXXAPIOrtTrainingSession].LazyResetGrad();
|
||||
return YES;
|
||||
}
|
||||
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
|
||||
}
|
||||
|
||||
- (BOOL)optimizerStepWithError:(NSError**)error {
|
||||
try {
|
||||
[self CXXAPIOrtTrainingSession].OptimizerStep();
|
||||
return YES;
|
||||
}
|
||||
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
|
||||
}
|
||||
|
||||
- (nullable NSArray<NSString*>*)getTrainInputNamesWithError:(NSError**)error {
|
||||
try {
|
||||
std::vector<std::string> inputNames = [self CXXAPIOrtTrainingSession].InputNames(true);
|
||||
return utils::toNSStringNSArray(inputNames);
|
||||
}
|
||||
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
|
||||
}
|
||||
|
||||
- (nullable NSArray<NSString*>*)getTrainOutputNamesWithError:(NSError**)error {
|
||||
try {
|
||||
std::vector<std::string> outputNames = [self CXXAPIOrtTrainingSession].OutputNames(true);
|
||||
return utils::toNSStringNSArray(outputNames);
|
||||
}
|
||||
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
|
||||
}
|
||||
|
||||
- (nullable NSArray<NSString*>*)getEvalInputNamesWithError:(NSError**)error {
|
||||
try {
|
||||
std::vector<std::string> inputNames = [self CXXAPIOrtTrainingSession].InputNames(false);
|
||||
return utils::toNSStringNSArray(inputNames);
|
||||
}
|
||||
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
|
||||
}
|
||||
|
||||
- (nullable NSArray<NSString*>*)getEvalOutputNamesWithError:(NSError**)error {
|
||||
try {
|
||||
std::vector<std::string> outputNames = [self CXXAPIOrtTrainingSession].OutputNames(false);
|
||||
return utils::toNSStringNSArray(outputNames);
|
||||
}
|
||||
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
|
||||
}
|
||||
|
||||
- (BOOL)registerLinearLRSchedulerWithWarmupStepCount:(int64_t)warmupStepCount
|
||||
totalStepCount:(int64_t)totalStepCount
|
||||
initialLr:(float)initialLr
|
||||
error:(NSError**)error {
|
||||
try {
|
||||
[self CXXAPIOrtTrainingSession].RegisterLinearLRScheduler(warmupStepCount, totalStepCount, initialLr);
|
||||
return YES;
|
||||
}
|
||||
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
|
||||
}
|
||||
|
||||
- (BOOL)schedulerStepWithError:(NSError**)error {
|
||||
try {
|
||||
[self CXXAPIOrtTrainingSession].SchedulerStep();
|
||||
return YES;
|
||||
}
|
||||
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
|
||||
}
|
||||
|
||||
- (float)getLearningRateWithError:(NSError**)error {
|
||||
try {
|
||||
return [self CXXAPIOrtTrainingSession].GetLearningRate();
|
||||
}
|
||||
ORT_OBJC_API_IMPL_CATCH(error, 0.0f);
|
||||
}
|
||||
|
||||
- (BOOL)setLearningRate:(float)lr
|
||||
error:(NSError**)error {
|
||||
try {
|
||||
[self CXXAPIOrtTrainingSession].SetLearningRate(lr);
|
||||
return YES;
|
||||
}
|
||||
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
|
||||
}
|
||||
|
||||
- (BOOL)fromBufferWithValue:(ORTValue*)buffer
|
||||
error:(NSError**)error {
|
||||
try {
|
||||
[self CXXAPIOrtTrainingSession].FromBuffer([buffer CXXAPIOrtValue]);
|
||||
return YES;
|
||||
}
|
||||
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
|
||||
}
|
||||
|
||||
- (nullable ORTValue*)toBufferWithTrainable:(BOOL)onlyTrainable
|
||||
error:(NSError**)error {
|
||||
try {
|
||||
Ort::Value val = [self CXXAPIOrtTrainingSession].ToBuffer(onlyTrainable);
|
||||
return [[ORTValue alloc] initWithCAPIOrtValue:val.release()
|
||||
externalTensorData:nil
|
||||
error:error];
|
||||
}
|
||||
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
|
||||
}
|
||||
|
||||
- (BOOL)exportModelForInferenceWithOutputPath:(NSString*)inferenceModelPath
|
||||
graphOutputNames:(NSArray<NSString*>*)graphOutputNames
|
||||
error:(NSError**)error {
|
||||
try {
|
||||
[self CXXAPIOrtTrainingSession].ExportModelForInferencing(utils::toStdString(inferenceModelPath),
|
||||
utils::toStdStringVector(graphOutputNames));
|
||||
return YES;
|
||||
}
|
||||
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
|
||||
}
|
||||
|
||||
@end
|
||||
|
||||
void ORTSetSeed(int64_t seed) {
|
||||
Ort::SetSeed(seed);
|
||||
}
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
||||
|
||||
#endif // ENABLE_TRAINING_APIS
|
||||
19
objectivec/ort_training_session_internal.h
Normal file
19
objectivec/ort_training_session_internal.h
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifdef ENABLE_TRAINING_APIS
|
||||
#import "ort_training_session.h"
|
||||
|
||||
#import "cxx_api.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
@interface ORTTrainingSession ()
|
||||
|
||||
- (Ort::TrainingSession&)CXXAPIOrtTrainingSession;
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
||||
|
||||
#endif // ENABLE_TRAINING_APIS
|
||||
|
|
@ -29,4 +29,18 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
XCTAssertNotNil(error); \
|
||||
} while (0)
|
||||
|
||||
#define ORTAssertEqualFloatAndNoError(expected, result, error) \
|
||||
do { \
|
||||
XCTAssertEqualWithAccuracy(expected, result, 1e-3f, @"Expected %f but got %f. Error:%@", expected, result, error); \
|
||||
XCTAssertNil(error); \
|
||||
} while (0)
|
||||
|
||||
#define ORTAssertEqualFloatArrays(expected, result) \
|
||||
do { \
|
||||
XCTAssertEqual(expected.count, result.count); \
|
||||
for (size_t i = 0; i < expected.count; ++i) { \
|
||||
XCTAssertEqualWithAccuracy([expected[i] floatValue], [result[i] floatValue], 1e-3f); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
||||
|
|
|
|||
|
|
@ -5,7 +5,11 @@
|
|||
#import <XCTest/XCTest.h>
|
||||
|
||||
#import "ort_checkpoint.h"
|
||||
#import "ort_training_session.h"
|
||||
#import "ort_env.h"
|
||||
#import "ort_session.h"
|
||||
|
||||
#import "test/test_utils.h"
|
||||
#import "test/assertion_utils.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
|
@ -27,22 +31,41 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
ORTAssertNullableResultSuccessful(_ortEnv, err);
|
||||
}
|
||||
|
||||
- (NSString*)getCheckpointPath {
|
||||
+ (NSString*)getCheckpointPath {
|
||||
NSBundle* bundle = [NSBundle bundleForClass:[ORTCheckpointTest class]];
|
||||
NSString* path = [[bundle resourcePath] stringByAppendingPathComponent:@"checkpoint.ckpt"];
|
||||
return path;
|
||||
}
|
||||
|
||||
- (void)testLoadCheckpoint {
|
||||
+ (NSString*)getTrainingModelPath {
|
||||
NSBundle* bundle = [NSBundle bundleForClass:[ORTCheckpointTest class]];
|
||||
NSString* path = [[bundle resourcePath] stringByAppendingPathComponent:@"training_model.onnx"];
|
||||
return path;
|
||||
}
|
||||
|
||||
- (void)testSaveCheckpoint {
|
||||
NSError* error = nil;
|
||||
ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[self getCheckpointPath] error:&error];
|
||||
ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[ORTCheckpointTest getCheckpointPath] error:&error];
|
||||
ORTAssertNullableResultSuccessful(checkpoint, error);
|
||||
|
||||
// save checkpoint
|
||||
NSString* path = [test_utils::createTemporaryDirectory(self) stringByAppendingPathComponent:@"save_checkpoint.ckpt"];
|
||||
XCTAssertNotNil(path);
|
||||
BOOL result = [checkpoint saveCheckpointToPath:path withOptimizerState:NO error:&error];
|
||||
|
||||
ORTAssertBoolResultSuccessful(result, error);
|
||||
}
|
||||
|
||||
- (void)testInitCheckpoint {
|
||||
NSError* error = nil;
|
||||
ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[ORTCheckpointTest getCheckpointPath] error:&error];
|
||||
ORTAssertNullableResultSuccessful(checkpoint, error);
|
||||
}
|
||||
|
||||
- (void)testIntProperty {
|
||||
NSError* error = nil;
|
||||
// Load checkpoint
|
||||
ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[self getCheckpointPath] error:&error];
|
||||
ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[ORTCheckpointTest getCheckpointPath] error:&error];
|
||||
ORTAssertNullableResultSuccessful(checkpoint, error);
|
||||
|
||||
// Add property
|
||||
|
|
@ -57,7 +80,7 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
- (void)testFloatProperty {
|
||||
NSError* error = nil;
|
||||
// Load checkpoint
|
||||
ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[self getCheckpointPath] error:&error];
|
||||
ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[ORTCheckpointTest getCheckpointPath] error:&error];
|
||||
ORTAssertNullableResultSuccessful(checkpoint, error);
|
||||
|
||||
// Add property
|
||||
|
|
@ -72,7 +95,7 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
- (void)testStringProperty {
|
||||
NSError* error = nil;
|
||||
// Load checkpoint
|
||||
ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[self getCheckpointPath] error:&error];
|
||||
ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[ORTCheckpointTest getCheckpointPath] error:&error];
|
||||
ORTAssertNullableResultSuccessful(checkpoint, error);
|
||||
|
||||
// Add property
|
||||
|
|
|
|||
362
objectivec/test/ort_training_session_test.mm
Normal file
362
objectivec/test/ort_training_session_test.mm
Normal file
|
|
@ -0,0 +1,362 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifdef ENABLE_TRAINING_APIS
|
||||
#import <XCTest/XCTest.h>
|
||||
|
||||
#import "ort_checkpoint.h"
|
||||
#import "ort_training_session.h"
|
||||
#import "ort_env.h"
|
||||
#import "ort_session.h"
|
||||
#import "ort_value.h"
|
||||
|
||||
#import "test/test_utils.h"
|
||||
#import "test/assertion_utils.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
@interface ORTTrainingSessionTest : XCTestCase
|
||||
@property(readonly, nullable) ORTEnv* ortEnv;
|
||||
@property(readonly, nullable) ORTCheckpoint* checkpoint;
|
||||
@property(readonly, nullable) ORTTrainingSession* session;
|
||||
@end
|
||||
|
||||
@implementation ORTTrainingSessionTest
|
||||
|
||||
- (void)setUp {
|
||||
[super setUp];
|
||||
|
||||
self.continueAfterFailure = NO;
|
||||
|
||||
NSError* err = nil;
|
||||
_ortEnv = [[ORTEnv alloc] initWithLoggingLevel:ORTLoggingLevelWarning
|
||||
error:&err];
|
||||
ORTAssertNullableResultSuccessful(_ortEnv, err);
|
||||
_checkpoint = [[ORTCheckpoint alloc] initWithPath:[ORTTrainingSessionTest
|
||||
getFilePathFromName:@"checkpoint.ckpt"]
|
||||
error:&err];
|
||||
ORTAssertNullableResultSuccessful(_checkpoint, err);
|
||||
_session = [self makeTrainingSessionWithCheckpoint:_checkpoint];
|
||||
}
|
||||
|
||||
+ (NSString*)getFilePathFromName:(NSString*)name {
|
||||
NSBundle* bundle = [NSBundle bundleForClass:[ORTTrainingSessionTest class]];
|
||||
NSString* path = [[bundle resourcePath] stringByAppendingPathComponent:name];
|
||||
return path;
|
||||
}
|
||||
|
||||
+ (NSMutableData*)loadTensorDataFromFile:(NSString*)filePath skipHeader:(BOOL)skipHeader {
|
||||
NSError* error = nil;
|
||||
NSString* fileContents = [NSString stringWithContentsOfFile:filePath
|
||||
encoding:NSUTF8StringEncoding
|
||||
error:&error];
|
||||
ORTAssertNullableResultSuccessful(fileContents, error);
|
||||
|
||||
NSArray<NSString*>* lines = [fileContents componentsSeparatedByCharactersInSet:[NSCharacterSet newlineCharacterSet]];
|
||||
|
||||
if (skipHeader) {
|
||||
lines = [lines subarrayWithRange:NSMakeRange(1, lines.count - 1)];
|
||||
}
|
||||
|
||||
NSArray<NSString*>* dataArray = [lines[0] componentsSeparatedByCharactersInSet:
|
||||
[NSCharacterSet characterSetWithCharactersInString:@",[] "]];
|
||||
NSMutableData* tensorData = [NSMutableData data];
|
||||
|
||||
for (NSString* str in dataArray) {
|
||||
if (str.length > 0) {
|
||||
float value = [str floatValue];
|
||||
[tensorData appendBytes:&value length:sizeof(float)];
|
||||
}
|
||||
}
|
||||
|
||||
return tensorData;
|
||||
}
|
||||
|
||||
- (ORTTrainingSession*)makeTrainingSessionWithCheckpoint:(ORTCheckpoint*)checkpoint {
|
||||
NSError* error = nil;
|
||||
ORTSessionOptions* sessionOptions = [[ORTSessionOptions alloc] initWithError:&error];
|
||||
ORTAssertNullableResultSuccessful(sessionOptions, error);
|
||||
|
||||
ORTTrainingSession* session = [[ORTTrainingSession alloc]
|
||||
initWithEnv:self.ortEnv
|
||||
sessionOptions:sessionOptions
|
||||
checkpoint:checkpoint
|
||||
trainModelPath:[ORTTrainingSessionTest getFilePathFromName:@"training_model.onnx"]
|
||||
evalModelPath:[ORTTrainingSessionTest getFilePathFromName:@"eval_model.onnx"]
|
||||
optimizerModelPath:[ORTTrainingSessionTest getFilePathFromName:@"adamw.onnx"]
|
||||
error:&error];
|
||||
|
||||
ORTAssertNullableResultSuccessful(session, error);
|
||||
return session;
|
||||
}
|
||||
|
||||
- (void)testInitTrainingSession {
|
||||
NSError* error = nil;
|
||||
|
||||
// check that inputNames contains input-0
|
||||
NSArray<NSString*>* inputNames = [self.session getTrainInputNamesWithError:&error];
|
||||
ORTAssertNullableResultSuccessful(inputNames, error);
|
||||
|
||||
XCTAssertTrue(inputNames.count > 0);
|
||||
XCTAssertTrue([inputNames containsObject:@"input-0"]);
|
||||
|
||||
// check that outNames contains onnx::loss::21273
|
||||
NSArray<NSString*>* outputNames = [self.session getTrainOutputNamesWithError:&error];
|
||||
ORTAssertNullableResultSuccessful(outputNames, error);
|
||||
|
||||
XCTAssertTrue(outputNames.count > 0);
|
||||
XCTAssertTrue([outputNames containsObject:@"onnx::loss::21273"]);
|
||||
}
|
||||
|
||||
- (void)testInitTrainingSessionWithEval {
|
||||
NSError* error = nil;
|
||||
|
||||
// check that inputNames contains input-0
|
||||
NSArray<NSString*>* inputNames = [self.session getEvalInputNamesWithError:&error];
|
||||
ORTAssertNullableResultSuccessful(inputNames, error);
|
||||
|
||||
XCTAssertTrue(inputNames.count > 0);
|
||||
XCTAssertTrue([inputNames containsObject:@"input-0"]);
|
||||
|
||||
// check that outNames contains onnx::loss::21273
|
||||
NSArray<NSString*>* outputNames = [self.session getEvalOutputNamesWithError:&error];
|
||||
ORTAssertNullableResultSuccessful(outputNames, error);
|
||||
|
||||
XCTAssertTrue(outputNames.count > 0);
|
||||
XCTAssertTrue([outputNames containsObject:@"onnx::loss::21273"]);
|
||||
}
|
||||
|
||||
- (void)runTrainStep {
|
||||
// load input and expected output
|
||||
NSError* error = nil;
|
||||
NSMutableData* expectedOutput = [ORTTrainingSessionTest loadTensorDataFromFile:[ORTTrainingSessionTest
|
||||
getFilePathFromName:@"loss_1.out"]
|
||||
skipHeader:YES];
|
||||
|
||||
NSMutableData* input = [ORTTrainingSessionTest loadTensorDataFromFile:[ORTTrainingSessionTest
|
||||
getFilePathFromName:@"input-0.in"]
|
||||
skipHeader:YES];
|
||||
|
||||
int32_t labels[] = {1, 1};
|
||||
|
||||
// create ORTValue array for input and labels
|
||||
NSMutableArray<ORTValue*>* inputValues = [NSMutableArray array];
|
||||
|
||||
ORTValue* inputTensor = [[ORTValue alloc] initWithTensorData:input
|
||||
elementType:ORTTensorElementDataTypeFloat
|
||||
shape:@[ @2, @784 ]
|
||||
error:&error];
|
||||
ORTAssertNullableResultSuccessful(inputTensor, error);
|
||||
[inputValues addObject:inputTensor];
|
||||
|
||||
ORTValue* labelTensor = [[ORTValue alloc] initWithTensorData:[NSMutableData dataWithBytes:labels
|
||||
length:sizeof(labels)]
|
||||
elementType:ORTTensorElementDataTypeInt32
|
||||
shape:@[ @2 ]
|
||||
error:&error];
|
||||
|
||||
ORTAssertNullableResultSuccessful(labelTensor, error);
|
||||
[inputValues addObject:labelTensor];
|
||||
|
||||
NSArray<ORTValue*>* outputs = [self.session trainStepWithInputValues:inputValues error:&error];
|
||||
ORTAssertNullableResultSuccessful(outputs, error);
|
||||
XCTAssertTrue(outputs.count > 0);
|
||||
|
||||
BOOL result = [self.session lazyResetGradWithError:&error];
|
||||
ORTAssertBoolResultSuccessful(result, error);
|
||||
|
||||
outputs = [self.session trainStepWithInputValues:inputValues error:&error];
|
||||
ORTAssertNullableResultSuccessful(outputs, error);
|
||||
XCTAssertTrue(outputs.count > 0);
|
||||
|
||||
ORTValue* outputValue = outputs[0];
|
||||
ORTValueTypeInfo* typeInfo = [outputValue typeInfoWithError:&error];
|
||||
ORTAssertNullableResultSuccessful(typeInfo, error);
|
||||
XCTAssertEqual(typeInfo.type, ORTValueTypeTensor);
|
||||
XCTAssertNotNil(typeInfo.tensorTypeAndShapeInfo);
|
||||
|
||||
ORTTensorTypeAndShapeInfo* tensorInfo = [outputValue tensorTypeAndShapeInfoWithError:&error];
|
||||
ORTAssertNullableResultSuccessful(tensorInfo, error);
|
||||
XCTAssertEqual(tensorInfo.elementType, ORTTensorElementDataTypeFloat);
|
||||
|
||||
NSMutableData* tensorData = [outputValue tensorDataWithError:&error];
|
||||
ORTAssertNullableResultSuccessful(tensorData, error);
|
||||
ORTAssertEqualFloatArrays(test_utils::getFloatArrayFromData(tensorData),
|
||||
test_utils::getFloatArrayFromData(expectedOutput));
|
||||
}
|
||||
|
||||
- (void)testTrainStepOutput {
|
||||
[self runTrainStep];
|
||||
}
|
||||
|
||||
- (void)testOptimizerStep {
|
||||
// load input and expected output
|
||||
NSError* error = nil;
|
||||
NSMutableData* expectedOutput1 = [ORTTrainingSessionTest loadTensorDataFromFile:[ORTTrainingSessionTest
|
||||
getFilePathFromName:@"loss_1.out"]
|
||||
skipHeader:YES];
|
||||
|
||||
NSMutableData* expectedOutput2 = [ORTTrainingSessionTest loadTensorDataFromFile:[ORTTrainingSessionTest
|
||||
getFilePathFromName:@"loss_2.out"]
|
||||
skipHeader:YES];
|
||||
|
||||
NSMutableData* input = [ORTTrainingSessionTest loadTensorDataFromFile:[ORTTrainingSessionTest
|
||||
getFilePathFromName:@"input-0.in"]
|
||||
skipHeader:YES];
|
||||
|
||||
int32_t labels[] = {1, 1};
|
||||
|
||||
// create ORTValue array for input and labels
|
||||
NSMutableArray<ORTValue*>* inputValues = [NSMutableArray array];
|
||||
|
||||
ORTValue* inputTensor = [[ORTValue alloc] initWithTensorData:input
|
||||
elementType:ORTTensorElementDataTypeFloat
|
||||
shape:@[ @2, @784 ]
|
||||
error:&error];
|
||||
ORTAssertNullableResultSuccessful(inputTensor, error);
|
||||
[inputValues addObject:inputTensor];
|
||||
|
||||
ORTValue* labelTensor = [[ORTValue alloc] initWithTensorData:[NSMutableData dataWithBytes:labels
|
||||
length:sizeof(labels)]
|
||||
elementType:ORTTensorElementDataTypeInt32
|
||||
shape:@[ @2 ]
|
||||
error:&error];
|
||||
ORTAssertNullableResultSuccessful(labelTensor, error);
|
||||
[inputValues addObject:labelTensor];
|
||||
|
||||
// run train step, optimizer steps and check loss
|
||||
NSArray<ORTValue*>* outputs = [self.session trainStepWithInputValues:inputValues error:&error];
|
||||
ORTAssertNullableResultSuccessful(outputs, error);
|
||||
|
||||
NSMutableData* loss = [outputs[0] tensorDataWithError:&error];
|
||||
ORTAssertNullableResultSuccessful(loss, error);
|
||||
ORTAssertEqualFloatArrays(test_utils::getFloatArrayFromData(loss),
|
||||
test_utils::getFloatArrayFromData(expectedOutput1));
|
||||
|
||||
BOOL result = [self.session lazyResetGradWithError:&error];
|
||||
ORTAssertBoolResultSuccessful(result, error);
|
||||
|
||||
outputs = [self.session trainStepWithInputValues:inputValues error:&error];
|
||||
ORTAssertNullableResultSuccessful(outputs, error);
|
||||
|
||||
loss = [outputs[0] tensorDataWithError:&error];
|
||||
ORTAssertNullableResultSuccessful(loss, error);
|
||||
ORTAssertEqualFloatArrays(test_utils::getFloatArrayFromData(loss),
|
||||
test_utils::getFloatArrayFromData(expectedOutput1));
|
||||
|
||||
result = [self.session optimizerStepWithError:&error];
|
||||
ORTAssertBoolResultSuccessful(result, error);
|
||||
|
||||
outputs = [self.session trainStepWithInputValues:inputValues error:&error];
|
||||
ORTAssertNullableResultSuccessful(outputs, error);
|
||||
loss = [outputs[0] tensorDataWithError:&error];
|
||||
ORTAssertNullableResultSuccessful(loss, error);
|
||||
ORTAssertEqualFloatArrays(test_utils::getFloatArrayFromData(loss),
|
||||
test_utils::getFloatArrayFromData(expectedOutput2));
|
||||
}
|
||||
|
||||
- (void)testSetLearningRate {
|
||||
NSError* error = nil;
|
||||
|
||||
float learningRate = 0.1f;
|
||||
BOOL result = [self.session setLearningRate:learningRate error:&error];
|
||||
ORTAssertBoolResultSuccessful(result, error);
|
||||
|
||||
float actualLearningRate = [self.session getLearningRateWithError:&error];
|
||||
ORTAssertEqualFloatAndNoError(learningRate, actualLearningRate, error);
|
||||
}
|
||||
|
||||
- (void)testLinearLRScheduler {
|
||||
NSError* error = nil;
|
||||
|
||||
float learningRate = 0.1f;
|
||||
BOOL result = [self.session registerLinearLRSchedulerWithWarmupStepCount:2
|
||||
totalStepCount:4
|
||||
initialLr:learningRate
|
||||
error:&error];
|
||||
|
||||
ORTAssertBoolResultSuccessful(result, error);
|
||||
|
||||
[self runTrainStep];
|
||||
|
||||
result = [self.session optimizerStepWithError:&error];
|
||||
ORTAssertBoolResultSuccessful(result, error);
|
||||
result = [self.session schedulerStepWithError:&error];
|
||||
ORTAssertBoolResultSuccessful(result, error);
|
||||
ORTAssertEqualFloatAndNoError(0.05f, [self.session getLearningRateWithError:&error], error);
|
||||
|
||||
result = [self.session optimizerStepWithError:&error];
|
||||
ORTAssertBoolResultSuccessful(result, error);
|
||||
result = [self.session schedulerStepWithError:&error];
|
||||
ORTAssertBoolResultSuccessful(result, error);
|
||||
ORTAssertEqualFloatAndNoError(0.1f, [self.session getLearningRateWithError:&error], error);
|
||||
|
||||
result = [self.session optimizerStepWithError:&error];
|
||||
ORTAssertBoolResultSuccessful(result, error);
|
||||
result = [self.session schedulerStepWithError:&error];
|
||||
ORTAssertBoolResultSuccessful(result, error);
|
||||
ORTAssertEqualFloatAndNoError(0.05f, [self.session getLearningRateWithError:&error], error);
|
||||
|
||||
result = [self.session optimizerStepWithError:&error];
|
||||
ORTAssertBoolResultSuccessful(result, error);
|
||||
result = [self.session schedulerStepWithError:&error];
|
||||
ORTAssertBoolResultSuccessful(result, error);
|
||||
ORTAssertEqualFloatAndNoError(0.0f, [self.session getLearningRateWithError:&error], error);
|
||||
}
|
||||
|
||||
- (void)testExportModelForInference {
|
||||
NSError* error = nil;
|
||||
|
||||
NSString* inferenceModelPath = [test_utils::createTemporaryDirectory(self)
|
||||
stringByAppendingPathComponent:@"inference_model.onnx"];
|
||||
XCTAssertNotNil(inferenceModelPath);
|
||||
|
||||
NSArray<NSString*>* graphOutputNames = [NSArray arrayWithObjects:@"output-0", nil];
|
||||
|
||||
BOOL result = [self.session exportModelForInferenceWithOutputPath:inferenceModelPath
|
||||
graphOutputNames:graphOutputNames
|
||||
error:&error];
|
||||
|
||||
ORTAssertBoolResultSuccessful(result, error);
|
||||
XCTAssertTrue([[NSFileManager defaultManager] fileExistsAtPath:inferenceModelPath]);
|
||||
|
||||
[self addTeardownBlock:^{
|
||||
NSError* error = nil;
|
||||
[[NSFileManager defaultManager] removeItemAtPath:inferenceModelPath error:&error];
|
||||
}];
|
||||
}
|
||||
|
||||
- (void)testToBuffer {
|
||||
NSError* error = nil;
|
||||
ORTValue* buffer = [self.session toBufferWithTrainable:YES error:&error];
|
||||
ORTAssertNullableResultSuccessful(buffer, error);
|
||||
|
||||
ORTValueTypeInfo* typeInfo = [buffer typeInfoWithError:&error];
|
||||
ORTAssertNullableResultSuccessful(typeInfo, error);
|
||||
XCTAssertEqual(typeInfo.type, ORTValueTypeTensor);
|
||||
XCTAssertNotNil(typeInfo.tensorTypeAndShapeInfo);
|
||||
}
|
||||
|
||||
- (void)testFromBuffer {
|
||||
NSError* error = nil;
|
||||
|
||||
ORTValue* buffer = [self.session toBufferWithTrainable:YES error:&error];
|
||||
ORTAssertNullableResultSuccessful(buffer, error);
|
||||
|
||||
BOOL result = [self.session fromBufferWithValue:buffer error:&error];
|
||||
ORTAssertBoolResultSuccessful(result, error);
|
||||
}
|
||||
|
||||
- (void)tearDown {
|
||||
_session = nil;
|
||||
_checkpoint = nil;
|
||||
_ortEnv = nil;
|
||||
|
||||
[super tearDown];
|
||||
}
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
||||
|
||||
#endif // ENABLE_TRAINING_APIS
|
||||
29
objectivec/test/ort_training_utils_test.mm
Normal file
29
objectivec/test/ort_training_utils_test.mm
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifdef ENABLE_TRAINING_APIS
|
||||
#import <XCTest/XCTest.h>
|
||||
#import "ort_training_session.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
@interface ORTTrainingUtilsTest : XCTestCase
|
||||
@end
|
||||
|
||||
@implementation ORTTrainingUtilsTest
|
||||
|
||||
- (void)setUp {
|
||||
[super setUp];
|
||||
|
||||
self.continueAfterFailure = NO;
|
||||
}
|
||||
|
||||
- (void)testSetSeed {
|
||||
ORTSetSeed(2718);
|
||||
}
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
||||
|
||||
#endif // ENABLE_TRAINING_APIS
|
||||
21
objectivec/test/test_utils.h
Normal file
21
objectivec/test/test_utils.h
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
#import <XCTest/XCTest.h>
|
||||
|
||||
#import "ort_session.h"
|
||||
#import "ort_env.h"
|
||||
#import "ort_value.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
namespace test_utils {
|
||||
|
||||
NSString* _Nullable createTemporaryDirectory(XCTestCase* testCase);
|
||||
|
||||
NSArray<NSNumber*>* getFloatArrayFromData(NSData* data);
|
||||
|
||||
} // namespace test_utils
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
||||
44
objectivec/test/test_utils.mm
Normal file
44
objectivec/test/test_utils.mm
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#import "test_utils.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
namespace test_utils {
|
||||
|
||||
NSString* createTemporaryDirectory(XCTestCase* testCase) {
|
||||
NSString* temporaryDirectory = NSTemporaryDirectory();
|
||||
NSString* directoryPath = [temporaryDirectory stringByAppendingPathComponent:@"ort-objective-c-test"];
|
||||
|
||||
NSError* error = nil;
|
||||
[[NSFileManager defaultManager] createDirectoryAtPath:directoryPath
|
||||
withIntermediateDirectories:YES
|
||||
attributes:nil
|
||||
error:&error];
|
||||
|
||||
XCTAssertNil(error, @"Error creating temporary directory: %@", error.localizedDescription);
|
||||
|
||||
// add teardown block to delete the temporary directory
|
||||
[testCase addTeardownBlock:^{
|
||||
NSError* error = nil;
|
||||
[[NSFileManager defaultManager] removeItemAtPath:directoryPath error:&error];
|
||||
XCTAssertNil(error, @"Error removing temporary directory: %@", error.localizedDescription);
|
||||
}];
|
||||
|
||||
return directoryPath;
|
||||
}
|
||||
|
||||
NSArray<NSNumber*>* getFloatArrayFromData(NSData* data) {
|
||||
NSMutableArray<NSNumber*>* array = [NSMutableArray array];
|
||||
float value;
|
||||
for (size_t i = 0; i < data.length / sizeof(float); ++i) {
|
||||
[data getBytes:&value range:NSMakeRange(i * sizeof(float), sizeof(float))];
|
||||
[array addObject:[NSNumber numberWithFloat:value]];
|
||||
}
|
||||
return array;
|
||||
}
|
||||
|
||||
} // namespace test_utils
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
||||
|
|
@ -194,7 +194,6 @@ class TrainingSession : public detail::Base<OrtTrainingSession> {
|
|||
* \param[in] input_values The user inputs to the training model.
|
||||
* \return A std::vector of Ort::Value objects that represents the output of the forward pass of the training model.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
std::vector<Value> TrainStep(const std::vector<Value>& input_values);
|
||||
|
|
|
|||
Loading…
Reference in a new issue