Objective-C binding for ORT training (#16127)

### Description
Implement Objective-C binding for `ORTCheckPoint`. Additionally, 
- Modify `onnxruntime_objectivec.cmake` to only include training header
and sources when training flag is enabled
- Enable objective-c binding for `orttraining-mac-ci-pipeline`

### Motivation and Context
This PR is part of implementing Objective-C bindings for training API.
It implements objective-c binding for ORTCheckPoint class. The
objective-C API closely resembles the C++ API.

**Note**: The test for saving checkpoint is skipped as it requires use
of training session. It will be added when the objective-c binding for
`ORTTrainingSession` is added.
This commit is contained in:
Vrajang Parikh 2023-06-07 14:01:30 -07:00 committed by GitHub
parent bca49d62a0
commit 67f4a4fd16
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 402 additions and 17 deletions

View file

@ -40,6 +40,16 @@ file(GLOB onnxruntime_objc_srcs CONFIGURE_DEPENDS
"${OBJC_ROOT}/*.m"
"${OBJC_ROOT}/*.mm")
if(NOT onnxruntime_ENABLE_TRAINING_APIS)
list(REMOVE_ITEM onnxruntime_objc_headers
"${OBJC_ROOT}/include/ort_checkpoint.h")
list(REMOVE_ITEM onnxruntime_objc_srcs
"${OBJC_ROOT}/ort_checkpoint_internal.h"
"${OBJC_ROOT}/ort_checkpoint.mm")
endif()
source_group(TREE "${OBJC_ROOT}" FILES
${onnxruntime_objc_headers}
${onnxruntime_objc_srcs})
@ -61,6 +71,13 @@ if(onnxruntime_USE_COREML)
"${ONNXRUNTIME_INCLUDE_DIR}/core/providers/coreml")
endif()
if (onnxruntime_ENABLE_TRAINING_APIS)
target_include_directories(onnxruntime_objc
PRIVATE
"${ORTTRAINING_SOURCE_DIR}/training_api/include/")
endif()
find_library(FOUNDATION_LIB Foundation REQUIRED)
target_link_libraries(onnxruntime_objc
@ -105,6 +122,12 @@ if(onnxruntime_BUILD_UNIT_TESTS)
"${OBJC_ROOT}/test/*.m"
"${OBJC_ROOT}/test/*.mm")
if(NOT onnxruntime_ENABLE_TRAINING_APIS)
list(REMOVE_ITEM onnxruntime_objc_test_srcs
"${OBJC_ROOT}/test/ort_checkpoint_test.mm")
endif()
source_group(TREE "${OBJC_ROOT}" FILES ${onnxruntime_objc_test_srcs})
xctest_add_bundle(onnxruntime_objc_test onnxruntime_objc
@ -124,6 +147,7 @@ if(onnxruntime_BUILD_UNIT_TESTS)
add_custom_command(TARGET onnxruntime_objc_test POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_directory
"${OBJC_ROOT}/test/testdata"
"${ONNXRUNTIME_ROOT}/test/testdata/training_api"
"$<TARGET_BUNDLE_CONTENT_DIR:onnxruntime_objc_test>/Resources")
xctest_add_test(XCTest.onnxruntime_objc_test onnxruntime_objc_test)

View file

@ -11,29 +11,28 @@
#endif // defined(__clang__)
// paths are different when building the Swift Package Manager package as the headers come from the iOS pod archive
// clang-format off
#define STRINGIFY(x) #x
#ifdef SPM_BUILD
#include "onnxruntime/onnxruntime_c_api.h"
#include "onnxruntime/onnxruntime_cxx_api.h"
#if __has_include("onnxruntime/coreml_provider_factory.h")
#define ORT_OBJC_API_COREML_EP_AVAILABLE 1
#include "onnxruntime/coreml_provider_factory.h"
#define ORT_C_CXX_HEADER_FILE_PATH(x) STRINGIFY(onnxruntime/x)
#else
#define ORT_OBJC_API_COREML_EP_AVAILABLE 0
#define ORT_C_CXX_HEADER_FILE_PATH(x) STRINGIFY(x)
#endif
// clang-format on
#ifndef ENABLE_TRAINING_APIS
#include ORT_C_CXX_HEADER_FILE_PATH(onnxruntime_c_api.h)
#include ORT_C_CXX_HEADER_FILE_PATH(onnxruntime_cxx_api.h)
#else
#include ORT_C_CXX_HEADER_FILE_PATH(onnxruntime_training_c_api.h)
#include ORT_C_CXX_HEADER_FILE_PATH(onnxruntime_training_cxx_api.h)
#endif
#else
#include "onnxruntime_c_api.h"
#include "onnxruntime_cxx_api.h"
#if __has_include("coreml_provider_factory.h")
#if __has_include(ORT_C_CXX_HEADER_FILE_PATH(coreml_provider_factory.h))
#define ORT_OBJC_API_COREML_EP_AVAILABLE 1
#include "coreml_provider_factory.h"
#include ORT_C_CXX_HEADER_FILE_PATH(coreml_provider_factory.h)
#else
#define ORT_OBJC_API_COREML_EP_AVAILABLE 0
#endif
#endif
#if defined(__clang__)

View file

@ -10,6 +10,7 @@
NS_ASSUME_NONNULL_BEGIN
void ORTSaveCodeAndDescriptionToError(int code, const char* description, NSError** error);
void ORTSaveCodeAndDescriptionToError(int code, NSString* description, NSError** error);
void ORTSaveOrtExceptionToError(const Ort::Exception& e, NSError** error);
void ORTSaveExceptionToError(const std::exception& e, NSError** error);

View file

@ -18,6 +18,14 @@ void ORTSaveCodeAndDescriptionToError(int code, const char* descriptionCstr, NSE
userInfo:@{NSLocalizedDescriptionKey : description}];
}
void ORTSaveCodeAndDescriptionToError(int code, NSString* description, NSError** error) {
if (!error) return;
*error = [NSError errorWithDomain:kOrtErrorDomain
code:code
userInfo:@{NSLocalizedDescriptionKey : description}];
}
void ORTSaveOrtExceptionToError(const Ort::Exception& e, NSError** error) {
ORTSaveCodeAndDescriptionToError(e.GetOrtErrorCode(), e.what(), error);
}

