diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index df39be3e54..9594f77f0e 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -1300,6 +1300,9 @@ function(onnxruntime_configure_target target_name) # set_target_properties(${target_name} PROPERTIES VS_USER_PROPS ${PROJECT_SOURCE_DIR}/EnableVisualStudioCodeAnalysis.props) #endif() target_include_directories(${target_name} PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT} ${abseil_cpp_SOURCE_DIR}) + if (onnxruntime_ENABLE_TRAINING_ON_DEVICE) + target_include_directories(${target_name} PRIVATE ${ORTTRAINING_ROOT}) + endif() if (onnxruntime_ENABLE_LTO) set_target_properties(${target_name} PROPERTIES INTERPROCEDURAL_OPTIMIZATION_RELEASE TRUE) set_target_properties(${target_name} PROPERTIES INTERPROCEDURAL_OPTIMIZATION_RELWITHDEBINFO TRUE) @@ -1689,6 +1692,9 @@ set(onnxruntime_DELAYLOAD_FLAGS "") include_directories( ${ONNXRUNTIME_INCLUDE_DIR} ${REPO_ROOT}/include/onnxruntime/core/session + if (onnxruntime_ENABLE_TRAINING_ON_DEVICE) + ${REPO_ROOT}/orttraining/orttraining/training_api/include/ + endif() ) if (onnxruntime_USE_OPENVINO) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 98ad6c9c72..e899428843 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -270,11 +270,6 @@ ORT_RUNTIME_CLASS(CUDAProviderOptionsV2); ORT_RUNTIME_CLASS(Op); ORT_RUNTIME_CLASS(OpAttr); -#ifdef ENABLE_TRAINING_ON_DEVICE -ORT_RUNTIME_CLASS(TrainingSession); -ORT_RUNTIME_CLASS(CheckpointState); -#endif - #ifdef _WIN32 typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; #else @@ -539,6 +534,11 @@ typedef struct OrtOpenVINOProviderOptions { struct OrtApi; typedef struct OrtApi OrtApi; +#ifdef ENABLE_TRAINING_ON_DEVICE +struct OrtTrainingApi; +typedef struct OrtTrainingApi OrtTrainingApi; +#endif + /** \brief The helper interface to get the right version of OrtApi * * Get a pointer to this structure through ::OrtGetApiBase @@ -551,6 +551,14 @@ struct OrtApiBase { * older than the version created with this header file. */ const OrtApi*(ORT_API_CALL* GetApi)(uint32_t version)NO_EXCEPTION; +#ifdef ENABLE_TRAINING_ON_DEVICE + /** \brief Get a pointer to the requested version of the ::OrtTrainingApi + * + * \param[in] version Must be ::ORT_API_VERSION + * \return The ::OrtTrainingApi for the version requested, nullptr will be returned if this version is unsupported. + */ + const OrtTrainingApi*(ORT_API_CALL* GetTrainingApi)(uint32_t version)NO_EXCEPTION; +#endif const char*(ORT_API_CALL* GetVersionString)(void)NO_EXCEPTION; ///< Returns a null terminated string of the version of the Onnxruntime library (eg: "1.8.1") }; typedef struct OrtApiBase OrtApiBase; @@ -3452,12 +3460,6 @@ struct OrtApi { _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); - -#ifdef ENABLE_TRAINING_ON_DEVICE - // defines c apis for on device training scenarios - #include "../../../orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h" -#endif - }; /* diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 389f9080e5..d04052fcc0 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -45,6 +45,11 @@ ProviderInfo_CUDA* TryGetProviderInfo_CUDA(); } #endif +#ifdef ENABLE_TRAINING_ON_DEVICE +#include "orttraining/training_api/include/onnxruntime_training_c_api.h" +#include "orttraining/training_api/include/ort_training_apis.h" +#endif + #ifdef USE_DML #include "core/providers/dml/dml_provider_factory.h" const OrtDmlApi* GetOrtDmlApi(_In_ uint32_t version) NO_EXCEPTION; @@ -2229,6 +2234,9 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetCustomJoinThreadFn, _Inout_ OrtSes static constexpr OrtApiBase ort_api_base = { &OrtApis::GetApi, +#ifdef ENABLE_TRAINING_ON_DEVICE + &OrtTrainingApis::GetTrainingApi, +#endif &OrtApis::GetVersionString, }; @@ -2528,21 +2536,6 @@ static constexpr OrtApi ort_api_1_to_12 = { &OrtApis::InvokeOp, &OrtApis::ReleaseOp, &OrtApis::SessionOptionsAppendExecutionProvider_SNPE, -#ifdef ENABLE_TRAINING_ON_DEVICE - // Experimental for on-device training. Always keep at the bottom. - &OrtApis::LoadCheckpoint, - &OrtApis::SaveCheckpoint, - &OrtApis::CreateTrainingSession, - &OrtApis::TrainingSessionGetTrainModeOutputCount, - &OrtApis::TrainingSessionGetEvalModeOutputCount, - &OrtApis::ResetGrad, - &OrtApis::TrainStep, - &OrtApis::EvalStep, - &OrtApis::OptimizerStep, - &OrtApis::ReleaseTrainingSession, - &OrtApis::ReleaseCheckpointState, -#endif - }; // Asserts to do a some checks to ensure older Versions of the OrtApi never change (will detect an addition or deletion but not if they cancel out each other) diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 48ac277a91..9c1f4f26fc 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -381,35 +381,4 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_SNPE, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); -#ifdef ENABLE_TRAINING_ON_DEVICE -ORT_API_STATUS_IMPL(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options, - _Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path, - _In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path, - _Outptr_ OrtTrainingSession** out); -ORT_API(void, ReleaseTrainingSession, _Frees_ptr_opt_ OrtTrainingSession* session); - -ORT_API_STATUS_IMPL(TrainingSessionGetTrainModeOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out); - -ORT_API_STATUS_IMPL(TrainingSessionGetEvalModeOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out); - -ORT_API_STATUS_IMPL(ResetGrad, _Inout_ OrtTrainingSession* session); - -ORT_API_STATUS_IMPL(TrainStep, _Inout_ OrtTrainingSession* session, _In_opt_ const OrtRunOptions* run_options, - size_t inputs_len, _In_reads_(input_len) const OrtValue* const* inputs, - size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs); - -ORT_API_STATUS_IMPL(EvalStep, _Inout_ OrtTrainingSession* session, _In_opt_ const OrtRunOptions* run_options, - size_t inputs_len, _In_reads_(input_len) const OrtValue* const* inputs, - size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs); - -ORT_API_STATUS_IMPL(OptimizerStep, _Inout_ OrtTrainingSession* session, _In_opt_ const OrtRunOptions* run_options); - -ORT_API_STATUS_IMPL(LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Outptr_ OrtCheckpointState** checkpoint_state); - -ORT_API_STATUS_IMPL(SaveCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Inout_ OrtTrainingSession* session, - bool save_optimizer_state); - -ORT_API(void, ReleaseCheckpointState, _Frees_ptr_opt_ OrtCheckpointState* session); -#endif - } // namespace OrtApis diff --git a/orttraining/orttraining/python/training/onnxblock/_graph_utils.py b/orttraining/orttraining/python/training/onnxblock/_graph_utils.py index 97304e9985..7ac20620ae 100644 --- a/orttraining/orttraining/python/training/onnxblock/_graph_utils.py +++ b/orttraining/orttraining/python/training/onnxblock/_graph_utils.py @@ -171,7 +171,7 @@ def build_gradient_accumulation_graph(grad_model, all_args_requiring_gradient_na # accumulated gradient update flag is also a graph output grad_accumulation_output = onnx.helper.make_tensor_value_info( grad_accumulation_output_name, onnx.TensorProto.BOOL, [1] - ) + ) graph_outputs.append(grad_accumulation_output) lazy_reset_grad_input = onnx.helper.make_tensor_value_info(lazy_reset_grad_input_name, onnx.TensorProto.BOOL, [1]) diff --git a/orttraining/orttraining/test/training_api/trainer/trainer.cc b/orttraining/orttraining/test/training_api/trainer/trainer.cc index 28d5c269f7..3703cef73f 100644 --- a/orttraining/orttraining/test/training_api/trainer/trainer.cc +++ b/orttraining/orttraining/test/training_api/trainer/trainer.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include -#include "orttraining/training_api/include/utils.h" +#include #include "cxxopts.hpp" #include "../common/synthetic_data_loader.h" @@ -13,10 +13,10 @@ #include "core/providers/cuda/nvtx_profile_context.h" #endif -using namespace onnxruntime::training::api; using namespace std; const OrtApi* g_ort_api = nullptr; +const OrtTrainingApi* g_ort_training_api = nullptr; struct TestRunnerParameters { // Models configs. @@ -173,6 +173,7 @@ void InitSyntheticDataLoader( int RunTraining(const TestRunnerParameters& params) { g_ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + g_ort_training_api = OrtGetApiBase()->GetTrainingApi(ORT_API_VERSION); // Create Env OrtEnv* env; @@ -185,7 +186,7 @@ int RunTraining(const TestRunnerParameters& params) { // Load Checkpoint State OrtCheckpointState* checkpoint_state; - ORT_RETURN_ON_ERROR(g_ort_api->LoadCheckpoint(params.checkpoint_to_load_path.c_str(), &checkpoint_state)); + ORT_RETURN_ON_ERROR(g_ort_training_api->LoadCheckpoint(params.checkpoint_to_load_path.c_str(), &checkpoint_state)); // Create TrainingSession OrtSessionOptions* soptions; @@ -199,16 +200,16 @@ int RunTraining(const TestRunnerParameters& params) { OrtTrainingSession* session; bool do_eval = params.model_evaluation_graph_path.has_value(); - ORT_RETURN_ON_ERROR(g_ort_api->CreateTrainingSession(env, soptions, checkpoint_state, - params.model_training_graph_path.c_str(), do_eval ? params.model_evaluation_graph_path.value().c_str() : nullptr, - params.optimizer_training_graph_path.size() > 0 ? params.optimizer_training_graph_path.c_str() : nullptr, - &session)); + ORT_RETURN_ON_ERROR(g_ort_training_api->CreateTrainingSession(env, soptions, checkpoint_state, + params.model_training_graph_path.c_str(), do_eval ? params.model_evaluation_graph_path.value().c_str() : nullptr, + params.optimizer_training_graph_path.size() > 0 ? params.optimizer_training_graph_path.c_str() : nullptr, + &session)); size_t train_mode_output_count, eval_mode_output_count = 0; - ORT_RETURN_ON_ERROR(g_ort_api->TrainingSessionGetTrainModeOutputCount(session, &train_mode_output_count)); + ORT_RETURN_ON_ERROR(g_ort_training_api->TrainingSessionGetTrainModeOutputCount(session, &train_mode_output_count)); if (do_eval) { - ORT_RETURN_ON_ERROR(g_ort_api->TrainingSessionGetEvalModeOutputCount(session, &eval_mode_output_count)); + ORT_RETURN_ON_ERROR(g_ort_training_api->TrainingSessionGetEvalModeOutputCount(session, &eval_mode_output_count)); } int64_t sample_batch_count_per_epoch = 4; @@ -247,9 +248,9 @@ int RunTraining(const TestRunnerParameters& params) { #endif std::vector fetches(train_mode_output_count); - ORT_RETURN_ON_ERROR(g_ort_api->TrainStep(session, nullptr, - inputs.size(), (const OrtValue* const*)inputs.data(), - train_mode_output_count, fetches.data())); + ORT_RETURN_ON_ERROR(g_ort_training_api->TrainStep(session, nullptr, + inputs.size(), (const OrtValue* const*)inputs.data(), + train_mode_output_count, fetches.data())); #if defined(USE_CUDA) && defined(ENABLE_NVTX_PROFILE) train_step_range.End(); #endif @@ -266,7 +267,7 @@ int RunTraining(const TestRunnerParameters& params) { onnxruntime::profile::Color::Blue); opt_step_range.Begin(); #endif - ORT_RETURN_ON_ERROR(g_ort_api->OptimizerStep(session, nullptr)); + ORT_RETURN_ON_ERROR(g_ort_training_api->OptimizerStep(session, nullptr)); #if defined(USE_CUDA) && defined(ENABLE_NVTX_PROFILE) opt_step_range.End(); @@ -282,7 +283,7 @@ int RunTraining(const TestRunnerParameters& params) { resetgrad_range.Begin(); #endif - ORT_RETURN_ON_ERROR(g_ort_api->ResetGrad(session)); + ORT_RETURN_ON_ERROR(g_ort_training_api->ResetGrad(session)); #if defined(USE_CUDA) && defined(ENABLE_NVTX_PROFILE) resetgrad_range.End(); @@ -291,15 +292,15 @@ int RunTraining(const TestRunnerParameters& params) { if (do_eval && (batch_idx + 1) % params.eval_interval == 0) { std::vector eval_results(eval_mode_output_count); - ORT_RETURN_ON_ERROR(g_ort_api->EvalStep(session, nullptr, - inputs.size(), (const OrtValue* const*)inputs.data(), - train_mode_output_count, eval_results.data())); + ORT_RETURN_ON_ERROR(g_ort_training_api->EvalStep(session, nullptr, + inputs.size(), (const OrtValue* const*)inputs.data(), + train_mode_output_count, eval_results.data())); } if ((batch_idx + 1) % params.checkpoint_interval == 0) { // Save trained weights std::string ckpt_file = params.output_dir + "/ckpt_" + params.model_name + std::to_string(batch_idx); - ORT_RETURN_ON_ERROR(g_ort_api->SaveCheckpoint(ckpt_file.c_str(), session, true)); + ORT_RETURN_ON_ERROR(g_ort_training_api->SaveCheckpoint(ckpt_file.c_str(), session, true)); // TODO(baiju): enable adding more properties to checkpoint // state_to_save.property_bag.AddProperty(std::string("epoch"), epoch); @@ -321,7 +322,7 @@ int RunTraining(const TestRunnerParameters& params) { // Save trained weights std::string ckpt_file = params.output_dir + "/ckpt_" + params.model_name; - ORT_RETURN_ON_ERROR(g_ort_api->SaveCheckpoint(ckpt_file.c_str(), session, true)); + ORT_RETURN_ON_ERROR(g_ort_training_api->SaveCheckpoint(ckpt_file.c_str(), session, true)); auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration duration_seconds = end - end_to_end_start; @@ -330,14 +331,14 @@ int RunTraining(const TestRunnerParameters& params) { std::cout << "Training completed - end to end latency: " << stabilized_total_end_to_end_time << "(s)" << std::endl; // Delete all the ptrs - g_ort_api->ReleaseTrainingSession(session); + g_ort_training_api->ReleaseTrainingSession(session); #ifdef USE_CUDA // Finally, don't forget to release the provider options g_ort_api->ReleaseCUDAProviderOptions(cuda_options); #endif g_ort_api->ReleaseSessionOptions(soptions); - g_ort_api->ReleaseCheckpointState(checkpoint_state); + g_ort_training_api->ReleaseCheckpointState(checkpoint_state); g_ort_api->ReleaseEnv(env); return 0; diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h index c143ede565..e9789a1d06 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h @@ -1,40 +1,44 @@ -// This file contains c apis for on device training -// This file should never be included standalone -// It is included from within core/session/onnxruntime_c_api.h when -// on device training is enabled -// These apis can be moved to core/session/onnxruntime_c_api.h once they stabilize +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. -// DO NOT UNCOMMENT -//#include "core/session/onnxruntime_c_api.h" +// This file contains the training c apis. -#include +#pragma once +#include "core/session/onnxruntime_c_api.h" -ORT_API2_STATUS(LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Outptr_ OrtCheckpointState** checkpoint_state); +ORT_RUNTIME_CLASS(TrainingSession); +ORT_RUNTIME_CLASS(CheckpointState); -ORT_API2_STATUS(SaveCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Inout_ OrtTrainingSession* session, - bool save_optimizer_state); +struct OrtTrainingApi { + ORT_API2_STATUS(LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Outptr_ OrtCheckpointState** checkpoint_state); -ORT_API2_STATUS(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options, - _Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path, - _In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path, - _Outptr_ OrtTrainingSession** out); + ORT_API2_STATUS(SaveCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Inout_ OrtTrainingSession* session, + bool save_optimizer_state); -ORT_API2_STATUS(TrainingSessionGetTrainModeOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out); + ORT_API2_STATUS(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options, + _Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path, + _In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path, + _Outptr_ OrtTrainingSession** out); -ORT_API2_STATUS(TrainingSessionGetEvalModeOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out); + ORT_API2_STATUS(TrainingSessionGetTrainModeOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out); -ORT_API2_STATUS(ResetGrad, _Inout_ OrtTrainingSession* session); + ORT_API2_STATUS(TrainingSessionGetEvalModeOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out); -ORT_API2_STATUS(TrainStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options, - size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs, - size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs); + ORT_API2_STATUS(ResetGrad, _Inout_ OrtTrainingSession* session); -ORT_API2_STATUS(EvalStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options, - size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs, - size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs); + ORT_API2_STATUS(TrainStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options, + size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs, + size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs); -ORT_API2_STATUS(OptimizerStep, _Inout_ OrtTrainingSession* sess, - _In_opt_ const OrtRunOptions* run_options); + ORT_API2_STATUS(EvalStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options, + size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs, + size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs); -ORT_CLASS_RELEASE(TrainingSession); -ORT_CLASS_RELEASE(CheckpointState); + ORT_API2_STATUS(OptimizerStep, _Inout_ OrtTrainingSession* sess, + _In_opt_ const OrtRunOptions* run_options); + + ORT_CLASS_RELEASE(TrainingSession); + ORT_CLASS_RELEASE(CheckpointState); +}; + +typedef struct OrtTrainingApi OrtTrainingApi; diff --git a/orttraining/orttraining/training_api/include/ort_training_apis.h b/orttraining/orttraining/training_api/include/ort_training_apis.h new file mode 100644 index 0000000000..a627025fa4 --- /dev/null +++ b/orttraining/orttraining/training_api/include/ort_training_apis.h @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace OrtTrainingApis { + +ORT_API(const OrtTrainingApi*, GetTrainingApi, uint32_t version); + +ORT_API_STATUS_IMPL(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options, + _Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path, + _In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path, + _Outptr_ OrtTrainingSession** out); +ORT_API(void, ReleaseTrainingSession, _Frees_ptr_opt_ OrtTrainingSession* session); + +ORT_API_STATUS_IMPL(TrainingSessionGetTrainModeOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out); + +ORT_API_STATUS_IMPL(TrainingSessionGetEvalModeOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out); + +ORT_API_STATUS_IMPL(ResetGrad, _Inout_ OrtTrainingSession* session); + +ORT_API_STATUS_IMPL(TrainStep, _Inout_ OrtTrainingSession* session, _In_opt_ const OrtRunOptions* run_options, + size_t inputs_len, _In_reads_(input_len) const OrtValue* const* inputs, + size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs); + +ORT_API_STATUS_IMPL(EvalStep, _Inout_ OrtTrainingSession* session, _In_opt_ const OrtRunOptions* run_options, + size_t inputs_len, _In_reads_(input_len) const OrtValue* const* inputs, + size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs); + +ORT_API_STATUS_IMPL(OptimizerStep, _Inout_ OrtTrainingSession* session, _In_opt_ const OrtRunOptions* run_options); + +ORT_API_STATUS_IMPL(LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Outptr_ OrtCheckpointState** checkpoint_state); + +ORT_API_STATUS_IMPL(SaveCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Inout_ OrtTrainingSession* session, + bool save_optimizer_state); + +ORT_API(void, ReleaseCheckpointState, _Frees_ptr_opt_ OrtCheckpointState* session); + +} // namespace OrtTrainingApis diff --git a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc index 5a6237018c..793c7a6219 100644 --- a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc +++ b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc @@ -1,14 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/common/common.h" +#include "orttraining/training_api/include/onnxruntime_training_c_api.h" #include "core/framework/error_code_helper.h" -#include "core/framework/ort_value.h" #include "core/session/ort_apis.h" #include "core/session/ort_env.h" +#include "core/session/abi_session_options_impl.h" #include "orttraining/training_api/include/checkpoint.h" #include "orttraining/training_api/include/training_session.h" -#include "core/session/abi_session_options_impl.h" +#include "orttraining/training_api/include/ort_training_apis.h" namespace { @@ -24,7 +24,7 @@ std::vector> CreateProviders( } // namespace -ORT_API_STATUS_IMPL(OrtApis::CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options, +ORT_API_STATUS_IMPL(OrtTrainingApis::CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path, _In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path, _Outptr_ OrtTrainingSession** out) { @@ -60,7 +60,8 @@ ORT_API_STATUS_IMPL(OrtApis::CreateTrainingSession, _In_ const OrtEnv* env, _In_ API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::TrainingSessionGetTrainModeOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out) { +ORT_API_STATUS_IMPL(OrtTrainingApis::TrainingSessionGetTrainModeOutputCount, _In_ const OrtTrainingSession* sess, + _Out_ size_t* out) { API_IMPL_BEGIN auto session = reinterpret_cast(sess); *out = session->GetTrainModeOutputCount(); @@ -68,7 +69,8 @@ ORT_API_STATUS_IMPL(OrtApis::TrainingSessionGetTrainModeOutputCount, _In_ const API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::TrainingSessionGetEvalModeOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out) { +ORT_API_STATUS_IMPL(OrtTrainingApis::TrainingSessionGetEvalModeOutputCount, _In_ const OrtTrainingSession* sess, + _Out_ size_t* out) { API_IMPL_BEGIN auto session = reinterpret_cast(sess); *out = session->GetEvalModeOutputCount(); @@ -76,7 +78,7 @@ ORT_API_STATUS_IMPL(OrtApis::TrainingSessionGetEvalModeOutputCount, _In_ const O API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::ResetGrad, _Inout_ OrtTrainingSession* session) { +ORT_API_STATUS_IMPL(OrtTrainingApis::ResetGrad, _Inout_ OrtTrainingSession* session) { API_IMPL_BEGIN auto train_session = reinterpret_cast(session); ORT_API_RETURN_IF_STATUS_NOT_OK(train_session->ResetGrad()); @@ -85,9 +87,10 @@ ORT_API_STATUS_IMPL(OrtApis::ResetGrad, _Inout_ OrtTrainingSession* session) { API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::TrainStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options, - size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs, - size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs) { +ORT_API_STATUS_IMPL(OrtTrainingApis::TrainStep, _Inout_ OrtTrainingSession* sess, + _In_opt_ const OrtRunOptions* run_options, size_t inputs_len, + _In_reads_(inputs_len) const OrtValue* const* inputs, size_t outputs_len, + _Inout_updates_all_(outputs_len) OrtValue** outputs) { API_IMPL_BEGIN auto session = reinterpret_cast(sess); constexpr int queue_id = 0; @@ -133,9 +136,10 @@ ORT_API_STATUS_IMPL(OrtApis::TrainStep, _Inout_ OrtTrainingSession* sess, _In_op API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::EvalStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options, - size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs, - size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs) { +ORT_API_STATUS_IMPL(OrtTrainingApis::EvalStep, _Inout_ OrtTrainingSession* sess, + _In_opt_ const OrtRunOptions* run_options, size_t inputs_len, + _In_reads_(inputs_len) const OrtValue* const* inputs, size_t outputs_len, + _Inout_updates_all_(outputs_len) OrtValue** outputs) { API_IMPL_BEGIN auto session = reinterpret_cast(sess); constexpr int queue_id = 0; @@ -180,7 +184,7 @@ ORT_API_STATUS_IMPL(OrtApis::EvalStep, _Inout_ OrtTrainingSession* sess, _In_opt API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::OptimizerStep, _Inout_ OrtTrainingSession* sess, +ORT_API_STATUS_IMPL(OrtTrainingApis::OptimizerStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options) { API_IMPL_BEGIN auto session = reinterpret_cast(sess); @@ -195,7 +199,8 @@ ORT_API_STATUS_IMPL(OrtApis::OptimizerStep, _Inout_ OrtTrainingSession* sess, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Outptr_ OrtCheckpointState** checkpoint_state) { +ORT_API_STATUS_IMPL(OrtTrainingApis::LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, + _Outptr_ OrtCheckpointState** checkpoint_state) { API_IMPL_BEGIN *checkpoint_state = nullptr; auto chkpt_state = std::make_unique(); @@ -206,8 +211,8 @@ ORT_API_STATUS_IMPL(OrtApis::LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_pa API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::SaveCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Inout_ OrtTrainingSession* sess, - bool save_optimizer_state) { +ORT_API_STATUS_IMPL(OrtTrainingApis::SaveCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, + _Inout_ OrtTrainingSession* sess, bool save_optimizer_state) { API_IMPL_BEGIN auto session = reinterpret_cast(sess); onnxruntime::training::api::CheckpointState chkpt_state; @@ -218,10 +223,29 @@ ORT_API_STATUS_IMPL(OrtApis::SaveCheckpoint, _In_ const ORTCHAR_T* checkpoint_pa API_IMPL_END } -ORT_API(void, OrtApis::ReleaseTrainingSession, _Frees_ptr_opt_ OrtTrainingSession* session) { +ORT_API(void, OrtTrainingApis::ReleaseTrainingSession, _Frees_ptr_opt_ OrtTrainingSession* session) { delete reinterpret_cast(session); } -ORT_API(void, OrtApis::ReleaseCheckpointState, _Frees_ptr_opt_ OrtCheckpointState* checkpoint_state) { +ORT_API(void, OrtTrainingApis::ReleaseCheckpointState, _Frees_ptr_opt_ OrtCheckpointState* checkpoint_state) { delete reinterpret_cast(checkpoint_state); } + +static constexpr OrtTrainingApi ort_training_api = { + &OrtTrainingApis::LoadCheckpoint, + &OrtTrainingApis::SaveCheckpoint, + &OrtTrainingApis::CreateTrainingSession, + &OrtTrainingApis::TrainingSessionGetTrainModeOutputCount, + &OrtTrainingApis::TrainingSessionGetEvalModeOutputCount, + &OrtTrainingApis::ResetGrad, + &OrtTrainingApis::TrainStep, + &OrtTrainingApis::EvalStep, + &OrtTrainingApis::OptimizerStep, + &OrtTrainingApis::ReleaseTrainingSession, + &OrtTrainingApis::ReleaseCheckpointState, +}; + +ORT_API(const OrtTrainingApi*, OrtTrainingApis::GetTrainingApi, uint32_t) { + // No constraints on the API version yet. + return &ort_training_api; +}