Separate training apis from shared core apis (#12027)

This commit is contained in:
Baiju Meswani 2022-06-29 14:12:29 -07:00 committed by GitHub
parent d25cf4df26
commit 6e8edfff0c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 162 additions and 126 deletions

View file

@ -1300,6 +1300,9 @@ function(onnxruntime_configure_target target_name)
# set_target_properties(${target_name} PROPERTIES VS_USER_PROPS ${PROJECT_SOURCE_DIR}/EnableVisualStudioCodeAnalysis.props)
#endif()
target_include_directories(${target_name} PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT} ${abseil_cpp_SOURCE_DIR})
if (onnxruntime_ENABLE_TRAINING_ON_DEVICE)
target_include_directories(${target_name} PRIVATE ${ORTTRAINING_ROOT})
endif()
if (onnxruntime_ENABLE_LTO)
set_target_properties(${target_name} PROPERTIES INTERPROCEDURAL_OPTIMIZATION_RELEASE TRUE)
set_target_properties(${target_name} PROPERTIES INTERPROCEDURAL_OPTIMIZATION_RELWITHDEBINFO TRUE)
@ -1689,6 +1692,9 @@ set(onnxruntime_DELAYLOAD_FLAGS "")
include_directories(
${ONNXRUNTIME_INCLUDE_DIR}
${REPO_ROOT}/include/onnxruntime/core/session
if (onnxruntime_ENABLE_TRAINING_ON_DEVICE)
${REPO_ROOT}/orttraining/orttraining/training_api/include/
endif()
)
if (onnxruntime_USE_OPENVINO)

View file

