mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Objective-C binding for ORT training (#16127)
### Description Implement Objective-C binding for `ORTCheckPoint`. Additionally, - Modify `onnxruntime_objectivec.cmake` to only include training header and sources when training flag is enabled - Enable objective-c binding for `orttraining-mac-ci-pipeline` ### Motivation and Context This PR is part of implementing Objective-C bindings for training API. It implements objective-c binding for ORTCheckPoint class. The objective-C API closely resembles the C++ API. **Note**: The test for saving checkpoint is skipped as it requires use of training session. It will be added when the objective-c binding for `ORTTrainingSession` is added.
This commit is contained in:
parent
bca49d62a0
commit
67f4a4fd16
9 changed files with 402 additions and 17 deletions
|
|
@ -40,6 +40,16 @@ file(GLOB onnxruntime_objc_srcs CONFIGURE_DEPENDS
|
|||
"${OBJC_ROOT}/*.m"
|
||||
"${OBJC_ROOT}/*.mm")
|
||||
|
||||
if(NOT onnxruntime_ENABLE_TRAINING_APIS)
|
||||
list(REMOVE_ITEM onnxruntime_objc_headers
|
||||
"${OBJC_ROOT}/include/ort_checkpoint.h")
|
||||
|
||||
list(REMOVE_ITEM onnxruntime_objc_srcs
|
||||
"${OBJC_ROOT}/ort_checkpoint_internal.h"
|
||||
"${OBJC_ROOT}/ort_checkpoint.mm")
|
||||
endif()
|
||||
|
||||
|
||||
source_group(TREE "${OBJC_ROOT}" FILES
|
||||
${onnxruntime_objc_headers}
|
||||
${onnxruntime_objc_srcs})
|
||||
|
|
@ -61,6 +71,13 @@ if(onnxruntime_USE_COREML)
|
|||
"${ONNXRUNTIME_INCLUDE_DIR}/core/providers/coreml")
|
||||
endif()
|
||||
|
||||
if (onnxruntime_ENABLE_TRAINING_APIS)
|
||||
target_include_directories(onnxruntime_objc
|
||||
PRIVATE
|
||||
"${ORTTRAINING_SOURCE_DIR}/training_api/include/")
|
||||
|
||||
endif()
|
||||
|
||||
find_library(FOUNDATION_LIB Foundation REQUIRED)
|
||||
|
||||
target_link_libraries(onnxruntime_objc
|
||||
|
|
@ -105,6 +122,12 @@ if(onnxruntime_BUILD_UNIT_TESTS)
|
|||
"${OBJC_ROOT}/test/*.m"
|
||||
"${OBJC_ROOT}/test/*.mm")
|
||||
|
||||
if(NOT onnxruntime_ENABLE_TRAINING_APIS)
|
||||
list(REMOVE_ITEM onnxruntime_objc_test_srcs
|
||||
"${OBJC_ROOT}/test/ort_checkpoint_test.mm")
|
||||
|
||||
endif()
|
||||
|
||||
source_group(TREE "${OBJC_ROOT}" FILES ${onnxruntime_objc_test_srcs})
|
||||
|
||||
xctest_add_bundle(onnxruntime_objc_test onnxruntime_objc
|
||||
|
|
@ -124,6 +147,7 @@ if(onnxruntime_BUILD_UNIT_TESTS)
|
|||
add_custom_command(TARGET onnxruntime_objc_test POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy_directory
|
||||
"${OBJC_ROOT}/test/testdata"
|
||||
"${ONNXRUNTIME_ROOT}/test/testdata/training_api"
|
||||
"$<TARGET_BUNDLE_CONTENT_DIR:onnxruntime_objc_test>/Resources")
|
||||
|
||||
xctest_add_test(XCTest.onnxruntime_objc_test onnxruntime_objc_test)
|
||||
|
|
|
|||
|
|
@ -11,29 +11,28 @@
|
|||
#endif // defined(__clang__)
|
||||
|
||||
// paths are different when building the Swift Package Manager package as the headers come from the iOS pod archive
|
||||
// clang-format off
|
||||
#define STRINGIFY(x) #x
|
||||
#ifdef SPM_BUILD
|
||||
#include "onnxruntime/onnxruntime_c_api.h"
|
||||
#include "onnxruntime/onnxruntime_cxx_api.h"
|
||||
|
||||
#if __has_include("onnxruntime/coreml_provider_factory.h")
|
||||
#define ORT_OBJC_API_COREML_EP_AVAILABLE 1
|
||||
#include "onnxruntime/coreml_provider_factory.h"
|
||||
#define ORT_C_CXX_HEADER_FILE_PATH(x) STRINGIFY(onnxruntime/x)
|
||||
#else
|
||||
#define ORT_OBJC_API_COREML_EP_AVAILABLE 0
|
||||
#define ORT_C_CXX_HEADER_FILE_PATH(x) STRINGIFY(x)
|
||||
#endif
|
||||
// clang-format on
|
||||
|
||||
#ifndef ENABLE_TRAINING_APIS
|
||||
#include ORT_C_CXX_HEADER_FILE_PATH(onnxruntime_c_api.h)
|
||||
#include ORT_C_CXX_HEADER_FILE_PATH(onnxruntime_cxx_api.h)
|
||||
#else
|
||||
#include ORT_C_CXX_HEADER_FILE_PATH(onnxruntime_training_c_api.h)
|
||||
#include ORT_C_CXX_HEADER_FILE_PATH(onnxruntime_training_cxx_api.h)
|
||||
#endif
|
||||
|
||||
#else
|
||||
#include "onnxruntime_c_api.h"
|
||||
#include "onnxruntime_cxx_api.h"
|
||||
|
||||
#if __has_include("coreml_provider_factory.h")
|
||||
#if __has_include(ORT_C_CXX_HEADER_FILE_PATH(coreml_provider_factory.h))
|
||||
#define ORT_OBJC_API_COREML_EP_AVAILABLE 1
|
||||
#include "coreml_provider_factory.h"
|
||||
#include ORT_C_CXX_HEADER_FILE_PATH(coreml_provider_factory.h)
|
||||
#else
|
||||
#define ORT_OBJC_API_COREML_EP_AVAILABLE 0
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
#if defined(__clang__)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
void ORTSaveCodeAndDescriptionToError(int code, const char* description, NSError** error);
|
||||
void ORTSaveCodeAndDescriptionToError(int code, NSString* description, NSError** error);
|
||||
void ORTSaveOrtExceptionToError(const Ort::Exception& e, NSError** error);
|
||||
void ORTSaveExceptionToError(const std::exception& e, NSError** error);
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,14 @@ void ORTSaveCodeAndDescriptionToError(int code, const char* descriptionCstr, NSE
|
|||
userInfo:@{NSLocalizedDescriptionKey : description}];
|
||||
}
|
||||
|
||||
void ORTSaveCodeAndDescriptionToError(int code, NSString* description, NSError** error) {
|
||||
if (!error) return;
|
||||
|
||||
*error = [NSError errorWithDomain:kOrtErrorDomain
|
||||
code:code
|
||||
userInfo:@{NSLocalizedDescriptionKey : description}];
|
||||
}
|
||||
|
||||
void ORTSaveOrtExceptionToError(const Ort::Exception& e, NSError** error) {
|
||||
ORTSaveCodeAndDescriptionToError(e.GetOrtErrorCode(), e.what(), error);
|
||||
}
|
||||
|
|
|
|||
125
objectivec/include/ort_checkpoint.h
Normal file
125
objectivec/include/ort_checkpoint.h
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifdef ENABLE_TRAINING_APIS
|
||||
#import <Foundation/Foundation.h>
|
||||
#include <stdint.h>
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
/**
|
||||
* An ORT checkpoint is a snapshot of the state of a model at a given point in time.
|
||||
*
|
||||
* This class holds the entire training session state that includes model parameters,
|
||||
* their gradients, optimizer parameters, and user properties. The ORTTrainingSession leverages the
|
||||
* ORTCheckpointState by accessing and updating the contained training state.
|
||||
*
|
||||
* Available since v1.16.0.
|
||||
*
|
||||
* @note Note that the training session created with a checkpoint state uses this state to store the entire training
|
||||
* state (including model parameters, its gradients, the optimizer states and the properties). The ORTTraingSession
|
||||
* does not hold a copy of the checkpoint state. Therefore, it is required that the checkpoint state outlive the
|
||||
* lifetime of the training session.
|
||||
*/
|
||||
@interface ORTCheckpoint : NSObject
|
||||
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
|
||||
/**
|
||||
* Creates a checkpoint from directory on disk.
|
||||
*
|
||||
* @param path The path to the checkpoint directory.
|
||||
* @param error Optional error information set if an error occurs.
|
||||
* @return The instance, or nil if an error occurs.
|
||||
*
|
||||
* @warning The construction of the checkpoint state requires instantiation of `ORTEnv`.
|
||||
* The intialization will fail if the `ORTEnv` is not properly initialized.
|
||||
*/
|
||||
- (nullable instancetype)initWithPath:(NSString*)path
|
||||
error:(NSError**)error NS_DESIGNATED_INITIALIZER;
|
||||
|
||||
/**
|
||||
* Saves a checkpoint to directory on disk.
|
||||
*
|
||||
* @param path The path to the checkpoint directory.
|
||||
* @param includeOptimizerState Flag to indicate whether to save the optimizer state or not.
|
||||
* @param error Optional error information set if an error occurs.
|
||||
* @return Whether the checkpoint was saved successfully.
|
||||
*/
|
||||
- (BOOL)saveCheckpointToPath:(NSString*)path
|
||||
withOptimizerState:(BOOL)includeOptimizerState
|
||||
error:(NSError**)error;
|
||||
|
||||
/**
|
||||
* Adds an int property to this checkpoint.
|
||||
*
|
||||
* @param name The name of the property.
|
||||
* @param value The value of the property.
|
||||
* @param error Optional error information set if an error occurs.
|
||||
* @return Whether the property was added successfully.
|
||||
*/
|
||||
- (BOOL)addIntPropertyWithName:(NSString*)name
|
||||
value:(int64_t)value
|
||||
error:(NSError**)error;
|
||||
|
||||
/**
|
||||
* Adds a float property to this checkpoint.
|
||||
*
|
||||
* @param name The name of the property.
|
||||
* @param value The value of the property.
|
||||
* @param error Optional error information set if an error occurs.
|
||||
* @return Whether the property was added successfully.
|
||||
*/
|
||||
- (BOOL)addFloatPropertyWithName:(NSString*)name
|
||||
value:(float)value
|
||||
error:(NSError**)error;
|
||||
|
||||
/**
|
||||
* Adds a string property to this checkpoint.
|
||||
*
|
||||
* @param name The name of the property.
|
||||
* @param value The value of the property.
|
||||
* @param error Optional error information set if an error occurs.
|
||||
* @return Whether the property was added successfully.
|
||||
*/
|
||||
|
||||
- (BOOL)addStringPropertyWithName:(NSString*)name
|
||||
value:(NSString*)value
|
||||
error:(NSError**)error;
|
||||
|
||||
/**
|
||||
* Gets an int property from this checkpoint.
|
||||
*
|
||||
* @param name The name of the property.
|
||||
* @param error Optional error information set if an error occurs.
|
||||
* @return The value of the property or 0 if an error occurs.
|
||||
*/
|
||||
- (int64_t)getIntPropertyWithName:(NSString*)name
|
||||
error:(NSError**)error __attribute__((swift_error(nonnull_error)));
|
||||
|
||||
/**
|
||||
* Gets a float property from this checkpoint.
|
||||
*
|
||||
* @param name The name of the property.
|
||||
* @param error Optional error information set if an error occurs.
|
||||
* @return The value of the property or 0.0f if an error occurs.
|
||||
*/
|
||||
- (float)getFloatPropertyWithName:(NSString*)name
|
||||
error:(NSError**)error __attribute__((swift_error(nonnull_error)));
|
||||
|
||||
/**
|
||||
*
|
||||
* Gets a string property from this checkpoint.
|
||||
*
|
||||
* @param name The name of the property.
|
||||
* @param error Optional error information set if an error occurs.
|
||||
* @return The value of the property.
|
||||
*/
|
||||
- (nullable NSString*)getStringPropertyWithName:(NSString*)name
|
||||
error:(NSError**)error __attribute__((swift_error(nonnull_error)));
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
||||
|
||||
#endif // ENABLE_TRAINING_APIS
|
||||
113
objectivec/ort_checkpoint.mm
Normal file
113
objectivec/ort_checkpoint.mm
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifdef ENABLE_TRAINING_APIS
|
||||
#import "ort_checkpoint_internal.h"
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <variant>
|
||||
#import "cxx_api.h"
|
||||
|
||||
#import "error_utils.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
@implementation ORTCheckpoint {
|
||||
std::optional<Ort::CheckpointState> _checkpoint;
|
||||
}
|
||||
|
||||
- (nullable instancetype)initWithPath:(NSString*)path
|
||||
error:(NSError**)error {
|
||||
if ((self = [super init]) == nil) {
|
||||
return nil;
|
||||
}
|
||||
|
||||
try {
|
||||
_checkpoint = Ort::CheckpointState::LoadCheckpoint(path.UTF8String);
|
||||
return self;
|
||||
}
|
||||
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
|
||||
}
|
||||
|
||||
- (BOOL)saveCheckpointToPath:(NSString*)path
|
||||
withOptimizerState:(BOOL)includeOptimizerState
|
||||
error:(NSError**)error {
|
||||
try {
|
||||
Ort::CheckpointState::SaveCheckpoint([self CXXAPIOrtCheckpoint], path.UTF8String, includeOptimizerState);
|
||||
return YES;
|
||||
}
|
||||
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
|
||||
}
|
||||
|
||||
- (BOOL)addIntPropertyWithName:(NSString*)name
|
||||
value:(int64_t)value
|
||||
error:(NSError**)error {
|
||||
try {
|
||||
[self CXXAPIOrtCheckpoint].AddProperty(name.UTF8String, value);
|
||||
return YES;
|
||||
}
|
||||
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
|
||||
}
|
||||
|
||||
- (BOOL)addFloatPropertyWithName:(NSString*)name
|
||||
value:(float)value
|
||||
error:(NSError**)error {
|
||||
try {
|
||||
[self CXXAPIOrtCheckpoint].AddProperty(name.UTF8String, value);
|
||||
return YES;
|
||||
}
|
||||
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
|
||||
}
|
||||
|
||||
- (BOOL)addStringPropertyWithName:(NSString*)name
|
||||
value:(NSString*)value
|
||||
error:(NSError**)error {
|
||||
try {
|
||||
[self CXXAPIOrtCheckpoint].AddProperty(name.UTF8String, value.UTF8String);
|
||||
return YES;
|
||||
}
|
||||
ORT_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
|
||||
}
|
||||
|
||||
- (nullable NSString*)getStringPropertyWithName:(NSString*)name error:(NSError**)error {
|
||||
try {
|
||||
Ort::Property value = [self CXXAPIOrtCheckpoint].GetProperty(name.UTF8String);
|
||||
if (std::string* str = std::get_if<std::string>(&value)) {
|
||||
return [NSString stringWithUTF8String:str->c_str()];
|
||||
}
|
||||
ORT_CXX_API_THROW("Property is not a string.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
|
||||
}
|
||||
|
||||
- (int64_t)getIntPropertyWithName:(NSString*)name error:(NSError**)error {
|
||||
try {
|
||||
Ort::Property value = [self CXXAPIOrtCheckpoint].GetProperty(name.UTF8String);
|
||||
if (int64_t* i = std::get_if<int64_t>(&value)) {
|
||||
return *i;
|
||||
}
|
||||
ORT_CXX_API_THROW("Property is not an integer.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
ORT_OBJC_API_IMPL_CATCH(error, 0)
|
||||
}
|
||||
|
||||
- (float)getFloatPropertyWithName:(NSString*)name error:(NSError**)error {
|
||||
try {
|
||||
Ort::Property value = [self CXXAPIOrtCheckpoint].GetProperty(name.UTF8String);
|
||||
if (float* f = std::get_if<float>(&value)) {
|
||||
return *f;
|
||||
}
|
||||
ORT_CXX_API_THROW("Property is not a float.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
ORT_OBJC_API_IMPL_CATCH(error, 0.0f)
|
||||
}
|
||||
|
||||
- (Ort::CheckpointState&)CXXAPIOrtCheckpoint {
|
||||
return *_checkpoint;
|
||||
}
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
||||
#endif // ENABLE_TRAINING_APIS
|
||||
18
objectivec/ort_checkpoint_internal.h
Normal file
18
objectivec/ort_checkpoint_internal.h
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifdef ENABLE_TRAINING_APIS
|
||||
#import "ort_checkpoint.h"
|
||||
|
||||
#import "cxx_api.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
@interface ORTCheckpoint ()
|
||||
|
||||
- (Ort::CheckpointState&)CXXAPIOrtCheckpoint;
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
||||
#endif // ENABLE_TRAINING_APIS
|
||||
97
objectivec/test/ort_checkpoint_test.mm
Normal file
97
objectivec/test/ort_checkpoint_test.mm
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifdef ENABLE_TRAINING_APIS
|
||||
#import <XCTest/XCTest.h>
|
||||
|
||||
#import "ort_checkpoint.h"
|
||||
#import "ort_env.h"
|
||||
#import "test/assertion_utils.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
@interface ORTCheckpointTest : XCTestCase
|
||||
@property(readonly, nullable) ORTEnv* ortEnv;
|
||||
@end
|
||||
|
||||
@implementation ORTCheckpointTest
|
||||
|
||||
- (void)setUp {
|
||||
[super setUp];
|
||||
|
||||
self.continueAfterFailure = NO;
|
||||
|
||||
NSError* err = nil;
|
||||
_ortEnv = [[ORTEnv alloc] initWithLoggingLevel:ORTLoggingLevelWarning
|
||||
error:&err];
|
||||
ORTAssertNullableResultSuccessful(_ortEnv, err);
|
||||
}
|
||||
|
||||
- (NSString*)getCheckpointPath {
|
||||
NSBundle* bundle = [NSBundle bundleForClass:[ORTCheckpointTest class]];
|
||||
NSString* path = [[bundle resourcePath] stringByAppendingPathComponent:@"checkpoint.ckpt"];
|
||||
return path;
|
||||
}
|
||||
|
||||
- (void)testLoadCheckpoint {
|
||||
NSError* error = nil;
|
||||
ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[self getCheckpointPath] error:&error];
|
||||
ORTAssertNullableResultSuccessful(checkpoint, error);
|
||||
}
|
||||
|
||||
- (void)testIntProperty {
|
||||
NSError* error = nil;
|
||||
// Load checkpoint
|
||||
ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[self getCheckpointPath] error:&error];
|
||||
ORTAssertNullableResultSuccessful(checkpoint, error);
|
||||
|
||||
// Add property
|
||||
BOOL result = [checkpoint addIntPropertyWithName:@"test" value:314 error:&error];
|
||||
ORTAssertBoolResultSuccessful(result, error);
|
||||
|
||||
// Get property
|
||||
int64_t value = [checkpoint getIntPropertyWithName:@"test" error:&error];
|
||||
XCTAssertEqual(value, 314);
|
||||
}
|
||||
|
||||
- (void)testFloatProperty {
|
||||
NSError* error = nil;
|
||||
// Load checkpoint
|
||||
ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[self getCheckpointPath] error:&error];
|
||||
ORTAssertNullableResultSuccessful(checkpoint, error);
|
||||
|
||||
// Add property
|
||||
BOOL result = [checkpoint addFloatPropertyWithName:@"test" value:3.14f error:&error];
|
||||
ORTAssertBoolResultSuccessful(result, error);
|
||||
|
||||
// Get property
|
||||
float value = [checkpoint getFloatPropertyWithName:@"test" error:&error];
|
||||
XCTAssertEqual(value, 3.14f);
|
||||
}
|
||||
|
||||
- (void)testStringProperty {
|
||||
NSError* error = nil;
|
||||
// Load checkpoint
|
||||
ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[self getCheckpointPath] error:&error];
|
||||
ORTAssertNullableResultSuccessful(checkpoint, error);
|
||||
|
||||
// Add property
|
||||
BOOL result = [checkpoint addStringPropertyWithName:@"test" value:@"hello" error:&error];
|
||||
ORTAssertBoolResultSuccessful(result, error);
|
||||
|
||||
// Get property
|
||||
NSString* value = [checkpoint getStringPropertyWithName:@"test" error:&error];
|
||||
XCTAssertEqualObjects(value, @"hello");
|
||||
}
|
||||
|
||||
- (void)tearDown {
|
||||
_ortEnv = nil;
|
||||
|
||||
[super tearDown];
|
||||
}
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
||||
|
||||
#endif // ENABLE_TRAINING_APIS
|
||||
|
|
@ -3,5 +3,5 @@ stages:
|
|||
parameters:
|
||||
AllowReleasedOpsetOnly: 0
|
||||
BuildForAllArchs: false
|
||||
AdditionalBuildFlags: --enable_training
|
||||
AdditionalBuildFlags: --enable_training --build_objc
|
||||
WithCache: true
|
||||
|
|
|
|||
Loading…
Reference in a new issue