diff --git a/cmake/onnxruntime_objectivec.cmake b/cmake/onnxruntime_objectivec.cmake index 1051607731..494f28f403 100644 --- a/cmake/onnxruntime_objectivec.cmake +++ b/cmake/onnxruntime_objectivec.cmake @@ -40,6 +40,16 @@ file(GLOB onnxruntime_objc_srcs CONFIGURE_DEPENDS "${OBJC_ROOT}/*.m" "${OBJC_ROOT}/*.mm") +if(NOT onnxruntime_ENABLE_TRAINING_APIS) + list(REMOVE_ITEM onnxruntime_objc_headers + "${OBJC_ROOT}/include/ort_checkpoint.h") + + list(REMOVE_ITEM onnxruntime_objc_srcs + "${OBJC_ROOT}/ort_checkpoint_internal.h" + "${OBJC_ROOT}/ort_checkpoint.mm") +endif() + + source_group(TREE "${OBJC_ROOT}" FILES ${onnxruntime_objc_headers} ${onnxruntime_objc_srcs}) @@ -61,6 +71,13 @@ if(onnxruntime_USE_COREML) "${ONNXRUNTIME_INCLUDE_DIR}/core/providers/coreml") endif() +if (onnxruntime_ENABLE_TRAINING_APIS) + target_include_directories(onnxruntime_objc + PRIVATE + "${ORTTRAINING_SOURCE_DIR}/training_api/include/") + +endif() + find_library(FOUNDATION_LIB Foundation REQUIRED) target_link_libraries(onnxruntime_objc @@ -105,6 +122,12 @@ if(onnxruntime_BUILD_UNIT_TESTS) "${OBJC_ROOT}/test/*.m" "${OBJC_ROOT}/test/*.mm") + if(NOT onnxruntime_ENABLE_TRAINING_APIS) + list(REMOVE_ITEM onnxruntime_objc_test_srcs + "${OBJC_ROOT}/test/ort_checkpoint_test.mm") + + endif() + source_group(TREE "${OBJC_ROOT}" FILES ${onnxruntime_objc_test_srcs}) xctest_add_bundle(onnxruntime_objc_test onnxruntime_objc @@ -124,6 +147,7 @@ if(onnxruntime_BUILD_UNIT_TESTS) add_custom_command(TARGET onnxruntime_objc_test POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_directory "${OBJC_ROOT}/test/testdata" + "${ONNXRUNTIME_ROOT}/test/testdata/training_api" "$/Resources") xctest_add_test(XCTest.onnxruntime_objc_test onnxruntime_objc_test) diff --git a/objectivec/cxx_api.h b/objectivec/cxx_api.h index 3e4821c24a..26acfb8b86 100644 --- a/objectivec/cxx_api.h +++ b/objectivec/cxx_api.h @@ -11,29 +11,28 @@ #endif // defined(__clang__) // paths are different when building the Swift Package Manager package as the headers come from the iOS pod archive +// clang-format off +#define STRINGIFY(x) #x #ifdef SPM_BUILD -#include "onnxruntime/onnxruntime_c_api.h" -#include "onnxruntime/onnxruntime_cxx_api.h" - -#if __has_include("onnxruntime/coreml_provider_factory.h") -#define ORT_OBJC_API_COREML_EP_AVAILABLE 1 -#include "onnxruntime/coreml_provider_factory.h" +#define ORT_C_CXX_HEADER_FILE_PATH(x) STRINGIFY(onnxruntime/x) #else -#define ORT_OBJC_API_COREML_EP_AVAILABLE 0 +#define ORT_C_CXX_HEADER_FILE_PATH(x) STRINGIFY(x) +#endif +// clang-format on + +#ifndef ENABLE_TRAINING_APIS +#include ORT_C_CXX_HEADER_FILE_PATH(onnxruntime_c_api.h) +#include ORT_C_CXX_HEADER_FILE_PATH(onnxruntime_cxx_api.h) +#else +#include ORT_C_CXX_HEADER_FILE_PATH(onnxruntime_training_c_api.h) +#include ORT_C_CXX_HEADER_FILE_PATH(onnxruntime_training_cxx_api.h) #endif -#else -#include "onnxruntime_c_api.h" -#include "onnxruntime_cxx_api.h" - -#if __has_include("coreml_provider_factory.h") +#if __has_include(ORT_C_CXX_HEADER_FILE_PATH(coreml_provider_factory.h)) #define ORT_OBJC_API_COREML_EP_AVAILABLE 1 -#include "coreml_provider_factory.h" +#include ORT_C_CXX_HEADER_FILE_PATH(coreml_provider_factory.h) #else #define ORT_OBJC_API_COREML_EP_AVAILABLE 0 - -#endif - #endif #if defined(__clang__) diff --git a/objectivec/error_utils.h b/objectivec/error_utils.h index 274e74aec1..4df71ce38d 100644 --- a/objectivec/error_utils.h +++ b/objectivec/error_utils.h @@ -10,6 +10,7 @@ NS_ASSUME_NONNULL_BEGIN void ORTSaveCodeAndDescriptionToError(int code, const char* description, NSError** error); +void ORTSaveCodeAndDescriptionToError(int code, NSString* description, NSError** error); void ORTSaveOrtExceptionToError(const Ort::Exception& e, NSError** error); void ORTSaveExceptionToError(const std::exception& e, NSError** error); diff --git a/objectivec/error_utils.mm b/objectivec/error_utils.mm index 91863262ca..335cf8894d 100644 --- a/objectivec/error_utils.mm +++ b/objectivec/error_utils.mm @@ -18,6 +18,14 @@ void ORTSaveCodeAndDescriptionToError(int code, const char* descriptionCstr, NSE userInfo:@{NSLocalizedDescriptionKey : description}]; } +void ORTSaveCodeAndDescriptionToError(int code, NSString* description, NSError** error) { + if (!error) return; + + *error = [NSError errorWithDomain:kOrtErrorDomain + code:code + userInfo:@{NSLocalizedDescriptionKey : description}]; +} + void ORTSaveOrtExceptionToError(const Ort::Exception& e, NSError** error) { ORTSaveCodeAndDescriptionToError(e.GetOrtErrorCode(), e.what(), error); } diff --git a/objectivec/include/ort_checkpoint.h b/objectivec/include/ort_checkpoint.h new file mode 100644 index 0000000000..8c55f21943 --- /dev/null +++ b/objectivec/include/ort_checkpoint.h @@ -0,0 +1,125 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef ENABLE_TRAINING_APIS +#import +#include + +NS_ASSUME_NONNULL_BEGIN + +/** + * An ORT checkpoint is a snapshot of the state of a model at a given point in time. + * + * This class holds the entire training session state that includes model parameters, + * their gradients, optimizer parameters, and user properties. The ORTTrainingSession leverages the + * ORTCheckpointState by accessing and updating the contained training state. + * + * Available since v1.16.0. + * + * @note Note that the training session created with a checkpoint state uses this state to store the entire training + * state (including model parameters, its gradients, the optimizer states and the properties). The ORTTraingSession + * does not hold a copy of the checkpoint state. Therefore, it is required that the checkpoint state outlive the + * lifetime of the training session. + */ +@interface ORTCheckpoint : NSObject + +- (instancetype)init NS_UNAVAILABLE; + +/** + * Creates a checkpoint from directory on disk. + * + * @param path The path to the checkpoint directory. + * @param error Optional error information set if an error occurs. + * @return The instance, or nil if an error occurs. + * + * @warning The construction of the checkpoint state requires instantiation of `ORTEnv`. + * The intialization will fail if the `ORTEnv` is not properly initialized. + */ +- (nullable instancetype)initWithPath:(NSString*)path + error:(NSError**)error NS_DESIGNATED_INITIALIZER; + +/** + * Saves a checkpoint to directory on disk. + * + * @param path The path to the checkpoint directory. + * @param includeOptimizerState Flag to indicate whether to save the optimizer state or not. + * @param error Optional error information set if an error occurs. + * @return Whether the checkpoint was saved successfully. + */ +- (BOOL)saveCheckpointToPath:(NSString*)path + withOptimizerState:(BOOL)includeOptimizerState + error:(NSError**)error; + +/** + * Adds an int property to this checkpoint. + * + * @param name The name of the property. + * @param value The value of the property. + * @param error Optional error information set if an error occurs. + * @return Whether the property was added successfully. + */ +- (BOOL)addIntPropertyWithName:(NSString*)name + value:(int64_t)value + error:(NSError**)error; + +/** + * Adds a float property to this checkpoint. + * + * @param name The name of the property. + * @param value The value of the property. + * @param error Optional error information set if an error occurs. + * @return Whether the property was added successfully. + */ +- (BOOL)addFloatPropertyWithName:(NSString*)name + value:(float)value + error:(NSError**)error; + +/** + * Adds a string property to this checkpoint. + * + * @param name The name of the property. + * @param value The value of the property. + * @param error Optional error information set if an error occurs. + * @return Whether the property was added successfully. + */ + +- (BOOL)addStringPropertyWithName:(NSString*)name + value:(NSString*)value + error:(NSError**)error; + +/** + * Gets an int property from this checkpoint. + * + * @param name The name of the property. + * @param error Optional error information set if an error occurs. + * @return The value of the property or 0 if an error occurs. + */ +- (int64_t)getIntPropertyWithName:(NSString*)name + error:(NSError**)error __attribute__((swift_error(nonnull_error))); + +/** + * Gets a float property from this checkpoint. + * + * @param name The name of the property. + * @param error Optional error information set if an error occurs. + * @return The value of the property or 0.0f if an error occurs. + */ +- (float)getFloatPropertyWithName:(NSString*)name + error:(NSError**)error __attribute__((swift_error(nonnull_error))); + +/** + * + * Gets a string property from this checkpoint. + * + * @param name The name of the property. + * @param error Optional error information set if an error occurs. + * @return The value of the property. + */ +- (nullable NSString*)getStringPropertyWithName:(NSString*)name + error:(NSError**)error __attribute__((swift_error(nonnull_error))); + +@end + +NS_ASSUME_NONNULL_END + +#endif // ENABLE_TRAINING_APIS diff --git a/objectivec/ort_checkpoint.mm b/objectivec/ort_checkpoint.mm new file mode 100644 index 0000000000..ee88e9c9c1 --- /dev/null +++ b/objectivec/ort_checkpoint.mm @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef ENABLE_TRAINING_APIS +#import "ort_checkpoint_internal.h" + +#include +#include +#include +#import "cxx_api.h" + +#import "error_utils.h" + +NS_ASSUME_NONNULL_BEGIN + +@implementation ORTCheckpoint { + std::optional _checkpoint; +} + +- (nullable instancetype)initWithPath:(NSString*)path + error:(NSError**)error { + if ((self = [super init]) == nil) { + return nil; + } + + try { + _checkpoint = Ort::CheckpointState::LoadCheckpoint(path.UTF8String); + return self; + } + ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) +} + +- (BOOL)saveCheckpointToPath:(NSString*)path + withOptimizerState:(BOOL)includeOptimizerState + error:(NSError**)error { + try { + Ort::CheckpointState::SaveCheckpoint([self CXXAPIOrtCheckpoint], path.UTF8String, includeOptimizerState); + return YES; + } + ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) +} + +- (BOOL)addIntPropertyWithName:(NSString*)name + value:(int64_t)value + error:(NSError**)error { + try { + [self CXXAPIOrtCheckpoint].AddProperty(name.UTF8String, value); + return YES; + } + ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) +} + +- (BOOL)addFloatPropertyWithName:(NSString*)name + value:(float)value + error:(NSError**)error { + try { + [self CXXAPIOrtCheckpoint].AddProperty(name.UTF8String, value); + return YES; + } + ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) +} + +- (BOOL)addStringPropertyWithName:(NSString*)name + value:(NSString*)value + error:(NSError**)error { + try { + [self CXXAPIOrtCheckpoint].AddProperty(name.UTF8String, value.UTF8String); + return YES; + } + ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) +} + +- (nullable NSString*)getStringPropertyWithName:(NSString*)name error:(NSError**)error { + try { + Ort::Property value = [self CXXAPIOrtCheckpoint].GetProperty(name.UTF8String); + if (std::string* str = std::get_if(&value)) { + return [NSString stringWithUTF8String:str->c_str()]; + } + ORT_CXX_API_THROW("Property is not a string.", ORT_INVALID_ARGUMENT); + } + ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) +} + +- (int64_t)getIntPropertyWithName:(NSString*)name error:(NSError**)error { + try { + Ort::Property value = [self CXXAPIOrtCheckpoint].GetProperty(name.UTF8String); + if (int64_t* i = std::get_if(&value)) { + return *i; + } + ORT_CXX_API_THROW("Property is not an integer.", ORT_INVALID_ARGUMENT); + } + ORT_OBJC_API_IMPL_CATCH(error, 0) +} + +- (float)getFloatPropertyWithName:(NSString*)name error:(NSError**)error { + try { + Ort::Property value = [self CXXAPIOrtCheckpoint].GetProperty(name.UTF8String); + if (float* f = std::get_if(&value)) { + return *f; + } + ORT_CXX_API_THROW("Property is not a float.", ORT_INVALID_ARGUMENT); + } + ORT_OBJC_API_IMPL_CATCH(error, 0.0f) +} + +- (Ort::CheckpointState&)CXXAPIOrtCheckpoint { + return *_checkpoint; +} + +@end + +NS_ASSUME_NONNULL_END +#endif // ENABLE_TRAINING_APIS diff --git a/objectivec/ort_checkpoint_internal.h b/objectivec/ort_checkpoint_internal.h new file mode 100644 index 0000000000..0001913b6e --- /dev/null +++ b/objectivec/ort_checkpoint_internal.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef ENABLE_TRAINING_APIS +#import "ort_checkpoint.h" + +#import "cxx_api.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface ORTCheckpoint () + +- (Ort::CheckpointState&)CXXAPIOrtCheckpoint; + +@end + +NS_ASSUME_NONNULL_END +#endif // ENABLE_TRAINING_APIS diff --git a/objectivec/test/ort_checkpoint_test.mm b/objectivec/test/ort_checkpoint_test.mm new file mode 100644 index 0000000000..01788a5bc0 --- /dev/null +++ b/objectivec/test/ort_checkpoint_test.mm @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef ENABLE_TRAINING_APIS +#import + +#import "ort_checkpoint.h" +#import "ort_env.h" +#import "test/assertion_utils.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface ORTCheckpointTest : XCTestCase +@property(readonly, nullable) ORTEnv* ortEnv; +@end + +@implementation ORTCheckpointTest + +- (void)setUp { + [super setUp]; + + self.continueAfterFailure = NO; + + NSError* err = nil; + _ortEnv = [[ORTEnv alloc] initWithLoggingLevel:ORTLoggingLevelWarning + error:&err]; + ORTAssertNullableResultSuccessful(_ortEnv, err); +} + +- (NSString*)getCheckpointPath { + NSBundle* bundle = [NSBundle bundleForClass:[ORTCheckpointTest class]]; + NSString* path = [[bundle resourcePath] stringByAppendingPathComponent:@"checkpoint.ckpt"]; + return path; +} + +- (void)testLoadCheckpoint { + NSError* error = nil; + ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[self getCheckpointPath] error:&error]; + ORTAssertNullableResultSuccessful(checkpoint, error); +} + +- (void)testIntProperty { + NSError* error = nil; + // Load checkpoint + ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[self getCheckpointPath] error:&error]; + ORTAssertNullableResultSuccessful(checkpoint, error); + + // Add property + BOOL result = [checkpoint addIntPropertyWithName:@"test" value:314 error:&error]; + ORTAssertBoolResultSuccessful(result, error); + + // Get property + int64_t value = [checkpoint getIntPropertyWithName:@"test" error:&error]; + XCTAssertEqual(value, 314); +} + +- (void)testFloatProperty { + NSError* error = nil; + // Load checkpoint + ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[self getCheckpointPath] error:&error]; + ORTAssertNullableResultSuccessful(checkpoint, error); + + // Add property + BOOL result = [checkpoint addFloatPropertyWithName:@"test" value:3.14f error:&error]; + ORTAssertBoolResultSuccessful(result, error); + + // Get property + float value = [checkpoint getFloatPropertyWithName:@"test" error:&error]; + XCTAssertEqual(value, 3.14f); +} + +- (void)testStringProperty { + NSError* error = nil; + // Load checkpoint + ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[self getCheckpointPath] error:&error]; + ORTAssertNullableResultSuccessful(checkpoint, error); + + // Add property + BOOL result = [checkpoint addStringPropertyWithName:@"test" value:@"hello" error:&error]; + ORTAssertBoolResultSuccessful(result, error); + + // Get property + NSString* value = [checkpoint getStringPropertyWithName:@"test" error:&error]; + XCTAssertEqualObjects(value, @"hello"); +} + +- (void)tearDown { + _ortEnv = nil; + + [super tearDown]; +} + +@end + +NS_ASSUME_NONNULL_END + +#endif // ENABLE_TRAINING_APIS diff --git a/tools/ci_build/github/azure-pipelines/orttraining-mac-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-mac-ci-pipeline.yml index 57e53ad245..f1843b590e 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-mac-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-mac-ci-pipeline.yml @@ -3,5 +3,5 @@ stages: parameters: AllowReleasedOpsetOnly: 0 BuildForAllArchs: false - AdditionalBuildFlags: --enable_training + AdditionalBuildFlags: --enable_training --build_objc WithCache: true