@ -270,11 +270,6 @@ ORT_RUNTIME_CLASS(CUDAProviderOptionsV2);
ORT_RUNTIME_CLASS(Op);
ORT_RUNTIME_CLASS(OpAttr);
#ifdef ENABLE_TRAINING_ON_DEVICE
ORT_RUNTIME_CLASS(TrainingSession);
ORT_RUNTIME_CLASS(CheckpointState);
#endif
#ifdef _WIN32
typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr;
#else
@ -539,6 +534,11 @@ typedef struct OrtOpenVINOProviderOptions {
struct OrtApi;
typedef struct OrtApi OrtApi;
#ifdef ENABLE_TRAINING_ON_DEVICE
struct OrtTrainingApi;
typedef struct OrtTrainingApi OrtTrainingApi;
#endif
/** \brief The helper interface to get the right version of OrtApi
*
* Get a pointer to this structure through ::OrtGetApiBase
@ -551,6 +551,14 @@ struct OrtApiBase {
* older than the version created with this header file.
*/
const OrtApi*(ORT_API_CALL* GetApi)(uint32_t version)NO_EXCEPTION;
#ifdef ENABLE_TRAINING_ON_DEVICE
/** \brief Get a pointer to the requested version of the ::OrtTrainingApi
*
* \param[in] version Must be ::ORT_API_VERSION
* \return The ::OrtTrainingApi for the version requested, nullptr will be returned if this version is unsupported.
*/
const OrtTrainingApi*(ORT_API_CALL* GetTrainingApi)(uint32_t version)NO_EXCEPTION;
#endif
const char*(ORT_API_CALL* GetVersionString)(void)NO_EXCEPTION; ///< Returns a null terminated string of the version of the Onnxruntime library (eg: "1.8.1")
};
typedef struct OrtApiBase OrtApiBase;
@ -3452,12 +3460,6 @@ struct OrtApi {
_In_reads_(num_keys) const char* const* provider_options_keys,
_In_reads_(num_keys) const char* const* provider_options_values,
_In_ size_t num_keys);
#ifdef ENABLE_TRAINING_ON_DEVICE
// defines c apis for on device training scenarios
#include "../../../orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h"
#endif
};
/*

View file

@ -45,6 +45,11 @@ ProviderInfo_CUDA* TryGetProviderInfo_CUDA();
}
#endif
#ifdef ENABLE_TRAINING_ON_DEVICE
#include "orttraining/training_api/include/onnxruntime_training_c_api.h"
#include "orttraining/training_api/include/ort_training_apis.h"
#endif
#ifdef USE_DML
#include "core/providers/dml/dml_provider_factory.h"
const OrtDmlApi* GetOrtDmlApi(_In_ uint32_t version) NO_EXCEPTION;
@ -2229,6 +2234,9 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetCustomJoinThreadFn, _Inout_ OrtSes
static constexpr OrtApiBase ort_api_base = {
&OrtApis::GetApi,
#ifdef ENABLE_TRAINING_ON_DEVICE
&OrtTrainingApis::GetTrainingApi,
#endif
&OrtApis::GetVersionString,
};
@ -2528,21 +2536,6 @@ static constexpr OrtApi ort_api_1_to_12 = {
&OrtApis::InvokeOp,
&OrtApis::ReleaseOp,
&OrtApis::SessionOptionsAppendExecutionProvider_SNPE,
#ifdef ENABLE_TRAINING_ON_DEVICE
// Experimental for on-device training. Always keep at the bottom.
&OrtApis::LoadCheckpoint,
&OrtApis::SaveCheckpoint,
&OrtApis::CreateTrainingSession,
&OrtApis::TrainingSessionGetTrainModeOutputCount,
&OrtApis::TrainingSessionGetEvalModeOutputCount,
&OrtApis::ResetGrad,
&OrtApis::TrainStep,
&OrtApis::EvalStep,
&OrtApis::OptimizerStep,
&OrtApis::ReleaseTrainingSession,
&OrtApis::ReleaseCheckpointState,
#endif
};
// Asserts to do a some checks to ensure older Versions of the OrtApi never change (will detect an addition or deletion but not if they cancel out each other)

View file

@ -381,35 +381,4 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_SNPE,
_In_reads_(num_keys) const char* const* provider_options_values,
_In_ size_t num_keys);
#ifdef ENABLE_TRAINING_ON_DEVICE
ORT_API_STATUS_IMPL(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options,
_Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path,
_In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path,
_Outptr_ OrtTrainingSession** out);
ORT_API(void, ReleaseTrainingSession, _Frees_ptr_opt_ OrtTrainingSession* session);
ORT_API_STATUS_IMPL(TrainingSessionGetTrainModeOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
ORT_API_STATUS_IMPL(TrainingSessionGetEvalModeOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
ORT_API_STATUS_IMPL(ResetGrad, _Inout_ OrtTrainingSession* session);
ORT_API_STATUS_IMPL(TrainStep, _Inout_ OrtTrainingSession* session, _In_opt_ const OrtRunOptions* run_options,
size_t inputs_len, _In_reads_(input_len) const OrtValue* const* inputs,
size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
ORT_API_STATUS_IMPL(EvalStep, _Inout_ OrtTrainingSession* session, _In_opt_ const OrtRunOptions* run_options,
size_t inputs_len, _In_reads_(input_len) const OrtValue* const* inputs,
size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
ORT_API_STATUS_IMPL(OptimizerStep, _Inout_ OrtTrainingSession* session, _In_opt_ const OrtRunOptions* run_options);
ORT_API_STATUS_IMPL(LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Outptr_ OrtCheckpointState** checkpoint_state);
ORT_API_STATUS_IMPL(SaveCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Inout_ OrtTrainingSession* session,
bool save_optimizer_state);
ORT_API(void, ReleaseCheckpointState, _Frees_ptr_opt_ OrtCheckpointState* session);
#endif
} // namespace OrtApis

View file

@ -171,7 +171,7 @@ def build_gradient_accumulation_graph(grad_model, all_args_requiring_gradient_na
# accumulated gradient update flag is also a graph output
grad_accumulation_output = onnx.helper.make_tensor_value_info(
grad_accumulation_output_name, onnx.TensorProto.BOOL, [1]
)
)
graph_outputs.append(grad_accumulation_output)
lazy_reset_grad_input = onnx.helper.make_tensor_value_info(lazy_reset_grad_input_name, onnx.TensorProto.BOOL, [1])

View file

@ -2,7 +2,7 @@
// Licensed under the MIT License.
#include <onnxruntime_c_api.h>
#include "orttraining/training_api/include/utils.h"
#include <onnxruntime_training_c_api.h>
#include "cxxopts.hpp"
#include "../common/synthetic_data_loader.h"
@ -13,10 +13,10 @@
#include "core/providers/cuda/nvtx_profile_context.h"
#endif
using namespace onnxruntime::training::api;
using namespace std;
const OrtApi* g_ort_api = nullptr;
const OrtTrainingApi* g_ort_training_api = nullptr;
struct TestRunnerParameters {
// Models configs.
@ -173,6 +173,7 @@ void InitSyntheticDataLoader(
int RunTraining(const TestRunnerParameters& params) {
g_ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
g_ort_training_api = OrtGetApiBase()->GetTrainingApi(ORT_API_VERSION);
// Create Env
OrtEnv* env;
@ -185,7 +186,7 @@ int RunTraining(const TestRunnerParameters& params) {
// Load Checkpoint State
OrtCheckpointState* checkpoint_state;
ORT_RETURN_ON_ERROR(g_ort_api->LoadCheckpoint(params.checkpoint_to_load_path.c_str(), &checkpoint_state));
ORT_RETURN_ON_ERROR(g_ort_training_api->LoadCheckpoint(params.checkpoint_to_load_path.c_str(), &checkpoint_state));
// Create TrainingSession
OrtSessionOptions* soptions;
@ -199,16 +200,16 @@ int RunTraining(const TestRunnerParameters& params) {
OrtTrainingSession* session;
bool do_eval = params.model_evaluation_graph_path.has_value();
ORT_RETURN_ON_ERROR(g_ort_api->CreateTrainingSession(env, soptions, checkpoint_state,
params.model_training_graph_path.c_str(), do_eval ? params.model_evaluation_graph_path.value().c_str() : nullptr,
params.optimizer_training_graph_path.size() > 0 ? params.optimizer_training_graph_path.c_str() : nullptr,
&session));
ORT_RETURN_ON_ERROR(g_ort_training_api->CreateTrainingSession(env, soptions, checkpoint_state,
params.model_training_graph_path.c_str(), do_eval ? params.model_evaluation_graph_path.value().c_str() : nullptr,
params.optimizer_training_graph_path.size() > 0 ? params.optimizer_training_graph_path.c_str() : nullptr,
&session));
size_t train_mode_output_count, eval_mode_output_count = 0;
ORT_RETURN_ON_ERROR(g_ort_api->TrainingSessionGetTrainModeOutputCount(session, &train_mode_output_count));
ORT_RETURN_ON_ERROR(g_ort_training_api->TrainingSessionGetTrainModeOutputCount(session, &train_mode_output_count));
if (do_eval) {
ORT_RETURN_ON_ERROR(g_ort_api->TrainingSessionGetEvalModeOutputCount(session, &eval_mode_output_count));
ORT_RETURN_ON_ERROR(g_ort_training_api->TrainingSessionGetEvalModeOutputCount(session, &eval_mode_output_count));
}
int64_t sample_batch_count_per_epoch = 4;
@ -247,9 +248,9 @@ int RunTraining(const TestRunnerParameters& params) {
#endif
std::vector<OrtValue*> fetches(train_mode_output_count);
ORT_RETURN_ON_ERROR(g_ort_api->TrainStep(session, nullptr,
inputs.size(), (const OrtValue* const*)inputs.data(),
train_mode_output_count, fetches.data()));
ORT_RETURN_ON_ERROR(g_ort_training_api->TrainStep(session, nullptr,
inputs.size(), (const OrtValue* const*)inputs.data(),
train_mode_output_count, fetches.data()));
#if defined(USE_CUDA) && defined(ENABLE_NVTX_PROFILE)
train_step_range.End();
#endif
@ -266,7 +267,7 @@ int RunTraining(const TestRunnerParameters& params) {
onnxruntime::profile::Color::Blue);
opt_step_range.Begin();
#endif
ORT_RETURN_ON_ERROR(g_ort_api->OptimizerStep(session, nullptr));
ORT_RETURN_ON_ERROR(g_ort_training_api->OptimizerStep(session, nullptr));
#if defined(USE_CUDA) && defined(ENABLE_NVTX_PROFILE)
opt_step_range.End();
@ -282,7 +283,7 @@ int RunTraining(const TestRunnerParameters& params) {
resetgrad_range.Begin();
#endif
ORT_RETURN_ON_ERROR(g_ort_api->ResetGrad(session));
ORT_RETURN_ON_ERROR(g_ort_training_api->ResetGrad(session));
#if defined(USE_CUDA) && defined(ENABLE_NVTX_PROFILE)
resetgrad_range.End();
@ -291,15 +292,15 @@ int RunTraining(const TestRunnerParameters& params) {
if (do_eval && (batch_idx + 1) % params.eval_interval == 0) {
std::vector<OrtValue*> eval_results(eval_mode_output_count);
ORT_RETURN_ON_ERROR(g_ort_api->EvalStep(session, nullptr,
inputs.size(), (const OrtValue* const*)inputs.data(),
train_mode_output_count, eval_results.data()));
ORT_RETURN_ON_ERROR(g_ort_training_api->EvalStep(session, nullptr,
inputs.size(), (const OrtValue* const*)inputs.data(),
train_mode_output_count, eval_results.data()));
}
if ((batch_idx + 1) % params.checkpoint_interval == 0) {
// Save trained weights
std::string ckpt_file = params.output_dir + "/ckpt_" + params.model_name + std::to_string(batch_idx);
ORT_RETURN_ON_ERROR(g_ort_api->SaveCheckpoint(ckpt_file.c_str(), session, true));
ORT_RETURN_ON_ERROR(g_ort_training_api->SaveCheckpoint(ckpt_file.c_str(), session, true));
// TODO(baiju): enable adding more properties to checkpoint
// state_to_save.property_bag.AddProperty<int64_t>(std::string("epoch"), epoch);
@ -321,7 +322,7 @@ int RunTraining(const TestRunnerParameters& params) {
// Save trained weights
std::string ckpt_file = params.output_dir + "/ckpt_" + params.model_name;
ORT_RETURN_ON_ERROR(g_ort_api->SaveCheckpoint(ckpt_file.c_str(), session, true));
ORT_RETURN_ON_ERROR(g_ort_training_api->SaveCheckpoint(ckpt_file.c_str(), session, true));
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> duration_seconds = end - end_to_end_start;
@ -330,14 +331,14 @@ int RunTraining(const TestRunnerParameters& params) {
std::cout << "Training completed - end to end latency: " << stabilized_total_end_to_end_time << "(s)" << std::endl;
// Delete all the ptrs
g_ort_api->ReleaseTrainingSession(session);
g_ort_training_api->ReleaseTrainingSession(session);
#ifdef USE_CUDA
// Finally, don't forget to release the provider options
g_ort_api->ReleaseCUDAProviderOptions(cuda_options);
#endif
g_ort_api->ReleaseSessionOptions(soptions);
g_ort_api->ReleaseCheckpointState(checkpoint_state);
g_ort_training_api->ReleaseCheckpointState(checkpoint_state);
g_ort_api->ReleaseEnv(env);
return 0;

View file

@ -1,40 +1,44 @@
// This file contains c apis for on device training
// This file should never be included standalone
// It is included from within core/session/onnxruntime_c_api.h when
// on device training is enabled
// These apis can be moved to core/session/onnxruntime_c_api.h once they stabilize
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// DO NOT UNCOMMENT
//#include "core/session/onnxruntime_c_api.h"
// This file contains the training c apis.
#include <stdbool.h>
#pragma once
#include "core/session/onnxruntime_c_api.h"
ORT_API2_STATUS(LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Outptr_ OrtCheckpointState** checkpoint_state);
ORT_RUNTIME_CLASS(TrainingSession);
ORT_RUNTIME_CLASS(CheckpointState);
ORT_API2_STATUS(SaveCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Inout_ OrtTrainingSession* session,
bool save_optimizer_state);
struct OrtTrainingApi {
ORT_API2_STATUS(LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Outptr_ OrtCheckpointState** checkpoint_state);
ORT_API2_STATUS(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options,
_Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path,
_In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path,
_Outptr_ OrtTrainingSession** out);
ORT_API2_STATUS(SaveCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Inout_ OrtTrainingSession* session,
bool save_optimizer_state);
ORT_API2_STATUS(TrainingSessionGetTrainModeOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
ORT_API2_STATUS(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options,
_Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path,
_In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path,
_Outptr_ OrtTrainingSession** out);
ORT_API2_STATUS(TrainingSessionGetEvalModeOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
ORT_API2_STATUS(TrainingSessionGetTrainModeOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
ORT_API2_STATUS(ResetGrad, _Inout_ OrtTrainingSession* session);
ORT_API2_STATUS(TrainingSessionGetEvalModeOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
ORT_API2_STATUS(TrainStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options,
size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs,
size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
ORT_API2_STATUS(ResetGrad, _Inout_ OrtTrainingSession* session);
ORT_API2_STATUS(EvalStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options,
size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs,
size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
ORT_API2_STATUS(TrainStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options,
size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs,
size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
ORT_API2_STATUS(OptimizerStep, _Inout_ OrtTrainingSession* sess,
_In_opt_ const OrtRunOptions* run_options);
ORT_API2_STATUS(EvalStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options,
size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs,
size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
ORT_CLASS_RELEASE(TrainingSession);
ORT_CLASS_RELEASE(CheckpointState);
ORT_API2_STATUS(OptimizerStep, _Inout_ OrtTrainingSession* sess,
_In_opt_ const OrtRunOptions* run_options);
ORT_CLASS_RELEASE(TrainingSession);
ORT_CLASS_RELEASE(CheckpointState);
};
typedef struct OrtTrainingApi OrtTrainingApi;

View file

@ -0,0 +1,37 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
namespace OrtTrainingApis {
ORT_API(const OrtTrainingApi*, GetTrainingApi, uint32_t version);
ORT_API_STATUS_IMPL(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options,
_Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path,
_In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path,
_Outptr_ OrtTrainingSession** out);
ORT_API(void, ReleaseTrainingSession, _Frees_ptr_opt_ OrtTrainingSession* session);
ORT_API_STATUS_IMPL(TrainingSessionGetTrainModeOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
ORT_API_STATUS_IMPL(TrainingSessionGetEvalModeOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
ORT_API_STATUS_IMPL(ResetGrad, _Inout_ OrtTrainingSession* session);
ORT_API_STATUS_IMPL(TrainStep, _Inout_ OrtTrainingSession* session, _In_opt_ const OrtRunOptions* run_options,
size_t inputs_len, _In_reads_(input_len) const OrtValue* const* inputs,
size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
ORT_API_STATUS_IMPL(EvalStep, _Inout_ OrtTrainingSession* session, _In_opt_ const OrtRunOptions* run_options,
size_t inputs_len, _In_reads_(input_len) const OrtValue* const* inputs,
size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
ORT_API_STATUS_IMPL(OptimizerStep, _Inout_ OrtTrainingSession* session, _In_opt_ const OrtRunOptions* run_options);
ORT_API_STATUS_IMPL(LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Outptr_ OrtCheckpointState** checkpoint_state);
ORT_API_STATUS_IMPL(SaveCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Inout_ OrtTrainingSession* session,
bool save_optimizer_state);
ORT_API(void, ReleaseCheckpointState, _Frees_ptr_opt_ OrtCheckpointState* session);
} // namespace OrtTrainingApis

View file

@ -1,14 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/common/common.h"
#include "orttraining/training_api/include/onnxruntime_training_c_api.h"
#include "core/framework/error_code_helper.h"
#include "core/framework/ort_value.h"
#include "core/session/ort_apis.h"
#include "core/session/ort_env.h"
#include "core/session/abi_session_options_impl.h"
#include "orttraining/training_api/include/checkpoint.h"
#include "orttraining/training_api/include/training_session.h"
#include "core/session/abi_session_options_impl.h"
#include "orttraining/training_api/include/ort_training_apis.h"
namespace {
@ -24,7 +24,7 @@ std::vector<std::shared_ptr<onnxruntime::IExecutionProvider>> CreateProviders(
} // namespace
ORT_API_STATUS_IMPL(OrtApis::CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options,
ORT_API_STATUS_IMPL(OrtTrainingApis::CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options,
_Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path,
_In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path,
_Outptr_ OrtTrainingSession** out) {
@ -60,7 +60,8 @@ ORT_API_STATUS_IMPL(OrtApis::CreateTrainingSession, _In_ const OrtEnv* env, _In_
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::TrainingSessionGetTrainModeOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out) {
ORT_API_STATUS_IMPL(OrtTrainingApis::TrainingSessionGetTrainModeOutputCount, _In_ const OrtTrainingSession* sess,
_Out_ size_t* out) {
API_IMPL_BEGIN
auto session = reinterpret_cast<const onnxruntime::training::api::TrainingSession*>(sess);
*out = session->GetTrainModeOutputCount();
@ -68,7 +69,8 @@ ORT_API_STATUS_IMPL(OrtApis::TrainingSessionGetTrainModeOutputCount, _In_ const
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::TrainingSessionGetEvalModeOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out) {
ORT_API_STATUS_IMPL(OrtTrainingApis::TrainingSessionGetEvalModeOutputCount, _In_ const OrtTrainingSession* sess,
_Out_ size_t* out) {
API_IMPL_BEGIN
auto session = reinterpret_cast<const onnxruntime::training::api::TrainingSession*>(sess);
*out = session->GetEvalModeOutputCount();
@ -76,7 +78,7 @@ ORT_API_STATUS_IMPL(OrtApis::TrainingSessionGetEvalModeOutputCount, _In_ const O
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::ResetGrad, _Inout_ OrtTrainingSession* session) {
ORT_API_STATUS_IMPL(OrtTrainingApis::ResetGrad, _Inout_ OrtTrainingSession* session) {
API_IMPL_BEGIN
auto train_session = reinterpret_cast<onnxruntime::training::api::TrainingSession*>(session);
ORT_API_RETURN_IF_STATUS_NOT_OK(train_session->ResetGrad());
@ -85,9 +87,10 @@ ORT_API_STATUS_IMPL(OrtApis::ResetGrad, _Inout_ OrtTrainingSession* session) {
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::TrainStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options,
size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs,
size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs) {
ORT_API_STATUS_IMPL(OrtTrainingApis::TrainStep, _Inout_ OrtTrainingSession* sess,
_In_opt_ const OrtRunOptions* run_options, size_t inputs_len,
_In_reads_(inputs_len) const OrtValue* const* inputs, size_t outputs_len,
_Inout_updates_all_(outputs_len) OrtValue** outputs) {
API_IMPL_BEGIN
auto session = reinterpret_cast<onnxruntime::training::api::TrainingSession*>(sess);
constexpr int queue_id = 0;
@ -133,9 +136,10 @@ ORT_API_STATUS_IMPL(OrtApis::TrainStep, _Inout_ OrtTrainingSession* sess, _In_op
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::EvalStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options,
size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs,
size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs) {
ORT_API_STATUS_IMPL(OrtTrainingApis::EvalStep, _Inout_ OrtTrainingSession* sess,
_In_opt_ const OrtRunOptions* run_options, size_t inputs_len,
_In_reads_(inputs_len) const OrtValue* const* inputs, size_t outputs_len,
_Inout_updates_all_(outputs_len) OrtValue** outputs) {
API_IMPL_BEGIN
auto session = reinterpret_cast<onnxruntime::training::api::TrainingSession*>(sess);
constexpr int queue_id = 0;
@ -180,7 +184,7 @@ ORT_API_STATUS_IMPL(OrtApis::EvalStep, _Inout_ OrtTrainingSession* sess, _In_opt
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::OptimizerStep, _Inout_ OrtTrainingSession* sess,
ORT_API_STATUS_IMPL(OrtTrainingApis::OptimizerStep, _Inout_ OrtTrainingSession* sess,
_In_opt_ const OrtRunOptions* run_options) {
API_IMPL_BEGIN
auto session = reinterpret_cast<onnxruntime::training::api::TrainingSession*>(sess);
@ -195,7 +199,8 @@ ORT_API_STATUS_IMPL(OrtApis::OptimizerStep, _Inout_ OrtTrainingSession* sess,
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Outptr_ OrtCheckpointState** checkpoint_state) {
ORT_API_STATUS_IMPL(OrtTrainingApis::LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path,
_Outptr_ OrtCheckpointState** checkpoint_state) {
API_IMPL_BEGIN
*checkpoint_state = nullptr;
auto chkpt_state = std::make_unique<onnxruntime::training::api::CheckpointState>();
@ -206,8 +211,8 @@ ORT_API_STATUS_IMPL(OrtApis::LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_pa
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::SaveCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Inout_ OrtTrainingSession* sess,
bool save_optimizer_state) {
ORT_API_STATUS_IMPL(OrtTrainingApis::SaveCheckpoint, _In_ const ORTCHAR_T* checkpoint_path,
_Inout_ OrtTrainingSession* sess, bool save_optimizer_state) {
API_IMPL_BEGIN
auto session = reinterpret_cast<onnxruntime::training::api::TrainingSession*>(sess);
onnxruntime::training::api::CheckpointState chkpt_state;
@ -218,10 +223,29 @@ ORT_API_STATUS_IMPL(OrtApis::SaveCheckpoint, _In_ const ORTCHAR_T* checkpoint_pa
API_IMPL_END
}
ORT_API(void, OrtApis::ReleaseTrainingSession, _Frees_ptr_opt_ OrtTrainingSession* session) {
ORT_API(void, OrtTrainingApis::ReleaseTrainingSession, _Frees_ptr_opt_ OrtTrainingSession* session) {
delete reinterpret_cast<onnxruntime::training::api::TrainingSession*>(session);
}
ORT_API(void, OrtApis::ReleaseCheckpointState, _Frees_ptr_opt_ OrtCheckpointState* checkpoint_state) {
ORT_API(void, OrtTrainingApis::ReleaseCheckpointState, _Frees_ptr_opt_ OrtCheckpointState* checkpoint_state) {
delete reinterpret_cast<onnxruntime::training::api::CheckpointState*>(checkpoint_state);
}
static constexpr OrtTrainingApi ort_training_api = {
&OrtTrainingApis::LoadCheckpoint,
&OrtTrainingApis::SaveCheckpoint,
&OrtTrainingApis::CreateTrainingSession,
&OrtTrainingApis::TrainingSessionGetTrainModeOutputCount,
&OrtTrainingApis::TrainingSessionGetEvalModeOutputCount,
&OrtTrainingApis::ResetGrad,
&OrtTrainingApis::TrainStep,
&OrtTrainingApis::EvalStep,
&OrtTrainingApis::OptimizerStep,
&OrtTrainingApis::ReleaseTrainingSession,
&OrtTrainingApis::ReleaseCheckpointState,
};
ORT_API(const OrtTrainingApi*, OrtTrainingApis::GetTrainingApi, uint32_t) {
// No constraints on the API version yet.
return &ort_training_api;
}