View file

@ -0,0 +1,125 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef ENABLE_TRAINING_APIS
#import <Foundation/Foundation.h>
#include <stdint.h>
NS_ASSUME_NONNULL_BEGIN
/**
* An ORT checkpoint is a snapshot of the state of a model at a given point in time.
*
* This class holds the entire training session state that includes model parameters,
* their gradients, optimizer parameters, and user properties. The ORTTrainingSession leverages the
* ORTCheckpointState by accessing and updating the contained training state.
*
* Available since v1.16.0.
*
* @note Note that the training session created with a checkpoint state uses this state to store the entire training
* state (including model parameters, its gradients, the optimizer states and the properties). The ORTTraingSession
* does not hold a copy of the checkpoint state. Therefore, it is required that the checkpoint state outlive the
* lifetime of the training session.
*/
@interface ORTCheckpoint : NSObject
- (instancetype)init NS_UNAVAILABLE;
/**
* Creates a checkpoint from directory on disk.
*
* @param path The path to the checkpoint directory.
* @param error Optional error information set if an error occurs.
* @return The instance, or nil if an error occurs.
*
* @warning The construction of the checkpoint state requires instantiation of `ORTEnv`.
* The intialization will fail if the `ORTEnv` is not properly initialized.
*/
- (nullable instancetype)initWithPath:(NSString*)path
error:(NSError**)error NS_DESIGNATED_INITIALIZER;
/**
* Saves a checkpoint to directory on disk.
*
* @param path The path to the checkpoint directory.
* @param includeOptimizerState Flag to indicate whether to save the optimizer state or not.
* @param error Optional error information set if an error occurs.
* @return Whether the checkpoint was saved successfully.
*/
- (BOOL)saveCheckpointToPath:(NSString*)path
withOptimizerState:(BOOL)includeOptimizerState
error:(NSError**)error;
/**
* Adds an int property to this checkpoint.
*
* @param name The name of the property.
* @param value The value of the property.
* @param error Optional error information set if an error occurs.
* @return Whether the property was added successfully.
*/
- (BOOL)addIntPropertyWithName:(NSString*)name
value:(int64_t)value
error:(NSError**)error;
/**
* Adds a float property to this checkpoint.
*
* @param name The name of the property.
* @param value The value of the property.
* @param error Optional error information set if an error occurs.
* @return Whether the property was added successfully.
*/
- (BOOL)addFloatPropertyWithName:(NSString*)name
value:(float)value
error:(NSError**)error;
/**
* Adds a string property to this checkpoint.
*
* @param name The name of the property.
* @param value The value of the property.
* @param error Optional error information set if an error occurs.
* @return Whether the property was added successfully.
*/
- (BOOL)addStringPropertyWithName:(NSString*)name
value:(NSString*)value
error:(NSError**)error;
/**
* Gets an int property from this checkpoint.
*
* @param name The name of the property.
* @param error Optional error information set if an error occurs.
* @return The value of the property or 0 if an error occurs.
*/
- (int64_t)getIntPropertyWithName:(NSString*)name
error:(NSError**)error __attribute__((swift_error(nonnull_error)));
/**
* Gets a float property from this checkpoint.
*
* @param name The name of the property.
* @param error Optional error information set if an error occurs.
* @return The value of the property or 0.0f if an error occurs.
*/
- (float)getFloatPropertyWithName:(NSString*)name
error:(NSError**)error __attribute__((swift_error(nonnull_error)));
/**
*
* Gets a string property from this checkpoint.
*
* @param name The name of the property.
* @param error Optional error information set if an error occurs.
* @return The value of the property.
*/
- (nullable NSString*)getStringPropertyWithName:(NSString*)name
error:(NSError**)error __attribute__((swift_error(nonnull_error)));
@end
NS_ASSUME_NONNULL_END
#endif // ENABLE_TRAINING_APIS

