mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-29 03:30:52 +00:00
Add Learning Rate Scheduler C API (#11957)
This commit is contained in:
parent
73da3f3705
commit
f5e3517c39
7 changed files with 177 additions and 11 deletions
|
|
@ -256,10 +256,11 @@ int RunTraining(const TestRunnerParameters& params) {
|
|||
onnxruntime::training::test::training_api::SyntheticDataLoader data_loader;
|
||||
InitSyntheticDataLoader(data_loader, params, num_of_batches_per_epoch);
|
||||
|
||||
// TODO(baiju): Add C API for LRScheduler
|
||||
// int64_t total_step_count = params.num_train_epochs * num_of_batches_per_epoch;
|
||||
// int64_t warmup_step_count = total_step_count / 3;
|
||||
// Ort::OrtLinearLRScheduler scheduler = Ort::OrtLinearLRScheduler(optimizer, warmup_step_count, total_step_count);
|
||||
auto lr_scheduler_parameters = std::make_unique<OrtLinearLRSchedulerParameters>().get();
|
||||
lr_scheduler_parameters->total_step_count = params.num_train_epochs * num_of_batches_per_epoch;
|
||||
lr_scheduler_parameters->warmup_step_count = lr_scheduler_parameters->total_step_count / 3;
|
||||
ORT_RETURN_ON_ERROR(g_ort_training_api->RegisterLRScheduler(
|
||||
session, reinterpret_cast<void*>(lr_scheduler_parameters), OrtLRSchedulerType::LinearLRScheduler, nullptr));
|
||||
|
||||
std::cout << "Initialization completed. Now starting training loop." << std::endl;
|
||||
const int64_t stabilized_perf_start_step = 0;
|
||||
|
|
@ -309,7 +310,7 @@ int RunTraining(const TestRunnerParameters& params) {
|
|||
#endif
|
||||
|
||||
// Update learning rate.
|
||||
// EnforceCheck(scheduler.Step(), "Failed during shceduler.Step()");
|
||||
ORT_RETURN_ON_ERROR(g_ort_training_api->SchedulerStep(session));
|
||||
|
||||
#if defined(USE_CUDA) && defined(ENABLE_NVTX_PROFILE)
|
||||
onnxruntime::profile::NvtxRangeCreator resetgrad_range(
|
||||
|
|
|
|||
|
|
@ -9,6 +9,15 @@
|
|||
ORT_RUNTIME_CLASS(TrainingSession); /// Type that enables performing training for the given user models.
|
||||
ORT_RUNTIME_CLASS(CheckpointState); /// Type that holds the training states for the training session.
|
||||
|
||||
typedef enum OrtLRSchedulerType {
|
||||
LinearLRScheduler
|
||||
} OrtLRSchedulerType;
|
||||
|
||||
typedef struct OrtLinearLRSchedulerParameters {
|
||||
int64_t warmup_step_count;
|
||||
int64_t total_step_count;
|
||||
} OrtLinearLRSchedulerParameters;
|
||||
|
||||
struct OrtTrainingApi {
|
||||
/** \brief Load a checkpoint state from directory on disk into checkpoint_state.
|
||||
*
|
||||
|
|
@ -144,6 +153,26 @@ struct OrtTrainingApi {
|
|||
size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs,
|
||||
size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
|
||||
|
||||
/** \brief Sets the learning rate for this training session.
|
||||
*
|
||||
* This function allows users to set the learning rate for the training session. The current
|
||||
* learning rate is maintained by the training session and can be overwritten by invoking
|
||||
* this function with the desired learning rate. This function should not be used when a valid
|
||||
* learning rate scheduler is registered. It should be used either to set the learning rate
|
||||
* derived from a custom learning rate scheduler or to set the learning rate constant to be used
|
||||
* throughout the training session.
|
||||
* Please note that this function does not set the initial learning rate that may be needed
|
||||
* by the predefined learning rate schedulers. To set the initial learning rate for learning
|
||||
* rate schedulers, please look at the function `RegisterLRScheduler`.
|
||||
*
|
||||
* \param[in] sess The training session on which the learning rate needs to be set.
|
||||
* \param[in] learning_rate Desired learning rate to set.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(SetLearningRate, _Inout_ OrtTrainingSession* sess, _In_ float learning_rate);
|
||||
|
||||
/** \brief Performs the weight updates for the trainable parameters using the optimizer model.
|
||||
*
|
||||
* This function performs the weight update step that updates the trainable parameters such that they
|
||||
|
|
@ -161,6 +190,37 @@ struct OrtTrainingApi {
|
|||
ORT_API2_STATUS(OptimizerStep, _Inout_ OrtTrainingSession* sess,
|
||||
_In_opt_ const OrtRunOptions* run_options);
|
||||
|
||||
/** \brief Registers the use of the given learning rate scheduler for the training session.
|
||||
*
|
||||
* Register a learning rate scheduler identified by the given enum with the given
|
||||
* learning rate scheduler parameters. Optionally specify the initial learning rate
|
||||
* that should be used with this learning rate scheduler and training session.
|
||||
*
|
||||
* \param[in] sess The training session that should use the linear learning rate scheduler.
|
||||
* \param[in] lr_scheduler_parameters Learning rate scheduler parameters struct.
|
||||
* \param[in] initial_lr The initial learning rate to be used by the training session.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(RegisterLRScheduler, _Inout_ OrtTrainingSession* sess, _In_ void* lr_scheduler_parameters,
|
||||
_In_ enum OrtLRSchedulerType lr_scheduler_type, _In_opt_ const float* initial_lr);
|
||||
|
||||
/** \brief Update the learning rate based on the registered learing rate scheduler.
|
||||
*
|
||||
* Takes a scheduler step that updates the learning rate that is being used by the training session.
|
||||
* This function should typically be called before invoking the optimizer step for each round,
|
||||
* or as determined necessary to update the learning rate being used by the training session.
|
||||
* Please note that a valid predefined learning rate scheduler must be first registered to invoke this
|
||||
* function.
|
||||
*
|
||||
* \param[in] sess The training session that has the registered learning rate scheduler.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(SchedulerStep, _Inout_ OrtTrainingSession* sess);
|
||||
|
||||
/** \brief Frees up the memory used up by the training session.
|
||||
*
|
||||
* This function frees up any memory that was allocated in the training session. The training
|
||||
|
|
|
|||
|
|
@ -68,16 +68,22 @@ struct Optimizer {
|
|||
|
||||
Status LoadStateDict(const OptimizerCheckpointState& optimizer_checkpoint_states);
|
||||
|
||||
inline Status SetLearningRate(float lr) {
|
||||
optimizer_state_.learning_rate = lr;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
inline Status SetInitialLearningRate(float initial_lr) {
|
||||
optimizer_state_.initial_lr = initial_lr;
|
||||
optimizer_state_.learning_rate = initial_lr;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t GetStep() const {
|
||||
return optimizer_state_.step;
|
||||
}
|
||||
|
||||
Status SetLearningRate(float lr) {
|
||||
optimizer_state_.learning_rate = lr;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Generates optimizer momentum states for applicable optimizer types
|
||||
Status GenerateMomentumNamedStates();
|
||||
// Constructs the ortvalue inputs to be fed to the graph
|
||||
|
|
|
|||
|
|
@ -25,8 +25,15 @@ ORT_API_STATUS_IMPL(EvalStep, _In_ const OrtTrainingSession* session, _In_opt_ c
|
|||
size_t inputs_len, _In_reads_(input_len) const OrtValue* const* inputs,
|
||||
size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
|
||||
|
||||
ORT_API_STATUS_IMPL(SetLearningRate, _Inout_ OrtTrainingSession* sess, _In_ float learning_rate);
|
||||
|
||||
ORT_API_STATUS_IMPL(OptimizerStep, _Inout_ OrtTrainingSession* session, _In_opt_ const OrtRunOptions* run_options);
|
||||
|
||||
ORT_API_STATUS_IMPL(RegisterLRScheduler, _Inout_ OrtTrainingSession* sess, _In_ void* lr_scheduler_parameters,
|
||||
_In_ enum OrtLRSchedulerType lr_scheduler_type, _In_opt_ const float* initial_lr);
|
||||
|
||||
ORT_API_STATUS_IMPL(SchedulerStep, _Inout_ OrtTrainingSession* sess);
|
||||
|
||||
ORT_API_STATUS_IMPL(LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path,
|
||||
_Outptr_ OrtCheckpointState** checkpoint_state);
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
#include "core/common/common.h"
|
||||
#include "module.h"
|
||||
#include "optimizer.h"
|
||||
#include "lr_scheduler.h"
|
||||
#include "checkpoint.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -30,6 +31,10 @@ class TrainingSession {
|
|||
const std::unordered_map<std::string, std::shared_ptr<Parameter>>& parameters,
|
||||
const ModelIdentifiers& model_identifiers);
|
||||
|
||||
Status RegisterScheduler(const std::function<
|
||||
std::unique_ptr<LRSchedulerBase>(std::shared_ptr<Optimizer>)>& get_scheduler,
|
||||
std::optional<float> initial_lr);
|
||||
|
||||
size_t GetTrainModeOutputCount() const noexcept;
|
||||
|
||||
size_t GetEvalModeOutputCount() const noexcept;
|
||||
|
|
@ -46,6 +51,10 @@ class TrainingSession {
|
|||
|
||||
Status OptimizerStep(const RunOptions& run_options);
|
||||
|
||||
Status SetLearningRate(float learning_rate) noexcept;
|
||||
|
||||
Status SchedulerStep() noexcept;
|
||||
|
||||
Status CreateCheckpointState(CheckpointState& chkpt_state, bool save_optimizer_state) const;
|
||||
|
||||
private:
|
||||
|
|
@ -53,7 +62,8 @@ class TrainingSession {
|
|||
|
||||
const std::unordered_map<std::string, std::shared_ptr<Parameter>> named_parameters_;
|
||||
std::unique_ptr<Module> module_;
|
||||
std::unique_ptr<Optimizer> optimizer_;
|
||||
std::shared_ptr<Optimizer> optimizer_;
|
||||
std::unique_ptr<LRSchedulerBase> scheduler_;
|
||||
};
|
||||
} // namespace api
|
||||
} // namespace training
|
||||
|
|
|
|||
|
|
@ -184,6 +184,17 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::EvalStep, _In_ const OrtTrainingSession* se
|
|||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtTrainingApis::SetLearningRate, _Inout_ OrtTrainingSession* sess,
|
||||
_In_ float learning_rate) {
|
||||
API_IMPL_BEGIN
|
||||
|
||||
auto session = reinterpret_cast<onnxruntime::training::api::TrainingSession*>(sess);
|
||||
ORT_API_RETURN_IF_STATUS_NOT_OK(session->SetLearningRate(learning_rate));
|
||||
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtTrainingApis::OptimizerStep, _Inout_ OrtTrainingSession* sess,
|
||||
_In_opt_ const OrtRunOptions* run_options) {
|
||||
API_IMPL_BEGIN
|
||||
|
|
@ -199,6 +210,50 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::OptimizerStep, _Inout_ OrtTrainingSession*
|
|||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtTrainingApis::RegisterLRScheduler, _Inout_ OrtTrainingSession* sess,
|
||||
_In_ void* lr_scheduler_parameters, _In_ enum OrtLRSchedulerType lr_scheduler_type,
|
||||
_In_opt_ const float* initial_lr) {
|
||||
API_IMPL_BEGIN
|
||||
|
||||
OrtStatus* status = nullptr;
|
||||
|
||||
auto session = reinterpret_cast<onnxruntime::training::api::TrainingSession*>(sess);
|
||||
|
||||
if (!lr_scheduler_parameters) {
|
||||
return OrtApis::CreateStatus(ORT_FAIL, "The provided learning rate scheduler parameters is a nullptr.");
|
||||
}
|
||||
|
||||
switch (lr_scheduler_type) {
|
||||
case OrtLRSchedulerType::LinearLRScheduler: {
|
||||
auto parameters = reinterpret_cast<OrtLinearLRSchedulerParameters*>(lr_scheduler_parameters);
|
||||
ORT_API_RETURN_IF_STATUS_NOT_OK(
|
||||
session->RegisterScheduler([¶meters](auto optimizer) {
|
||||
return std::make_unique<onnxruntime::training::api::LinearLRScheduler>(
|
||||
optimizer, parameters->warmup_step_count, parameters->total_step_count);
|
||||
},
|
||||
initial_lr ? std::optional<float>(*initial_lr) : std::nullopt));
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
status = OrtApis::CreateStatus(ORT_FAIL, "Could not decipher the OrtLRSchedulerType from the given argument.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return status;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtTrainingApis::SchedulerStep, _Inout_ OrtTrainingSession* sess) {
|
||||
API_IMPL_BEGIN
|
||||
|
||||
auto session = reinterpret_cast<onnxruntime::training::api::TrainingSession*>(sess);
|
||||
ORT_API_RETURN_IF_STATUS_NOT_OK(session->SchedulerStep());
|
||||
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtTrainingApis::LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path,
|
||||
_Outptr_ OrtCheckpointState** checkpoint_state) {
|
||||
API_IMPL_BEGIN
|
||||
|
|
@ -240,7 +295,10 @@ static constexpr OrtTrainingApi ort_training_api = {
|
|||
&OrtTrainingApis::ResetGrad,
|
||||
&OrtTrainingApis::TrainStep,
|
||||
&OrtTrainingApis::EvalStep,
|
||||
&OrtTrainingApis::SetLearningRate,
|
||||
&OrtTrainingApis::OptimizerStep,
|
||||
&OrtTrainingApis::RegisterLRScheduler,
|
||||
&OrtTrainingApis::SchedulerStep,
|
||||
&OrtTrainingApis::ReleaseTrainingSession,
|
||||
&OrtTrainingApis::ReleaseCheckpointState,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -21,6 +21,19 @@ TrainingSession::TrainingSession(const Environment& session_env,
|
|||
session_options, session_env, providers)
|
||||
: std::unique_ptr<Optimizer>()} {}
|
||||
|
||||
Status TrainingSession::RegisterScheduler(
|
||||
const std::function<std::unique_ptr<LRSchedulerBase>(std::shared_ptr<Optimizer>)>& get_scheduler,
|
||||
std::optional<float> initial_lr) {
|
||||
scheduler_ = std::move(get_scheduler(optimizer_));
|
||||
ORT_RETURN_IF_NOT(scheduler_, "The provided instance of the learning rate scheduler is a nullptr.");
|
||||
|
||||
if (initial_lr.has_value()) {
|
||||
ORT_RETURN_IF_ERROR(optimizer_->SetInitialLearningRate(initial_lr.value()));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
size_t TrainingSession::GetTrainModeOutputCount() const noexcept {
|
||||
return module_->GetTrainModeOutputCount();
|
||||
}
|
||||
|
|
@ -58,6 +71,17 @@ Status TrainingSession::CreateCheckpointState(CheckpointState& chkpt_state, bool
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TrainingSession::SetLearningRate(float learning_rate) noexcept {
|
||||
ORT_RETURN_IF_ERROR(optimizer_->SetLearningRate(learning_rate));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TrainingSession::SchedulerStep() noexcept {
|
||||
ORT_RETURN_IF_NOT(scheduler_, "No learning rate schedler was registered. Please register a valid learning rate scheduler");
|
||||
return scheduler_->Step();
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue