diff --git a/Package.swift b/Package.swift index 7f8bfe0c3c..e053732811 100644 --- a/Package.swift +++ b/Package.swift @@ -32,7 +32,14 @@ let package = Package( .target(name: "OnnxRuntimeBindings", dependencies: ["onnxruntime"], path: "objectivec", - exclude: ["test", "docs", "ReadMe.md", "format_objc.sh"], + exclude: ["test", "docs", "ReadMe.md", "format_objc.sh", + "ort_checkpoint.mm", + "ort_checkpoint_internal.h", + "ort_training_session_internal.h", + "ort_training_session.mm", + "include/ort_checkpoint.h", + "include/ort_training_session.h", + "include/onnxruntime_training.h"], cxxSettings: [ .define("SPM_BUILD"), .unsafeFlags(["-std=c++17", diff --git a/objectivec/cxx_api.h b/objectivec/cxx_api.h index 26acfb8b86..b57c865a92 100644 --- a/objectivec/cxx_api.h +++ b/objectivec/cxx_api.h @@ -20,12 +20,12 @@ #endif // clang-format on -#ifndef ENABLE_TRAINING_APIS -#include ORT_C_CXX_HEADER_FILE_PATH(onnxruntime_c_api.h) -#include ORT_C_CXX_HEADER_FILE_PATH(onnxruntime_cxx_api.h) -#else +#if __has_include(ORT_C_CXX_HEADER_FILE_PATH(onnxruntime_training_c_api.h)) #include ORT_C_CXX_HEADER_FILE_PATH(onnxruntime_training_c_api.h) #include ORT_C_CXX_HEADER_FILE_PATH(onnxruntime_training_cxx_api.h) +#else +#include ORT_C_CXX_HEADER_FILE_PATH(onnxruntime_c_api.h) +#include ORT_C_CXX_HEADER_FILE_PATH(onnxruntime_cxx_api.h) #endif #if __has_include(ORT_C_CXX_HEADER_FILE_PATH(coreml_provider_factory.h)) diff --git a/objectivec/include/ort_checkpoint.h b/objectivec/include/ort_checkpoint.h index 2b0144a38d..dbb61b7e01 100644 --- a/objectivec/include/ort_checkpoint.h +++ b/objectivec/include/ort_checkpoint.h @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef ENABLE_TRAINING_APIS #import #include @@ -121,5 +120,3 @@ NS_ASSUME_NONNULL_BEGIN @end NS_ASSUME_NONNULL_END - -#endif // ENABLE_TRAINING_APIS diff --git a/objectivec/include/ort_training_session.h b/objectivec/include/ort_training_session.h index 54b7e54289..ec0a46d331 100644 --- a/objectivec/include/ort_training_session.h +++ b/objectivec/include/ort_training_session.h @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef ENABLE_TRAINING_APIS #import #include @@ -257,5 +256,3 @@ void ORTSetSeed(int64_t seed); #endif NS_ASSUME_NONNULL_END - -#endif // ENABLE_TRAINING_APIS diff --git a/objectivec/ort_checkpoint.mm b/objectivec/ort_checkpoint.mm index ee88e9c9c1..12386457fa 100644 --- a/objectivec/ort_checkpoint.mm +++ b/objectivec/ort_checkpoint.mm @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef ENABLE_TRAINING_APIS #import "ort_checkpoint_internal.h" #include @@ -110,4 +109,3 @@ NS_ASSUME_NONNULL_BEGIN @end NS_ASSUME_NONNULL_END -#endif // ENABLE_TRAINING_APIS diff --git a/objectivec/ort_checkpoint_internal.h b/objectivec/ort_checkpoint_internal.h index 0001913b6e..3d1550cc59 100644 --- a/objectivec/ort_checkpoint_internal.h +++ b/objectivec/ort_checkpoint_internal.h @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef ENABLE_TRAINING_APIS #import "ort_checkpoint.h" #import "cxx_api.h" @@ -15,4 +14,3 @@ NS_ASSUME_NONNULL_BEGIN @end NS_ASSUME_NONNULL_END -#endif // ENABLE_TRAINING_APIS diff --git a/objectivec/ort_training_session.mm b/objectivec/ort_training_session.mm index e33890f12c..285151b412 100644 --- a/objectivec/ort_training_session.mm +++ b/objectivec/ort_training_session.mm @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef ENABLE_TRAINING_APIS #import "ort_training_session_internal.h" #import @@ -223,5 +222,3 @@ void ORTSetSeed(int64_t seed) { } NS_ASSUME_NONNULL_END - -#endif // ENABLE_TRAINING_APIS diff --git a/objectivec/ort_training_session_internal.h b/objectivec/ort_training_session_internal.h index 402c84eb5b..453c941d5f 100644 --- a/objectivec/ort_training_session_internal.h +++ b/objectivec/ort_training_session_internal.h @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef ENABLE_TRAINING_APIS #import "ort_training_session.h" #import "cxx_api.h" @@ -15,5 +14,3 @@ NS_ASSUME_NONNULL_BEGIN @end NS_ASSUME_NONNULL_END - -#endif // ENABLE_TRAINING_APIS diff --git a/objectivec/test/ort_checkpoint_test.mm b/objectivec/test/ort_checkpoint_test.mm index df97dcf01d..9b2196fabb 100644 --- a/objectivec/test/ort_checkpoint_test.mm +++ b/objectivec/test/ort_checkpoint_test.mm @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef ENABLE_TRAINING_APIS #import #import "ort_checkpoint.h" @@ -116,5 +115,3 @@ NS_ASSUME_NONNULL_BEGIN @end NS_ASSUME_NONNULL_END - -#endif // ENABLE_TRAINING_APIS diff --git a/objectivec/test/ort_training_session_test.mm b/objectivec/test/ort_training_session_test.mm index 30ef51f3a0..683965dc76 100644 --- a/objectivec/test/ort_training_session_test.mm +++ b/objectivec/test/ort_training_session_test.mm @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef ENABLE_TRAINING_APIS #import #import "ort_checkpoint.h" @@ -358,5 +357,3 @@ NS_ASSUME_NONNULL_BEGIN @end NS_ASSUME_NONNULL_END - -#endif // ENABLE_TRAINING_APIS diff --git a/objectivec/test/ort_training_utils_test.mm b/objectivec/test/ort_training_utils_test.mm index 77636bb237..1695ddc57b 100644 --- a/objectivec/test/ort_training_utils_test.mm +++ b/objectivec/test/ort_training_utils_test.mm @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef ENABLE_TRAINING_APIS #import #import "ort_training_session.h" @@ -25,5 +24,3 @@ NS_ASSUME_NONNULL_BEGIN @end NS_ASSUME_NONNULL_END - -#endif // ENABLE_TRAINING_APIS diff --git a/tools/ci_build/github/apple/c/assemble_c_pod_package.py b/tools/ci_build/github/apple/c/assemble_c_pod_package.py index dd80123c25..14e7729610 100644 --- a/tools/ci_build/github/apple/c/assemble_c_pod_package.py +++ b/tools/ci_build/github/apple/c/assemble_c_pod_package.py @@ -30,6 +30,8 @@ def get_pod_config_file(package_variant: PackageVariant): return _script_dir / "onnxruntime-mobile-c.config.json" elif package_variant == PackageVariant.Test: return _script_dir / "onnxruntime-test-c.config.json" + elif package_variant == PackageVariant.Training: + return _script_dir / "onnxruntime-training-c.config.json" else: raise ValueError(f"Unhandled package variant: {package_variant}") diff --git a/tools/ci_build/github/apple/c/onnxruntime-training-c.config.json b/tools/ci_build/github/apple/c/onnxruntime-training-c.config.json new file mode 100644 index 0000000000..87011c216a --- /dev/null +++ b/tools/ci_build/github/apple/c/onnxruntime-training-c.config.json @@ -0,0 +1,5 @@ +{ + "name": "onnxruntime-training-c", + "summary": "ONNX Runtime Training C/C++ Pod", + "description": "A pod for the ONNX Runtime C/C++ library. This pod supports additional training features." +} diff --git a/tools/ci_build/github/apple/default_training_ios_framework_build_settings.json b/tools/ci_build/github/apple/default_training_ios_framework_build_settings.json new file mode 100644 index 0000000000..aa9bdc483d --- /dev/null +++ b/tools/ci_build/github/apple/default_training_ios_framework_build_settings.json @@ -0,0 +1,21 @@ +{ + "build_osx_archs": { + "iphoneos": [ + "arm64" + ], + "iphonesimulator": [ + "arm64", + "x86_64" + ] + }, + "build_params": [ + "--ios", + "--parallel", + "--use_xcode", + "--enable_training_apis", + "--build_apple_framework", + "--skip_tests", + "--cmake_extra_defines=onnxruntime_BUILD_UNIT_TESTS=OFF", + "--apple_deploy_target=12.0" + ] +} diff --git a/tools/ci_build/github/apple/objectivec/assemble_objc_pod_package.py b/tools/ci_build/github/apple/objectivec/assemble_objc_pod_package.py index b80c408e98..7d1005a34c 100755 --- a/tools/ci_build/github/apple/objectivec/assemble_objc_pod_package.py +++ b/tools/ci_build/github/apple/objectivec/assemble_objc_pod_package.py @@ -15,6 +15,7 @@ from c.assemble_c_pod_package import get_pod_config_file as get_c_pod_config_fil from package_assembly_utils import ( # noqa: E402 PackageVariant, copy_repo_relative_to_dir, + filter_files, gen_file_from_template, load_json_config, ) @@ -29,31 +30,65 @@ include_dirs = [ "objectivec", ] -# pod source files -source_files = [ - "objectivec/include/*.h", - "objectivec/*.h", - "objectivec/*.m", - "objectivec/*.mm", -] +all_objc_files = { + "source_files": [ + "objectivec/include/*.h", + "objectivec/*.h", + "objectivec/*.m", + "objectivec/*.mm", + ], + "public_header_files": [ + "objectivec/include/*.h", + ], + "test_source_files": [ + "objectivec/test/*.h", + "objectivec/test/*.m", + "objectivec/test/*.mm", + ], + "test_resource_files": [ + "objectivec/test/testdata/*.ort", + "onnxruntime/test/testdata/training_api/*", + ], +} -# pod public header files -# note: these are a subset of source_files -public_header_files = [ - "objectivec/include/*.h", -] +training_only_objc_files = { + "source_files": [ + "objectivec/include/onnxruntime_training.h", + "objectivec/include/ort_checkpoint.h", + "objectivec/include/ort_training_session.h", + "objectivec/ort_checkpoint.mm", + "objectivec/ort_checkpoint_internal.h", + "objectivec/ort_training_session_internal.h", + "objectivec/ort_training_session.mm", + ], + "public_header_files": [ + "objectivec/include/ort_checkpoint.h", + "objectivec/include/ort_training_session.h", + "objectivec/include/onnxruntime_training.h", + ], + "test_source_files": [ + "objectivec/test/ort_training_session_test.mm", + "objectivec/test/ort_checkpoint_test.mm", + "objectivec/test/ort_training_utils_test.mm", + ], + "test_resource_files": [ + "onnxruntime/test/testdata/training_api/*", + ], +} -# pod test source files -test_source_files = [ - "objectivec/test/*.h", - "objectivec/test/*.m", - "objectivec/test/*.mm", -] -# pod test resource files -test_resource_files = [ - "objectivec/test/testdata/*.ort", -] +def get_pod_files(package_variant: PackageVariant): + """ + Gets the source and header files for the given package variant. + """ + if package_variant == PackageVariant.Training: + return all_objc_files + else: + # return files that are in pod_files but not in training_only_objc_files + filtered_pod_files = {} + for key in all_objc_files: + filtered_pod_files[key] = filter_files(all_objc_files[key], training_only_objc_files[key]) + return filtered_pod_files def get_pod_config_file(package_variant: PackageVariant): @@ -64,6 +99,8 @@ def get_pod_config_file(package_variant: PackageVariant): return _script_dir / "onnxruntime-objc.config.json" elif package_variant == PackageVariant.Mobile: return _script_dir / "onnxruntime-mobile-objc.config.json" + elif package_variant == PackageVariant.Training: + return _script_dir / "onnxruntime-training-objc.config.json" else: raise ValueError(f"Unhandled package variant: {package_variant}") @@ -93,8 +130,13 @@ def assemble_objc_pod_package( if staging_dir.exists(): print("Warning: staging directory already exists", file=sys.stderr) + pod_files = get_pod_files(package_variant) + # copy the necessary files to the staging directory - copy_repo_relative_to_dir([license_file, *source_files, *test_source_files, *test_resource_files], staging_dir) + copy_repo_relative_to_dir( + [license_file, *pod_files["source_files"], *pod_files["test_source_files"], *pod_files["test_resource_files"]], + staging_dir, + ) # generate the podspec file from the template @@ -108,11 +150,11 @@ def assemble_objc_pod_package( "IOS_DEPLOYMENT_TARGET": framework_info["IOS_DEPLOYMENT_TARGET"], "LICENSE_FILE": license_file, "NAME": pod_name, - "PUBLIC_HEADER_FILE_LIST": path_patterns_as_variable_value(public_header_files), - "SOURCE_FILE_LIST": path_patterns_as_variable_value(source_files), + "PUBLIC_HEADER_FILE_LIST": path_patterns_as_variable_value(pod_files["public_header_files"]), + "SOURCE_FILE_LIST": path_patterns_as_variable_value(pod_files["source_files"]), "SUMMARY": pod_config["summary"], - "TEST_RESOURCE_FILE_LIST": path_patterns_as_variable_value(test_resource_files), - "TEST_SOURCE_FILE_LIST": path_patterns_as_variable_value(test_source_files), + "TEST_RESOURCE_FILE_LIST": path_patterns_as_variable_value(pod_files["test_resource_files"]), + "TEST_SOURCE_FILE_LIST": path_patterns_as_variable_value(pod_files["test_source_files"]), "VERSION": pod_version, } diff --git a/tools/ci_build/github/apple/objectivec/onnxruntime-training-objc.config.json b/tools/ci_build/github/apple/objectivec/onnxruntime-training-objc.config.json new file mode 100644 index 0000000000..b1cc2d4aad --- /dev/null +++ b/tools/ci_build/github/apple/objectivec/onnxruntime-training-objc.config.json @@ -0,0 +1,5 @@ +{ + "name": "onnxruntime-training-objc", + "summary": "ONNX Runtime Objective-C Pod", + "description": "A pod for the ONNX Runtime Objective-C training API." +} diff --git a/tools/ci_build/github/apple/package_assembly_utils.py b/tools/ci_build/github/apple/package_assembly_utils.py index cb603fadc7..e5940774c5 100644 --- a/tools/ci_build/github/apple/package_assembly_utils.py +++ b/tools/ci_build/github/apple/package_assembly_utils.py @@ -16,6 +16,7 @@ repo_root = _script_dir.parents[3] class PackageVariant(enum.Enum): Full = 0 # full ORT build with all opsets, ops, and types Mobile = 1 # minimal ORT build with reduced ops + Training = 2 # full ORT build with all opsets, ops, and types, plus training APIs Test = -1 # for testing purposes only @classmethod @@ -70,6 +71,27 @@ def gen_file_from_template( output.write(content) +def filter_files(all_file_patterns: List[str], excluded_file_patterns: List[str]): + """ + Filters file paths based on inclusion and exclusion patterns + + :param all_file_patterns The list of file paths to filter. + :param excluded_file_patterns The list of exclusion patterns. + + :return The filtered list of file paths + """ + # get all files matching the patterns in all_file_patterns + all_files = [str(path.relative_to(repo_root)) for pattern in all_file_patterns for path in repo_root.glob(pattern)] + + # get all files matching the patterns in excluded_file_patterns + exclude_files = [ + str(path.relative_to(repo_root)) for pattern in excluded_file_patterns for path in repo_root.glob(pattern) + ] + + # return the difference + return list(set(all_files) - set(exclude_files)) + + def copy_repo_relative_to_dir(patterns: List[str], dest_dir: pathlib.Path): """ Copies file paths relative to the repo root to a directory.