mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Separate training apis from shared core apis (#12027)
This commit is contained in:
parent
d25cf4df26
commit
6e8edfff0c
9 changed files with 162 additions and 126 deletions
|
|
@ -1300,6 +1300,9 @@ function(onnxruntime_configure_target target_name)
|
|||
# set_target_properties(${target_name} PROPERTIES VS_USER_PROPS ${PROJECT_SOURCE_DIR}/EnableVisualStudioCodeAnalysis.props)
|
||||
#endif()
|
||||
target_include_directories(${target_name} PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT} ${abseil_cpp_SOURCE_DIR})
|
||||
if (onnxruntime_ENABLE_TRAINING_ON_DEVICE)
|
||||
target_include_directories(${target_name} PRIVATE ${ORTTRAINING_ROOT})
|
||||
endif()
|
||||
if (onnxruntime_ENABLE_LTO)
|
||||
set_target_properties(${target_name} PROPERTIES INTERPROCEDURAL_OPTIMIZATION_RELEASE TRUE)
|
||||
set_target_properties(${target_name} PROPERTIES INTERPROCEDURAL_OPTIMIZATION_RELWITHDEBINFO TRUE)
|
||||
|
|
@ -1689,6 +1692,9 @@ set(onnxruntime_DELAYLOAD_FLAGS "")
|
|||
include_directories(
|
||||
${ONNXRUNTIME_INCLUDE_DIR}
|
||||
${REPO_ROOT}/include/onnxruntime/core/session
|
||||
if (onnxruntime_ENABLE_TRAINING_ON_DEVICE)
|
||||
${REPO_ROOT}/orttraining/orttraining/training_api/include/
|
||||
endif()
|
||||
)
|
||||
|
||||
if (onnxruntime_USE_OPENVINO)
|
||||
|
|
|
|||
|
|
@ -270,11 +270,6 @@ ORT_RUNTIME_CLASS(CUDAProviderOptionsV2);
|
|||
ORT_RUNTIME_CLASS(Op);
|
||||
ORT_RUNTIME_CLASS(OpAttr);
|
||||
|
||||
#ifdef ENABLE_TRAINING_ON_DEVICE
|
||||
ORT_RUNTIME_CLASS(TrainingSession);
|
||||
ORT_RUNTIME_CLASS(CheckpointState);
|
||||
#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr;
|
||||
#else
|
||||
|
|
@ -539,6 +534,11 @@ typedef struct OrtOpenVINOProviderOptions {
|
|||
struct OrtApi;
|
||||
typedef struct OrtApi OrtApi;
|
||||
|
||||
#ifdef ENABLE_TRAINING_ON_DEVICE
|
||||
struct OrtTrainingApi;
|
||||
typedef struct OrtTrainingApi OrtTrainingApi;
|
||||
#endif
|
||||
|
||||
/** \brief The helper interface to get the right version of OrtApi
|
||||
*
|
||||
* Get a pointer to this structure through ::OrtGetApiBase
|
||||
|
|
@ -551,6 +551,14 @@ struct OrtApiBase {
|
|||
* older than the version created with this header file.
|
||||
*/
|
||||
const OrtApi*(ORT_API_CALL* GetApi)(uint32_t version)NO_EXCEPTION;
|
||||
#ifdef ENABLE_TRAINING_ON_DEVICE
|
||||
/** \brief Get a pointer to the requested version of the ::OrtTrainingApi
|
||||
*
|
||||
* \param[in] version Must be ::ORT_API_VERSION
|
||||
* \return The ::OrtTrainingApi for the version requested, nullptr will be returned if this version is unsupported.
|
||||
*/
|
||||
const OrtTrainingApi*(ORT_API_CALL* GetTrainingApi)(uint32_t version)NO_EXCEPTION;
|
||||
#endif
|
||||
const char*(ORT_API_CALL* GetVersionString)(void)NO_EXCEPTION; ///< Returns a null terminated string of the version of the Onnxruntime library (eg: "1.8.1")
|
||||
};
|
||||
typedef struct OrtApiBase OrtApiBase;
|
||||
|
|
@ -3452,12 +3460,6 @@ struct OrtApi {
|
|||
_In_reads_(num_keys) const char* const* provider_options_keys,
|
||||
_In_reads_(num_keys) const char* const* provider_options_values,
|
||||
_In_ size_t num_keys);
|
||||
|
||||
#ifdef ENABLE_TRAINING_ON_DEVICE
|
||||
// defines c apis for on device training scenarios
|
||||
#include "../../../orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h"
|
||||
#endif
|
||||
|
||||
};
|
||||
|
||||
/*
|
||||
|
|
|
|||
|
|
@ -45,6 +45,11 @@ ProviderInfo_CUDA* TryGetProviderInfo_CUDA();
|
|||
}
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_TRAINING_ON_DEVICE
|
||||
#include "orttraining/training_api/include/onnxruntime_training_c_api.h"
|
||||
#include "orttraining/training_api/include/ort_training_apis.h"
|
||||
#endif
|
||||
|
||||
#ifdef USE_DML
|
||||
#include "core/providers/dml/dml_provider_factory.h"
|
||||
const OrtDmlApi* GetOrtDmlApi(_In_ uint32_t version) NO_EXCEPTION;
|
||||
|
|
@ -2229,6 +2234,9 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetCustomJoinThreadFn, _Inout_ OrtSes
|
|||
|
||||
static constexpr OrtApiBase ort_api_base = {
|
||||
&OrtApis::GetApi,
|
||||
#ifdef ENABLE_TRAINING_ON_DEVICE
|
||||
&OrtTrainingApis::GetTrainingApi,
|
||||
#endif
|
||||
&OrtApis::GetVersionString,
|
||||
};
|
||||
|
||||
|
|
@ -2528,21 +2536,6 @@ static constexpr OrtApi ort_api_1_to_12 = {
|
|||
&OrtApis::InvokeOp,
|
||||
&OrtApis::ReleaseOp,
|
||||
&OrtApis::SessionOptionsAppendExecutionProvider_SNPE,
|
||||
#ifdef ENABLE_TRAINING_ON_DEVICE
|
||||
// Experimental for on-device training. Always keep at the bottom.
|
||||
&OrtApis::LoadCheckpoint,
|
||||
&OrtApis::SaveCheckpoint,
|
||||
&OrtApis::CreateTrainingSession,
|
||||
&OrtApis::TrainingSessionGetTrainModeOutputCount,
|
||||
&OrtApis::TrainingSessionGetEvalModeOutputCount,
|
||||
&OrtApis::ResetGrad,
|
||||
&OrtApis::TrainStep,
|
||||
&OrtApis::EvalStep,
|
||||
&OrtApis::OptimizerStep,
|
||||
&OrtApis::ReleaseTrainingSession,
|
||||
&OrtApis::ReleaseCheckpointState,
|
||||
#endif
|
||||
|
||||
};
|
||||
|
||||
// Asserts to do a some checks to ensure older Versions of the OrtApi never change (will detect an addition or deletion but not if they cancel out each other)
|
||||
|
|
|
|||
|
|
@ -381,35 +381,4 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_SNPE,
|
|||
_In_reads_(num_keys) const char* const* provider_options_values,
|
||||
_In_ size_t num_keys);
|
||||
|
||||
#ifdef ENABLE_TRAINING_ON_DEVICE
|
||||
ORT_API_STATUS_IMPL(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options,
|
||||
_Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path,
|
||||
_In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path,
|
||||
_Outptr_ OrtTrainingSession** out);
|
||||
ORT_API(void, ReleaseTrainingSession, _Frees_ptr_opt_ OrtTrainingSession* session);
|
||||
|
||||
ORT_API_STATUS_IMPL(TrainingSessionGetTrainModeOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
|
||||
|
||||
ORT_API_STATUS_IMPL(TrainingSessionGetEvalModeOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
|
||||
|
||||
ORT_API_STATUS_IMPL(ResetGrad, _Inout_ OrtTrainingSession* session);
|
||||
|
||||
ORT_API_STATUS_IMPL(TrainStep, _Inout_ OrtTrainingSession* session, _In_opt_ const OrtRunOptions* run_options,
|
||||
size_t inputs_len, _In_reads_(input_len) const OrtValue* const* inputs,
|
||||
size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
|
||||
|
||||
ORT_API_STATUS_IMPL(EvalStep, _Inout_ OrtTrainingSession* session, _In_opt_ const OrtRunOptions* run_options,
|
||||
size_t inputs_len, _In_reads_(input_len) const OrtValue* const* inputs,
|
||||
size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
|
||||
|
||||
ORT_API_STATUS_IMPL(OptimizerStep, _Inout_ OrtTrainingSession* session, _In_opt_ const OrtRunOptions* run_options);
|
||||
|
||||
ORT_API_STATUS_IMPL(LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Outptr_ OrtCheckpointState** checkpoint_state);
|
||||
|
||||
ORT_API_STATUS_IMPL(SaveCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Inout_ OrtTrainingSession* session,
|
||||
bool save_optimizer_state);
|
||||
|
||||
ORT_API(void, ReleaseCheckpointState, _Frees_ptr_opt_ OrtCheckpointState* session);
|
||||
#endif
|
||||
|
||||
} // namespace OrtApis
|
||||
|
|
|
|||
|
|
@ -171,7 +171,7 @@ def build_gradient_accumulation_graph(grad_model, all_args_requiring_gradient_na
|
|||
# accumulated gradient update flag is also a graph output
|
||||
grad_accumulation_output = onnx.helper.make_tensor_value_info(
|
||||
grad_accumulation_output_name, onnx.TensorProto.BOOL, [1]
|
||||
)
|
||||
)
|
||||
graph_outputs.append(grad_accumulation_output)
|
||||
|
||||
lazy_reset_grad_input = onnx.helper.make_tensor_value_info(lazy_reset_grad_input_name, onnx.TensorProto.BOOL, [1])
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include <onnxruntime_c_api.h>
|
||||
#include "orttraining/training_api/include/utils.h"
|
||||
#include <onnxruntime_training_c_api.h>
|
||||
|
||||
#include "cxxopts.hpp"
|
||||
#include "../common/synthetic_data_loader.h"
|
||||
|
|
@ -13,10 +13,10 @@
|
|||
#include "core/providers/cuda/nvtx_profile_context.h"
|
||||
#endif
|
||||
|
||||
using namespace onnxruntime::training::api;
|
||||
using namespace std;
|
||||
|
||||
const OrtApi* g_ort_api = nullptr;
|
||||
const OrtTrainingApi* g_ort_training_api = nullptr;
|
||||
|
||||
struct TestRunnerParameters {
|
||||
// Models configs.
|
||||
|
|
@ -173,6 +173,7 @@ void InitSyntheticDataLoader(
|
|||
|
||||
int RunTraining(const TestRunnerParameters& params) {
|
||||
g_ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
|
||||
g_ort_training_api = OrtGetApiBase()->GetTrainingApi(ORT_API_VERSION);
|
||||
|
||||
// Create Env
|
||||
OrtEnv* env;
|
||||
|
|
@ -185,7 +186,7 @@ int RunTraining(const TestRunnerParameters& params) {
|
|||
|
||||
// Load Checkpoint State
|
||||
OrtCheckpointState* checkpoint_state;
|
||||
ORT_RETURN_ON_ERROR(g_ort_api->LoadCheckpoint(params.checkpoint_to_load_path.c_str(), &checkpoint_state));
|
||||
ORT_RETURN_ON_ERROR(g_ort_training_api->LoadCheckpoint(params.checkpoint_to_load_path.c_str(), &checkpoint_state));
|
||||
|
||||
// Create TrainingSession
|
||||
OrtSessionOptions* soptions;
|
||||
|
|
@ -199,16 +200,16 @@ int RunTraining(const TestRunnerParameters& params) {
|
|||
|
||||
OrtTrainingSession* session;
|
||||
bool do_eval = params.model_evaluation_graph_path.has_value();
|
||||
ORT_RETURN_ON_ERROR(g_ort_api->CreateTrainingSession(env, soptions, checkpoint_state,
|
||||
params.model_training_graph_path.c_str(), do_eval ? params.model_evaluation_graph_path.value().c_str() : nullptr,
|
||||
params.optimizer_training_graph_path.size() > 0 ? params.optimizer_training_graph_path.c_str() : nullptr,
|
||||
&session));
|
||||
ORT_RETURN_ON_ERROR(g_ort_training_api->CreateTrainingSession(env, soptions, checkpoint_state,
|
||||
params.model_training_graph_path.c_str(), do_eval ? params.model_evaluation_graph_path.value().c_str() : nullptr,
|
||||
params.optimizer_training_graph_path.size() > 0 ? params.optimizer_training_graph_path.c_str() : nullptr,
|
||||
&session));
|
||||
|
||||
size_t train_mode_output_count, eval_mode_output_count = 0;
|
||||
ORT_RETURN_ON_ERROR(g_ort_api->TrainingSessionGetTrainModeOutputCount(session, &train_mode_output_count));
|
||||
ORT_RETURN_ON_ERROR(g_ort_training_api->TrainingSessionGetTrainModeOutputCount(session, &train_mode_output_count));
|
||||
|
||||
if (do_eval) {
|
||||
ORT_RETURN_ON_ERROR(g_ort_api->TrainingSessionGetEvalModeOutputCount(session, &eval_mode_output_count));
|
||||
ORT_RETURN_ON_ERROR(g_ort_training_api->TrainingSessionGetEvalModeOutputCount(session, &eval_mode_output_count));
|
||||
}
|
||||
|
||||
int64_t sample_batch_count_per_epoch = 4;
|
||||
|
|
@ -247,9 +248,9 @@ int RunTraining(const TestRunnerParameters& params) {
|
|||
#endif
|
||||
|
||||
std::vector<OrtValue*> fetches(train_mode_output_count);
|
||||
ORT_RETURN_ON_ERROR(g_ort_api->TrainStep(session, nullptr,
|
||||
inputs.size(), (const OrtValue* const*)inputs.data(),
|
||||
train_mode_output_count, fetches.data()));
|
||||
ORT_RETURN_ON_ERROR(g_ort_training_api->TrainStep(session, nullptr,
|
||||
inputs.size(), (const OrtValue* const*)inputs.data(),
|
||||
train_mode_output_count, fetches.data()));
|
||||
#if defined(USE_CUDA) && defined(ENABLE_NVTX_PROFILE)
|
||||
train_step_range.End();
|
||||
#endif
|
||||
|
|
@ -266,7 +267,7 @@ int RunTraining(const TestRunnerParameters& params) {
|
|||
onnxruntime::profile::Color::Blue);
|
||||
opt_step_range.Begin();
|
||||
#endif
|
||||
ORT_RETURN_ON_ERROR(g_ort_api->OptimizerStep(session, nullptr));
|
||||
ORT_RETURN_ON_ERROR(g_ort_training_api->OptimizerStep(session, nullptr));
|
||||
|
||||
#if defined(USE_CUDA) && defined(ENABLE_NVTX_PROFILE)
|
||||
opt_step_range.End();
|
||||
|
|
@ -282,7 +283,7 @@ int RunTraining(const TestRunnerParameters& params) {
|
|||
resetgrad_range.Begin();
|
||||
#endif
|
||||
|
||||
ORT_RETURN_ON_ERROR(g_ort_api->ResetGrad(session));
|
||||
ORT_RETURN_ON_ERROR(g_ort_training_api->ResetGrad(session));
|
||||
|
||||
#if defined(USE_CUDA) && defined(ENABLE_NVTX_PROFILE)
|
||||
resetgrad_range.End();
|
||||
|
|
@ -291,15 +292,15 @@ int RunTraining(const TestRunnerParameters& params) {
|
|||
|
||||
if (do_eval && (batch_idx + 1) % params.eval_interval == 0) {
|
||||
std::vector<OrtValue*> eval_results(eval_mode_output_count);
|
||||
ORT_RETURN_ON_ERROR(g_ort_api->EvalStep(session, nullptr,
|
||||
inputs.size(), (const OrtValue* const*)inputs.data(),
|
||||
train_mode_output_count, eval_results.data()));
|
||||
ORT_RETURN_ON_ERROR(g_ort_training_api->EvalStep(session, nullptr,
|
||||
inputs.size(), (const OrtValue* const*)inputs.data(),
|
||||
train_mode_output_count, eval_results.data()));
|
||||
}
|
||||
|
||||
if ((batch_idx + 1) % params.checkpoint_interval == 0) {
|
||||
// Save trained weights
|
||||
std::string ckpt_file = params.output_dir + "/ckpt_" + params.model_name + std::to_string(batch_idx);
|
||||
ORT_RETURN_ON_ERROR(g_ort_api->SaveCheckpoint(ckpt_file.c_str(), session, true));
|
||||
ORT_RETURN_ON_ERROR(g_ort_training_api->SaveCheckpoint(ckpt_file.c_str(), session, true));
|
||||
|
||||
// TODO(baiju): enable adding more properties to checkpoint
|
||||
// state_to_save.property_bag.AddProperty<int64_t>(std::string("epoch"), epoch);
|
||||
|
|
@ -321,7 +322,7 @@ int RunTraining(const TestRunnerParameters& params) {
|
|||
|
||||
// Save trained weights
|
||||
std::string ckpt_file = params.output_dir + "/ckpt_" + params.model_name;
|
||||
ORT_RETURN_ON_ERROR(g_ort_api->SaveCheckpoint(ckpt_file.c_str(), session, true));
|
||||
ORT_RETURN_ON_ERROR(g_ort_training_api->SaveCheckpoint(ckpt_file.c_str(), session, true));
|
||||
|
||||
auto end = std::chrono::high_resolution_clock::now();
|
||||
std::chrono::duration<double> duration_seconds = end - end_to_end_start;
|
||||
|
|
@ -330,14 +331,14 @@ int RunTraining(const TestRunnerParameters& params) {
|
|||
std::cout << "Training completed - end to end latency: " << stabilized_total_end_to_end_time << "(s)" << std::endl;
|
||||
|
||||
// Delete all the ptrs
|
||||
g_ort_api->ReleaseTrainingSession(session);
|
||||
g_ort_training_api->ReleaseTrainingSession(session);
|
||||
|
||||
#ifdef USE_CUDA
|
||||
// Finally, don't forget to release the provider options
|
||||
g_ort_api->ReleaseCUDAProviderOptions(cuda_options);
|
||||
#endif
|
||||
g_ort_api->ReleaseSessionOptions(soptions);
|
||||
g_ort_api->ReleaseCheckpointState(checkpoint_state);
|
||||
g_ort_training_api->ReleaseCheckpointState(checkpoint_state);
|
||||
g_ort_api->ReleaseEnv(env);
|
||||
|
||||
return 0;
|
||||
|
|
|
|||
|
|
@ -1,40 +1,44 @@
|
|||
// This file contains c apis for on device training
|
||||
// This file should never be included standalone
|
||||
// It is included from within core/session/onnxruntime_c_api.h when
|
||||
// on device training is enabled
|
||||
// These apis can be moved to core/session/onnxruntime_c_api.h once they stabilize
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// DO NOT UNCOMMENT
|
||||
//#include "core/session/onnxruntime_c_api.h"
|
||||
// This file contains the training c apis.
|
||||
|
||||
#include <stdbool.h>
|
||||
#pragma once
|
||||
#include "core/session/onnxruntime_c_api.h"
|
||||
|
||||
ORT_API2_STATUS(LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Outptr_ OrtCheckpointState** checkpoint_state);
|
||||
ORT_RUNTIME_CLASS(TrainingSession);
|
||||
ORT_RUNTIME_CLASS(CheckpointState);
|
||||
|
||||
ORT_API2_STATUS(SaveCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Inout_ OrtTrainingSession* session,
|
||||
bool save_optimizer_state);
|
||||
struct OrtTrainingApi {
|
||||
ORT_API2_STATUS(LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Outptr_ OrtCheckpointState** checkpoint_state);
|
||||
|
||||
ORT_API2_STATUS(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options,
|
||||
_Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path,
|
||||
_In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path,
|
||||
_Outptr_ OrtTrainingSession** out);
|
||||
ORT_API2_STATUS(SaveCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Inout_ OrtTrainingSession* session,
|
||||
bool save_optimizer_state);
|
||||
|
||||
ORT_API2_STATUS(TrainingSessionGetTrainModeOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
|
||||
ORT_API2_STATUS(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options,
|
||||
_Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path,
|
||||
_In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path,
|
||||
_Outptr_ OrtTrainingSession** out);
|
||||
|
||||
ORT_API2_STATUS(TrainingSessionGetEvalModeOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
|
||||
ORT_API2_STATUS(TrainingSessionGetTrainModeOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
|
||||
|
||||
ORT_API2_STATUS(ResetGrad, _Inout_ OrtTrainingSession* session);
|
||||
ORT_API2_STATUS(TrainingSessionGetEvalModeOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
|
||||
|
||||
ORT_API2_STATUS(TrainStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options,
|
||||
size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs,
|
||||
size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
|
||||
ORT_API2_STATUS(ResetGrad, _Inout_ OrtTrainingSession* session);
|
||||
|
||||
ORT_API2_STATUS(EvalStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options,
|
||||
size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs,
|
||||
size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
|
||||
ORT_API2_STATUS(TrainStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options,
|
||||
size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs,
|
||||
size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
|
||||
|
||||
ORT_API2_STATUS(OptimizerStep, _Inout_ OrtTrainingSession* sess,
|
||||
_In_opt_ const OrtRunOptions* run_options);
|
||||
ORT_API2_STATUS(EvalStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options,
|
||||
size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs,
|
||||
size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
|
||||
|
||||
ORT_CLASS_RELEASE(TrainingSession);
|
||||
ORT_CLASS_RELEASE(CheckpointState);
|
||||
ORT_API2_STATUS(OptimizerStep, _Inout_ OrtTrainingSession* sess,
|
||||
_In_opt_ const OrtRunOptions* run_options);
|
||||
|
||||
ORT_CLASS_RELEASE(TrainingSession);
|
||||
ORT_CLASS_RELEASE(CheckpointState);
|
||||
};
|
||||
|
||||
typedef struct OrtTrainingApi OrtTrainingApi;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,37 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
namespace OrtTrainingApis {
|
||||
|
||||
ORT_API(const OrtTrainingApi*, GetTrainingApi, uint32_t version);
|
||||
|
||||
ORT_API_STATUS_IMPL(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options,
|
||||
_Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path,
|
||||
_In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path,
|
||||
_Outptr_ OrtTrainingSession** out);
|
||||
ORT_API(void, ReleaseTrainingSession, _Frees_ptr_opt_ OrtTrainingSession* session);
|
||||
|
||||
ORT_API_STATUS_IMPL(TrainingSessionGetTrainModeOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
|
||||
|
||||
ORT_API_STATUS_IMPL(TrainingSessionGetEvalModeOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
|
||||
|
||||
ORT_API_STATUS_IMPL(ResetGrad, _Inout_ OrtTrainingSession* session);
|
||||
|
||||
ORT_API_STATUS_IMPL(TrainStep, _Inout_ OrtTrainingSession* session, _In_opt_ const OrtRunOptions* run_options,
|
||||
size_t inputs_len, _In_reads_(input_len) const OrtValue* const* inputs,
|
||||
size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
|
||||
|
||||
ORT_API_STATUS_IMPL(EvalStep, _Inout_ OrtTrainingSession* session, _In_opt_ const OrtRunOptions* run_options,
|
||||
size_t inputs_len, _In_reads_(input_len) const OrtValue* const* inputs,
|
||||
size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
|
||||
|
||||
ORT_API_STATUS_IMPL(OptimizerStep, _Inout_ OrtTrainingSession* session, _In_opt_ const OrtRunOptions* run_options);
|
||||
|
||||
ORT_API_STATUS_IMPL(LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Outptr_ OrtCheckpointState** checkpoint_state);
|
||||
|
||||
ORT_API_STATUS_IMPL(SaveCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Inout_ OrtTrainingSession* session,
|
||||
bool save_optimizer_state);
|
||||
|
||||
ORT_API(void, ReleaseCheckpointState, _Frees_ptr_opt_ OrtCheckpointState* session);
|
||||
|
||||
} // namespace OrtTrainingApis
|
||||
|
|
@ -1,14 +1,14 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "orttraining/training_api/include/onnxruntime_training_c_api.h"
|
||||
#include "core/framework/error_code_helper.h"
|
||||
#include "core/framework/ort_value.h"
|
||||
#include "core/session/ort_apis.h"
|
||||
#include "core/session/ort_env.h"
|
||||
#include "core/session/abi_session_options_impl.h"
|
||||
#include "orttraining/training_api/include/checkpoint.h"
|
||||
#include "orttraining/training_api/include/training_session.h"
|
||||
#include "core/session/abi_session_options_impl.h"
|
||||
#include "orttraining/training_api/include/ort_training_apis.h"
|
||||
|
||||
namespace {
|
||||
|
||||
|
|
@ -24,7 +24,7 @@ std::vector<std::shared_ptr<onnxruntime::IExecutionProvider>> CreateProviders(
|
|||
|
||||
} // namespace
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options,
|
||||
ORT_API_STATUS_IMPL(OrtTrainingApis::CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options,
|
||||
_Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path,
|
||||
_In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path,
|
||||
_Outptr_ OrtTrainingSession** out) {
|
||||
|
|
@ -60,7 +60,8 @@ ORT_API_STATUS_IMPL(OrtApis::CreateTrainingSession, _In_ const OrtEnv* env, _In_
|
|||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::TrainingSessionGetTrainModeOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out) {
|
||||
ORT_API_STATUS_IMPL(OrtTrainingApis::TrainingSessionGetTrainModeOutputCount, _In_ const OrtTrainingSession* sess,
|
||||
_Out_ size_t* out) {
|
||||
API_IMPL_BEGIN
|
||||
auto session = reinterpret_cast<const onnxruntime::training::api::TrainingSession*>(sess);
|
||||
*out = session->GetTrainModeOutputCount();
|
||||
|
|
@ -68,7 +69,8 @@ ORT_API_STATUS_IMPL(OrtApis::TrainingSessionGetTrainModeOutputCount, _In_ const
|
|||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::TrainingSessionGetEvalModeOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out) {
|
||||
ORT_API_STATUS_IMPL(OrtTrainingApis::TrainingSessionGetEvalModeOutputCount, _In_ const OrtTrainingSession* sess,
|
||||
_Out_ size_t* out) {
|
||||
API_IMPL_BEGIN
|
||||
auto session = reinterpret_cast<const onnxruntime::training::api::TrainingSession*>(sess);
|
||||
*out = session->GetEvalModeOutputCount();
|
||||
|
|
@ -76,7 +78,7 @@ ORT_API_STATUS_IMPL(OrtApis::TrainingSessionGetEvalModeOutputCount, _In_ const O
|
|||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::ResetGrad, _Inout_ OrtTrainingSession* session) {
|
||||
ORT_API_STATUS_IMPL(OrtTrainingApis::ResetGrad, _Inout_ OrtTrainingSession* session) {
|
||||
API_IMPL_BEGIN
|
||||
auto train_session = reinterpret_cast<onnxruntime::training::api::TrainingSession*>(session);
|
||||
ORT_API_RETURN_IF_STATUS_NOT_OK(train_session->ResetGrad());
|
||||
|
|
@ -85,9 +87,10 @@ ORT_API_STATUS_IMPL(OrtApis::ResetGrad, _Inout_ OrtTrainingSession* session) {
|
|||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::TrainStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options,
|
||||
size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs,
|
||||
size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs) {
|
||||
ORT_API_STATUS_IMPL(OrtTrainingApis::TrainStep, _Inout_ OrtTrainingSession* sess,
|
||||
_In_opt_ const OrtRunOptions* run_options, size_t inputs_len,
|
||||
_In_reads_(inputs_len) const OrtValue* const* inputs, size_t outputs_len,
|
||||
_Inout_updates_all_(outputs_len) OrtValue** outputs) {
|
||||
API_IMPL_BEGIN
|
||||
auto session = reinterpret_cast<onnxruntime::training::api::TrainingSession*>(sess);
|
||||
constexpr int queue_id = 0;
|
||||
|
|
@ -133,9 +136,10 @@ ORT_API_STATUS_IMPL(OrtApis::TrainStep, _Inout_ OrtTrainingSession* sess, _In_op
|
|||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::EvalStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options,
|
||||
size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs,
|
||||
size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs) {
|
||||
ORT_API_STATUS_IMPL(OrtTrainingApis::EvalStep, _Inout_ OrtTrainingSession* sess,
|
||||
_In_opt_ const OrtRunOptions* run_options, size_t inputs_len,
|
||||
_In_reads_(inputs_len) const OrtValue* const* inputs, size_t outputs_len,
|
||||
_Inout_updates_all_(outputs_len) OrtValue** outputs) {
|
||||
API_IMPL_BEGIN
|
||||
auto session = reinterpret_cast<onnxruntime::training::api::TrainingSession*>(sess);
|
||||
constexpr int queue_id = 0;
|
||||
|
|
@ -180,7 +184,7 @@ ORT_API_STATUS_IMPL(OrtApis::EvalStep, _Inout_ OrtTrainingSession* sess, _In_opt
|
|||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::OptimizerStep, _Inout_ OrtTrainingSession* sess,
|
||||
ORT_API_STATUS_IMPL(OrtTrainingApis::OptimizerStep, _Inout_ OrtTrainingSession* sess,
|
||||
_In_opt_ const OrtRunOptions* run_options) {
|
||||
API_IMPL_BEGIN
|
||||
auto session = reinterpret_cast<onnxruntime::training::api::TrainingSession*>(sess);
|
||||
|
|
@ -195,7 +199,8 @@ ORT_API_STATUS_IMPL(OrtApis::OptimizerStep, _Inout_ OrtTrainingSession* sess,
|
|||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Outptr_ OrtCheckpointState** checkpoint_state) {
|
||||
ORT_API_STATUS_IMPL(OrtTrainingApis::LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path,
|
||||
_Outptr_ OrtCheckpointState** checkpoint_state) {
|
||||
API_IMPL_BEGIN
|
||||
*checkpoint_state = nullptr;
|
||||
auto chkpt_state = std::make_unique<onnxruntime::training::api::CheckpointState>();
|
||||
|
|
@ -206,8 +211,8 @@ ORT_API_STATUS_IMPL(OrtApis::LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_pa
|
|||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::SaveCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Inout_ OrtTrainingSession* sess,
|
||||
bool save_optimizer_state) {
|
||||
ORT_API_STATUS_IMPL(OrtTrainingApis::SaveCheckpoint, _In_ const ORTCHAR_T* checkpoint_path,
|
||||
_Inout_ OrtTrainingSession* sess, bool save_optimizer_state) {
|
||||
API_IMPL_BEGIN
|
||||
auto session = reinterpret_cast<onnxruntime::training::api::TrainingSession*>(sess);
|
||||
onnxruntime::training::api::CheckpointState chkpt_state;
|
||||
|
|
@ -218,10 +223,29 @@ ORT_API_STATUS_IMPL(OrtApis::SaveCheckpoint, _In_ const ORTCHAR_T* checkpoint_pa
|
|||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API(void, OrtApis::ReleaseTrainingSession, _Frees_ptr_opt_ OrtTrainingSession* session) {
|
||||
ORT_API(void, OrtTrainingApis::ReleaseTrainingSession, _Frees_ptr_opt_ OrtTrainingSession* session) {
|
||||
delete reinterpret_cast<onnxruntime::training::api::TrainingSession*>(session);
|
||||
}
|
||||
|
||||
ORT_API(void, OrtApis::ReleaseCheckpointState, _Frees_ptr_opt_ OrtCheckpointState* checkpoint_state) {
|
||||
ORT_API(void, OrtTrainingApis::ReleaseCheckpointState, _Frees_ptr_opt_ OrtCheckpointState* checkpoint_state) {
|
||||
delete reinterpret_cast<onnxruntime::training::api::CheckpointState*>(checkpoint_state);
|
||||
}
|
||||
|
||||
static constexpr OrtTrainingApi ort_training_api = {
|
||||
&OrtTrainingApis::LoadCheckpoint,
|
||||
&OrtTrainingApis::SaveCheckpoint,
|
||||
&OrtTrainingApis::CreateTrainingSession,
|
||||
&OrtTrainingApis::TrainingSessionGetTrainModeOutputCount,
|
||||
&OrtTrainingApis::TrainingSessionGetEvalModeOutputCount,
|
||||
&OrtTrainingApis::ResetGrad,
|
||||
&OrtTrainingApis::TrainStep,
|
||||
&OrtTrainingApis::EvalStep,
|
||||
&OrtTrainingApis::OptimizerStep,
|
||||
&OrtTrainingApis::ReleaseTrainingSession,
|
||||
&OrtTrainingApis::ReleaseCheckpointState,
|
||||
};
|
||||
|
||||
ORT_API(const OrtTrainingApi*, OrtTrainingApis::GetTrainingApi, uint32_t) {
|
||||
// No constraints on the API version yet.
|
||||
return &ort_training_api;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue