onnxruntime/objectivec/test/ort_checkpoint_test.mm
Vrajang Parikh 960e320dff
Objective C Training API: TrainingSession (#16374)
### Description
- Implement Objective-C binding for `ORTTrainingSession`
- Add `ORTUtils` utility class to handle conversion between C++ and
Objective-C types
- Add test case for saving checkpoint
- Add unit test cases for `ORTTrainingSession`

### Motivation and Context
This PR is part of implementing Objective-C bindings for training API.
It implements objective-c binding for training session. The objective-C
API closely resembles the C++ API.

---------

Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
2023-06-28 09:13:56 -07:00

120 lines
3.6 KiB
Text

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef ENABLE_TRAINING_APIS
#import <XCTest/XCTest.h>
#import "ort_checkpoint.h"
#import "ort_training_session.h"
#import "ort_env.h"
#import "ort_session.h"
#import "test/test_utils.h"
#import "test/assertion_utils.h"
NS_ASSUME_NONNULL_BEGIN
@interface ORTCheckpointTest : XCTestCase
@property(readonly, nullable) ORTEnv* ortEnv;
@end
@implementation ORTCheckpointTest
- (void)setUp {
[super setUp];
self.continueAfterFailure = NO;
NSError* err = nil;
_ortEnv = [[ORTEnv alloc] initWithLoggingLevel:ORTLoggingLevelWarning
error:&err];
ORTAssertNullableResultSuccessful(_ortEnv, err);
}
+ (NSString*)getCheckpointPath {
NSBundle* bundle = [NSBundle bundleForClass:[ORTCheckpointTest class]];
NSString* path = [[bundle resourcePath] stringByAppendingPathComponent:@"checkpoint.ckpt"];
return path;
}
+ (NSString*)getTrainingModelPath {
NSBundle* bundle = [NSBundle bundleForClass:[ORTCheckpointTest class]];
NSString* path = [[bundle resourcePath] stringByAppendingPathComponent:@"training_model.onnx"];
return path;
}
- (void)testSaveCheckpoint {
NSError* error = nil;
ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[ORTCheckpointTest getCheckpointPath] error:&error];
ORTAssertNullableResultSuccessful(checkpoint, error);
// save checkpoint
NSString* path = [test_utils::createTemporaryDirectory(self) stringByAppendingPathComponent:@"save_checkpoint.ckpt"];
XCTAssertNotNil(path);
BOOL result = [checkpoint saveCheckpointToPath:path withOptimizerState:NO error:&error];
ORTAssertBoolResultSuccessful(result, error);
}
- (void)testInitCheckpoint {
NSError* error = nil;
ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[ORTCheckpointTest getCheckpointPath] error:&error];
ORTAssertNullableResultSuccessful(checkpoint, error);
}
- (void)testIntProperty {
NSError* error = nil;
// Load checkpoint
ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[ORTCheckpointTest getCheckpointPath] error:&error];
ORTAssertNullableResultSuccessful(checkpoint, error);
// Add property
BOOL result = [checkpoint addIntPropertyWithName:@"test" value:314 error:&error];
ORTAssertBoolResultSuccessful(result, error);
// Get property
int64_t value = [checkpoint getIntPropertyWithName:@"test" error:&error];
XCTAssertEqual(value, 314);
}
- (void)testFloatProperty {
NSError* error = nil;
// Load checkpoint
ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[ORTCheckpointTest getCheckpointPath] error:&error];
ORTAssertNullableResultSuccessful(checkpoint, error);
// Add property
BOOL result = [checkpoint addFloatPropertyWithName:@"test" value:3.14f error:&error];
ORTAssertBoolResultSuccessful(result, error);
// Get property
float value = [checkpoint getFloatPropertyWithName:@"test" error:&error];
XCTAssertEqual(value, 3.14f);
}
- (void)testStringProperty {
NSError* error = nil;
// Load checkpoint
ORTCheckpoint* checkpoint = [[ORTCheckpoint alloc] initWithPath:[ORTCheckpointTest getCheckpointPath] error:&error];
ORTAssertNullableResultSuccessful(checkpoint, error);
// Add property
BOOL result = [checkpoint addStringPropertyWithName:@"test" value:@"hello" error:&error];
ORTAssertBoolResultSuccessful(result, error);
// Get property
NSString* value = [checkpoint getStringPropertyWithName:@"test" error:&error];
XCTAssertEqualObjects(value, @"hello");
}
- (void)tearDown {
_ortEnv = nil;
[super tearDown];
}
@end
NS_ASSUME_NONNULL_END
#endif // ENABLE_TRAINING_APIS