View file

@ -0,0 +1,113 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef ENABLE_TRAINING_APIS
#import "ort_checkpoint_internal.h"
#include <optional>
#include <string>
#include <variant>
#import "cxx_api.h"
#import "error_utils.h"
NS_ASSUME_NONNULL_BEGIN
@implementation ORTCheckpoint {
std::optional<Ort::CheckpointState> _checkpoint;
}
- (nullable instancetype)initWithPath:(NSString*)path
error:(NSError**)error {
if ((self = [super init]) == nil) {
return nil;
}
try {
_checkpoint = Ort::CheckpointState::LoadCheckpoint(path.UTF8String);
return self;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
- (BOOL)saveCheckpointToPath:(NSString*)path
withOptimizerState:(BOOL)includeOptimizerState
error:(NSError**)error {
try {
Ort::CheckpointState::SaveCheckpoint([self CXXAPIOrtCheckpoint], path.UTF8String, includeOptimizerState);
return YES;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}
- (BOOL)addIntPropertyWithName:(NSString*)name
value:(int64_t)value
error:(NSError**)error {
try {
[self CXXAPIOrtCheckpoint].AddProperty(name.UTF8String, value);
return YES;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}
- (BOOL)addFloatPropertyWithName:(NSString*)name
value:(float)value
error:(NSError**)error {
try {
[self CXXAPIOrtCheckpoint].AddProperty(name.UTF8String, value);
return YES;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}
- (BOOL)addStringPropertyWithName:(NSString*)name
value:(NSString*)value
error:(NSError**)error {
try {
[self CXXAPIOrtCheckpoint].AddProperty(name.UTF8String, value.UTF8String);
return YES;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}
- (nullable NSString*)getStringPropertyWithName:(NSString*)name error:(NSError**)error {
try {
Ort::Property value = [self CXXAPIOrtCheckpoint].GetProperty(name.UTF8String);
if (std::string* str = std::get_if<std::string>(&value)) {
return [NSString stringWithUTF8String:str->c_str()];
}
ORT_CXX_API_THROW("Property is not a string.", ORT_INVALID_ARGUMENT);
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
- (int64_t)getIntPropertyWithName:(NSString*)name error:(NSError**)error {
try {
Ort::Property value = [self CXXAPIOrtCheckpoint].GetProperty(name.UTF8String);
if (int64_t* i = std::get_if<int64_t>(&value)) {
return *i;
}
ORT_CXX_API_THROW("Property is not an integer.", ORT_INVALID_ARGUMENT);
}
ORT_OBJC_API_IMPL_CATCH(error, 0)
}
- (float)getFloatPropertyWithName:(NSString*)name error:(NSError**)error {
try {
Ort::Property value = [self CXXAPIOrtCheckpoint].GetProperty(name.UTF8String);
if (float* f = std::get_if<float>(&value)) {
return *f;
}
ORT_CXX_API_THROW("Property is not a float.", ORT_INVALID_ARGUMENT);
}
ORT_OBJC_API_IMPL_CATCH(error, 0.0f)
}
- (Ort::CheckpointState&)CXXAPIOrtCheckpoint {
return *_checkpoint;
}
@end
NS_ASSUME_NONNULL_END
#endif // ENABLE_TRAINING_APIS

View file

@ -0,0 +1,18 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef ENABLE_TRAINING_APIS
#import "ort_checkpoint.h"
#import "cxx_api.h"
NS_ASSUME_NONNULL_BEGIN
@interface ORTCheckpoint ()
- (Ort::CheckpointState&)CXXAPIOrtCheckpoint;
@end
NS_ASSUME_NONNULL_END
#endif // ENABLE_TRAINING_APIS

View file

@ -0,0 +1,97 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef ENABLE_TRAINING_APIS
#import <XCTest/XCTest.h>
#import "ort_checkpoint.h"
#import "ort_env.h"
#import "test/assertion_utils.h"
NS_ASSUME_NONNULL_BEGIN
@interface ORTCheckpointTest : XCTestCase
@property(readonly, nullable) ORTEnv* ortEnv;
@end
@implementation ORTCheckpointTest
- (void)setUp {
[super setUp];
self.continueAfterFailure = NO;
NSError* err = nil;
_ortEnv = [[ORTEnv alloc] initWithLoggingLevel:ORTLoggingLevelWarning
error:&err];
ORTAssertNullableResultSuccessful(_ortEnv, err);
}
- (NSString*)getCheckpointPath {
NSBundle* bundle = [NSBundle bundleForClass:[ORTCheckpointTest class]];
NSString* path = [[bundle resourcePath] stringByAppendingPathComponent:@"checkpoint.ckpt"];
return path;
}
- (void)testLoadCheckpoint {
NSError* error = nil;
ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[self getCheckpointPath] error:&error];
ORTAssertNullableResultSuccessful(checkpoint, error);
}
- (void)testIntProperty {
NSError* error = nil;
// Load checkpoint
ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[self getCheckpointPath] error:&error];
ORTAssertNullableResultSuccessful(checkpoint, error);
// Add property
BOOL result = [checkpoint addIntPropertyWithName:@"test" value:314 error:&error];
ORTAssertBoolResultSuccessful(result, error);
// Get property
int64_t value = [checkpoint getIntPropertyWithName:@"test" error:&error];
XCTAssertEqual(value, 314);
}
- (void)testFloatProperty {
NSError* error = nil;
// Load checkpoint
ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[self getCheckpointPath] error:&error];
ORTAssertNullableResultSuccessful(checkpoint, error);
// Add property
BOOL result = [checkpoint addFloatPropertyWithName:@"test" value:3.14f error:&error];
ORTAssertBoolResultSuccessful(result, error);
// Get property
float value = [checkpoint getFloatPropertyWithName:@"test" error:&error];
XCTAssertEqual(value, 3.14f);
}
- (void)testStringProperty {
NSError* error = nil;
// Load checkpoint
ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[self getCheckpointPath] error:&error];
ORTAssertNullableResultSuccessful(checkpoint, error);
// Add property
BOOL result = [checkpoint addStringPropertyWithName:@"test" value:@"hello" error:&error];
ORTAssertBoolResultSuccessful(result, error);
// Get property
NSString* value = [checkpoint getStringPropertyWithName:@"test" error:&error];
XCTAssertEqualObjects(value, @"hello");
}
- (void)tearDown {
_ortEnv = nil;
[super tearDown];
}
@end
NS_ASSUME_NONNULL_END
#endif // ENABLE_TRAINING_APIS

View file

@ -3,5 +3,5 @@ stages:
parameters:
AllowReleasedOpsetOnly: 0
BuildForAllArchs: false
AdditionalBuildFlags: --enable_training
AdditionalBuildFlags: --enable_training --build_objc
WithCache: true