From f5e3517c39b624bd027fdd9f5d3146d535bd74ac Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Mon, 15 Aug 2022 09:10:25 -0700 Subject: [PATCH] Add Learning Rate Scheduler C API (#11957) --- .../test/training_api/trainer/trainer.cc | 11 ++-- .../include/onnxruntime_training_c_api.h | 60 +++++++++++++++++++ .../training_api/include/optimizer.h | 16 +++-- .../training_api/include/ort_training_apis.h | 7 +++ .../training_api/include/training_session.h | 12 +++- .../onnxruntime_training_c_api.cc | 58 ++++++++++++++++++ .../training_api/training_session.cc | 24 ++++++++ 7 files changed, 177 insertions(+), 11 deletions(-) diff --git a/orttraining/orttraining/test/training_api/trainer/trainer.cc b/orttraining/orttraining/test/training_api/trainer/trainer.cc index ca54c89795..ce5f572726 100644 --- a/orttraining/orttraining/test/training_api/trainer/trainer.cc +++ b/orttraining/orttraining/test/training_api/trainer/trainer.cc @@ -256,10 +256,11 @@ int RunTraining(const TestRunnerParameters& params) { onnxruntime::training::test::training_api::SyntheticDataLoader data_loader; InitSyntheticDataLoader(data_loader, params, num_of_batches_per_epoch); - // TODO(baiju): Add C API for LRScheduler - // int64_t total_step_count = params.num_train_epochs * num_of_batches_per_epoch; - // int64_t warmup_step_count = total_step_count / 3; - // Ort::OrtLinearLRScheduler scheduler = Ort::OrtLinearLRScheduler(optimizer, warmup_step_count, total_step_count); + auto lr_scheduler_parameters = std::make_unique().get(); + lr_scheduler_parameters->total_step_count = params.num_train_epochs * num_of_batches_per_epoch; + lr_scheduler_parameters->warmup_step_count = lr_scheduler_parameters->total_step_count / 3; + ORT_RETURN_ON_ERROR(g_ort_training_api->RegisterLRScheduler( + session, reinterpret_cast(lr_scheduler_parameters), OrtLRSchedulerType::LinearLRScheduler, nullptr)); std::cout << "Initialization completed. Now starting training loop." << std::endl; const int64_t stabilized_perf_start_step = 0; @@ -309,7 +310,7 @@ int RunTraining(const TestRunnerParameters& params) { #endif // Update learning rate. - // EnforceCheck(scheduler.Step(), "Failed during shceduler.Step()"); + ORT_RETURN_ON_ERROR(g_ort_training_api->SchedulerStep(session)); #if defined(USE_CUDA) && defined(ENABLE_NVTX_PROFILE) onnxruntime::profile::NvtxRangeCreator resetgrad_range( diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h index 9c57ded042..593c60ac04 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h @@ -9,6 +9,15 @@ ORT_RUNTIME_CLASS(TrainingSession); /// Type that enables performing training for the given user models. ORT_RUNTIME_CLASS(CheckpointState); /// Type that holds the training states for the training session. +typedef enum OrtLRSchedulerType { + LinearLRScheduler +} OrtLRSchedulerType; + +typedef struct OrtLinearLRSchedulerParameters { + int64_t warmup_step_count; + int64_t total_step_count; +} OrtLinearLRSchedulerParameters; + struct OrtTrainingApi { /** \brief Load a checkpoint state from directory on disk into checkpoint_state. * @@ -144,6 +153,26 @@ struct OrtTrainingApi { size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs, size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs); + /** \brief Sets the learning rate for this training session. + * + * This function allows users to set the learning rate for the training session. The current + * learning rate is maintained by the training session and can be overwritten by invoking + * this function with the desired learning rate. This function should not be used when a valid + * learning rate scheduler is registered. It should be used either to set the learning rate + * derived from a custom learning rate scheduler or to set the learning rate constant to be used + * throughout the training session. + * Please note that this function does not set the initial learning rate that may be needed + * by the predefined learning rate schedulers. To set the initial learning rate for learning + * rate schedulers, please look at the function `RegisterLRScheduler`. + * + * \param[in] sess The training session on which the learning rate needs to be set. + * \param[in] learning_rate Desired learning rate to set. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(SetLearningRate, _Inout_ OrtTrainingSession* sess, _In_ float learning_rate); + /** \brief Performs the weight updates for the trainable parameters using the optimizer model. * * This function performs the weight update step that updates the trainable parameters such that they @@ -161,6 +190,37 @@ struct OrtTrainingApi { ORT_API2_STATUS(OptimizerStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options); + /** \brief Registers the use of the given learning rate scheduler for the training session. + * + * Register a learning rate scheduler identified by the given enum with the given + * learning rate scheduler parameters. Optionally specify the initial learning rate + * that should be used with this learning rate scheduler and training session. + * + * \param[in] sess The training session that should use the linear learning rate scheduler. + * \param[in] lr_scheduler_parameters Learning rate scheduler parameters struct. + * \param[in] initial_lr The initial learning rate to be used by the training session. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(RegisterLRScheduler, _Inout_ OrtTrainingSession* sess, _In_ void* lr_scheduler_parameters, + _In_ enum OrtLRSchedulerType lr_scheduler_type, _In_opt_ const float* initial_lr); + + /** \brief Update the learning rate based on the registered learing rate scheduler. + * + * Takes a scheduler step that updates the learning rate that is being used by the training session. + * This function should typically be called before invoking the optimizer step for each round, + * or as determined necessary to update the learning rate being used by the training session. + * Please note that a valid predefined learning rate scheduler must be first registered to invoke this + * function. + * + * \param[in] sess The training session that has the registered learning rate scheduler. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(SchedulerStep, _Inout_ OrtTrainingSession* sess); + /** \brief Frees up the memory used up by the training session. * * This function frees up any memory that was allocated in the training session. The training diff --git a/orttraining/orttraining/training_api/include/optimizer.h b/orttraining/orttraining/training_api/include/optimizer.h index 6b26f88280..9efa836473 100644 --- a/orttraining/orttraining/training_api/include/optimizer.h +++ b/orttraining/orttraining/training_api/include/optimizer.h @@ -68,16 +68,22 @@ struct Optimizer { Status LoadStateDict(const OptimizerCheckpointState& optimizer_checkpoint_states); + inline Status SetLearningRate(float lr) { + optimizer_state_.learning_rate = lr; + return Status::OK(); + } + + inline Status SetInitialLearningRate(float initial_lr) { + optimizer_state_.initial_lr = initial_lr; + optimizer_state_.learning_rate = initial_lr; + return Status::OK(); + } + private: int64_t GetStep() const { return optimizer_state_.step; } - Status SetLearningRate(float lr) { - optimizer_state_.learning_rate = lr; - return Status::OK(); - } - // Generates optimizer momentum states for applicable optimizer types Status GenerateMomentumNamedStates(); // Constructs the ortvalue inputs to be fed to the graph diff --git a/orttraining/orttraining/training_api/include/ort_training_apis.h b/orttraining/orttraining/training_api/include/ort_training_apis.h index e3bd741999..40215bf356 100644 --- a/orttraining/orttraining/training_api/include/ort_training_apis.h +++ b/orttraining/orttraining/training_api/include/ort_training_apis.h @@ -25,8 +25,15 @@ ORT_API_STATUS_IMPL(EvalStep, _In_ const OrtTrainingSession* session, _In_opt_ c size_t inputs_len, _In_reads_(input_len) const OrtValue* const* inputs, size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs); +ORT_API_STATUS_IMPL(SetLearningRate, _Inout_ OrtTrainingSession* sess, _In_ float learning_rate); + ORT_API_STATUS_IMPL(OptimizerStep, _Inout_ OrtTrainingSession* session, _In_opt_ const OrtRunOptions* run_options); +ORT_API_STATUS_IMPL(RegisterLRScheduler, _Inout_ OrtTrainingSession* sess, _In_ void* lr_scheduler_parameters, + _In_ enum OrtLRSchedulerType lr_scheduler_type, _In_opt_ const float* initial_lr); + +ORT_API_STATUS_IMPL(SchedulerStep, _Inout_ OrtTrainingSession* sess); + ORT_API_STATUS_IMPL(LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Outptr_ OrtCheckpointState** checkpoint_state); diff --git a/orttraining/orttraining/training_api/include/training_session.h b/orttraining/orttraining/training_api/include/training_session.h index eda91090a0..b1557988a9 100644 --- a/orttraining/orttraining/training_api/include/training_session.h +++ b/orttraining/orttraining/training_api/include/training_session.h @@ -5,6 +5,7 @@ #include "core/common/common.h" #include "module.h" #include "optimizer.h" +#include "lr_scheduler.h" #include "checkpoint.h" namespace onnxruntime { @@ -30,6 +31,10 @@ class TrainingSession { const std::unordered_map>& parameters, const ModelIdentifiers& model_identifiers); + Status RegisterScheduler(const std::function< + std::unique_ptr(std::shared_ptr)>& get_scheduler, + std::optional initial_lr); + size_t GetTrainModeOutputCount() const noexcept; size_t GetEvalModeOutputCount() const noexcept; @@ -46,6 +51,10 @@ class TrainingSession { Status OptimizerStep(const RunOptions& run_options); + Status SetLearningRate(float learning_rate) noexcept; + + Status SchedulerStep() noexcept; + Status CreateCheckpointState(CheckpointState& chkpt_state, bool save_optimizer_state) const; private: @@ -53,7 +62,8 @@ class TrainingSession { const std::unordered_map> named_parameters_; std::unique_ptr module_; - std::unique_ptr optimizer_; + std::shared_ptr optimizer_; + std::unique_ptr scheduler_; }; } // namespace api } // namespace training diff --git a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc index f954980593..6ef00e47d9 100644 --- a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc +++ b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc @@ -184,6 +184,17 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::EvalStep, _In_ const OrtTrainingSession* se API_IMPL_END } +ORT_API_STATUS_IMPL(OrtTrainingApis::SetLearningRate, _Inout_ OrtTrainingSession* sess, + _In_ float learning_rate) { + API_IMPL_BEGIN + + auto session = reinterpret_cast(sess); + ORT_API_RETURN_IF_STATUS_NOT_OK(session->SetLearningRate(learning_rate)); + + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtTrainingApis::OptimizerStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options) { API_IMPL_BEGIN @@ -199,6 +210,50 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::OptimizerStep, _Inout_ OrtTrainingSession* API_IMPL_END } +ORT_API_STATUS_IMPL(OrtTrainingApis::RegisterLRScheduler, _Inout_ OrtTrainingSession* sess, + _In_ void* lr_scheduler_parameters, _In_ enum OrtLRSchedulerType lr_scheduler_type, + _In_opt_ const float* initial_lr) { + API_IMPL_BEGIN + + OrtStatus* status = nullptr; + + auto session = reinterpret_cast(sess); + + if (!lr_scheduler_parameters) { + return OrtApis::CreateStatus(ORT_FAIL, "The provided learning rate scheduler parameters is a nullptr."); + } + + switch (lr_scheduler_type) { + case OrtLRSchedulerType::LinearLRScheduler: { + auto parameters = reinterpret_cast(lr_scheduler_parameters); + ORT_API_RETURN_IF_STATUS_NOT_OK( + session->RegisterScheduler([¶meters](auto optimizer) { + return std::make_unique( + optimizer, parameters->warmup_step_count, parameters->total_step_count); + }, + initial_lr ? std::optional(*initial_lr) : std::nullopt)); + break; + } + default: { + status = OrtApis::CreateStatus(ORT_FAIL, "Could not decipher the OrtLRSchedulerType from the given argument."); + break; + } + } + + return status; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtTrainingApis::SchedulerStep, _Inout_ OrtTrainingSession* sess) { + API_IMPL_BEGIN + + auto session = reinterpret_cast(sess); + ORT_API_RETURN_IF_STATUS_NOT_OK(session->SchedulerStep()); + + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtTrainingApis::LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, _Outptr_ OrtCheckpointState** checkpoint_state) { API_IMPL_BEGIN @@ -240,7 +295,10 @@ static constexpr OrtTrainingApi ort_training_api = { &OrtTrainingApis::ResetGrad, &OrtTrainingApis::TrainStep, &OrtTrainingApis::EvalStep, + &OrtTrainingApis::SetLearningRate, &OrtTrainingApis::OptimizerStep, + &OrtTrainingApis::RegisterLRScheduler, + &OrtTrainingApis::SchedulerStep, &OrtTrainingApis::ReleaseTrainingSession, &OrtTrainingApis::ReleaseCheckpointState, }; diff --git a/orttraining/orttraining/training_api/training_session.cc b/orttraining/orttraining/training_api/training_session.cc index fbd4e55d90..b3549aebb6 100644 --- a/orttraining/orttraining/training_api/training_session.cc +++ b/orttraining/orttraining/training_api/training_session.cc @@ -21,6 +21,19 @@ TrainingSession::TrainingSession(const Environment& session_env, session_options, session_env, providers) : std::unique_ptr()} {} +Status TrainingSession::RegisterScheduler( + const std::function(std::shared_ptr)>& get_scheduler, + std::optional initial_lr) { + scheduler_ = std::move(get_scheduler(optimizer_)); + ORT_RETURN_IF_NOT(scheduler_, "The provided instance of the learning rate scheduler is a nullptr."); + + if (initial_lr.has_value()) { + ORT_RETURN_IF_ERROR(optimizer_->SetInitialLearningRate(initial_lr.value())); + } + + return Status::OK(); +} + size_t TrainingSession::GetTrainModeOutputCount() const noexcept { return module_->GetTrainModeOutputCount(); } @@ -58,6 +71,17 @@ Status TrainingSession::CreateCheckpointState(CheckpointState& chkpt_state, bool return Status::OK(); } +Status TrainingSession::SetLearningRate(float learning_rate) noexcept { + ORT_RETURN_IF_ERROR(optimizer_->SetLearningRate(learning_rate)); + + return Status::OK(); +} + +Status TrainingSession::SchedulerStep() noexcept { + ORT_RETURN_IF_NOT(scheduler_, "No learning rate schedler was registered. Please register a valid learning rate scheduler"); + return scheduler_->Step(); +} + } // namespace api } // namespace training } // namespace onnxruntime