Add Learning Rate Scheduler C API (#11957)

This commit is contained in:
Baiju Meswani 2022-08-15 09:10:25 -07:00 committed by GitHub
parent 73da3f3705
commit f5e3517c39
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 177 additions and 11 deletions

View file

@ -256,10 +256,11 @@ int RunTraining(const TestRunnerParameters& params) {
onnxruntime::training::test::training_api::SyntheticDataLoader data_loader;
InitSyntheticDataLoader(data_loader, params, num_of_batches_per_epoch);
// TODO(baiju): Add C API for LRScheduler
// int64_t total_step_count = params.num_train_epochs * num_of_batches_per_epoch;
// int64_t warmup_step_count = total_step_count / 3;
// Ort::OrtLinearLRScheduler scheduler = Ort::OrtLinearLRScheduler(optimizer, warmup_step_count, total_step_count);
auto lr_scheduler_parameters = std::make_unique<OrtLinearLRSchedulerParameters>().get();
lr_scheduler_parameters->total_step_count = params.num_train_epochs * num_of_batches_per_epoch;
lr_scheduler_parameters->warmup_step_count = lr_scheduler_parameters->total_step_count / 3;
ORT_RETURN_ON_ERROR(g_ort_training_api->RegisterLRScheduler(
session, reinterpret_cast<void*>(lr_scheduler_parameters), OrtLRSchedulerType::LinearLRScheduler, nullptr));
std::cout << "Initialization completed. Now starting training loop." << std::endl;
const int64_t stabilized_perf_start_step = 0;
@ -309,7 +310,7 @@ int RunTraining(const TestRunnerParameters& params) {
#endif
// Update learning rate.
// EnforceCheck(scheduler.Step(), "Failed during shceduler.Step()");
ORT_RETURN_ON_ERROR(g_ort_training_api->SchedulerStep(session));
#if defined(USE_CUDA) && defined(ENABLE_NVTX_PROFILE)
onnxruntime::profile::NvtxRangeCreator resetgrad_range(

View file

@ -9,6 +9,15 @@
ORT_RUNTIME_CLASS(TrainingSession); /// Type that enables performing training for the given user models.
ORT_RUNTIME_CLASS(CheckpointState); /// Type that holds the training states for the training session.
typedef enum OrtLRSchedulerType {
LinearLRScheduler
} OrtLRSchedulerType;
typedef struct OrtLinearLRSchedulerParameters {
int64_t warmup_step_count;
int64_t total_step_count;
} OrtLinearLRSchedulerParameters;
struct OrtTrainingApi {
/** \brief Load a checkpoint state from directory on disk into checkpoint_state.
*
@ -144,6 +153,26 @@ struct OrtTrainingApi {
size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs,
size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
/** \brief Sets the learning rate for this training session.
*
* This function allows users to set the learning rate for the training session. The current
* learning rate is maintained by the training session and can be overwritten by invoking
* this function with the desired learning rate. This function should not be used when a valid
* learning rate scheduler is registered. It should be used either to set the learning rate
* derived from a custom learning rate scheduler or to set the learning rate constant to be used
* throughout the training session.
* Please note that this function does not set the initial learning rate that may be needed
* by the predefined learning rate schedulers. To set the initial learning rate for learning
* rate schedulers, please look at the function `RegisterLRScheduler`.
*
* \param[in] sess The training session on which the learning rate needs to be set.
* \param[in] learning_rate Desired learning rate to set.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(SetLearningRate, _Inout_ OrtTrainingSession* sess, _In_ float learning_rate);
/** \brief Performs the weight updates for the trainable parameters using the optimizer model.
*
* This function performs the weight update step that updates the trainable parameters such that they
@ -161,6 +190,37 @@ struct OrtTrainingApi {
ORT_API2_STATUS(OptimizerStep, _Inout_ OrtTrainingSession* sess,
_In_opt_ const OrtRunOptions* run_options);
/** \brief Registers the use of the given learning rate scheduler for the training session.
*
* Register a learning rate scheduler identified by the given enum with the given
* learning rate scheduler parameters. Optionally specify the initial learning rate
* that should be used with this learning rate scheduler and training session.
*
* \param[in] sess The training session that should use the linear learning rate scheduler.
* \param[in] lr_scheduler_parameters Learning rate scheduler parameters struct.
* \param[in] initial_lr The initial learning rate to be used by the training session.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(RegisterLRScheduler, _Inout_ OrtTrainingSession* sess, _In_ void* lr_scheduler_parameters,
_In_ enum OrtLRSchedulerType lr_scheduler_type, _In_opt_ const float* initial_lr);
/** \brief Update the learning rate based on the registered learing rate scheduler.
*
* Takes a scheduler step that updates the learning rate that is being used by the training session.
* This function should typically be called before invoking the optimizer step for each round,
* or as determined necessary to update the learning rate being used by the training session.
* Please note that a valid predefined learning rate scheduler must be first registered to invoke this
* function.
*
* \param[in] sess The training session that has the registered learning rate scheduler.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(SchedulerStep, _Inout_ OrtTrainingSession* sess);
/** \brief Frees up the memory used up by the training session.
*
* This function frees up any memory that was allocated in the training session. The training

View file

@ -68,16 +68,22 @@ struct Optimizer {
Status LoadStateDict(const OptimizerCheckpointState& optimizer_checkpoint_states);
inline Status SetLearningRate(float lr) {
optimizer_state_.learning_rate = lr;
return Status::OK();
}
inline Status SetInitialLearningRate(float initial_lr) {
optimizer_state_.initial_lr = initial_lr;
optimizer_state_.learning_rate = initial_lr;
return Status::OK();
}
private:
int64_t GetStep() const {
return optimizer_state_.step;
}
Status SetLearningRate(float lr) {
optimizer_state_.learning_rate = lr;
return Status::OK();
}
// Generates optimizer momentum states for applicable optimizer types
Status GenerateMomentumNamedStates();
// Constructs the ortvalue inputs to be fed to the graph

View file

@ -25,8 +25,15 @@ ORT_API_STATUS_IMPL(EvalStep, _In_ const OrtTrainingSession* session, _In_opt_ c
size_t inputs_len, _In_reads_(input_len) const OrtValue* const* inputs,
size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
ORT_API_STATUS_IMPL(SetLearningRate, _Inout_ OrtTrainingSession* sess, _In_ float learning_rate);
ORT_API_STATUS_IMPL(OptimizerStep, _Inout_ OrtTrainingSession* session, _In_opt_ const OrtRunOptions* run_options);
ORT_API_STATUS_IMPL(RegisterLRScheduler, _Inout_ OrtTrainingSession* sess, _In_ void* lr_scheduler_parameters,
_In_ enum OrtLRSchedulerType lr_scheduler_type, _In_opt_ const float* initial_lr);
ORT_API_STATUS_IMPL(SchedulerStep, _Inout_ OrtTrainingSession* sess);
ORT_API_STATUS_IMPL(LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path,
_Outptr_ OrtCheckpointState** checkpoint_state);

View file

@ -5,6 +5,7 @@
#include "core/common/common.h"
#include "module.h"
#include "optimizer.h"
#include "lr_scheduler.h"
#include "checkpoint.h"
namespace onnxruntime {
@ -30,6 +31,10 @@ class TrainingSession {
const std::unordered_map<std::string, std::shared_ptr<Parameter>>& parameters,
const ModelIdentifiers& model_identifiers);
Status RegisterScheduler(const std::function<
std::unique_ptr<LRSchedulerBase>(std::shared_ptr<Optimizer>)>& get_scheduler,
std::optional<float> initial_lr);
size_t GetTrainModeOutputCount() const noexcept;
size_t GetEvalModeOutputCount() const noexcept;
@ -46,6 +51,10 @@ class TrainingSession {
Status OptimizerStep(const RunOptions& run_options);
Status SetLearningRate(float learning_rate) noexcept;
Status SchedulerStep() noexcept;
Status CreateCheckpointState(CheckpointState& chkpt_state, bool save_optimizer_state) const;
private:
@ -53,7 +62,8 @@ class TrainingSession {
const std::unordered_map<std::string, std::shared_ptr<Parameter>> named_parameters_;
std::unique_ptr<Module> module_;
std::unique_ptr<Optimizer> optimizer_;
std::shared_ptr<Optimizer> optimizer_;
std::unique_ptr<LRSchedulerBase> scheduler_;
};
} // namespace api
} // namespace training

View file

@ -184,6 +184,17 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::EvalStep, _In_ const OrtTrainingSession* se
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtTrainingApis::SetLearningRate, _Inout_ OrtTrainingSession* sess,
_In_ float learning_rate) {
API_IMPL_BEGIN
auto session = reinterpret_cast<onnxruntime::training::api::TrainingSession*>(sess);
ORT_API_RETURN_IF_STATUS_NOT_OK(session->SetLearningRate(learning_rate));
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtTrainingApis::OptimizerStep, _Inout_ OrtTrainingSession* sess,
_In_opt_ const OrtRunOptions* run_options) {
API_IMPL_BEGIN
@ -199,6 +210,50 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::OptimizerStep, _Inout_ OrtTrainingSession*
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtTrainingApis::RegisterLRScheduler, _Inout_ OrtTrainingSession* sess,
_In_ void* lr_scheduler_parameters, _In_ enum OrtLRSchedulerType lr_scheduler_type,
_In_opt_ const float* initial_lr) {
API_IMPL_BEGIN
OrtStatus* status = nullptr;
auto session = reinterpret_cast<onnxruntime::training::api::TrainingSession*>(sess);
if (!lr_scheduler_parameters) {
return OrtApis::CreateStatus(ORT_FAIL, "The provided learning rate scheduler parameters is a nullptr.");
}
switch (lr_scheduler_type) {
case OrtLRSchedulerType::LinearLRScheduler: {
auto parameters = reinterpret_cast<OrtLinearLRSchedulerParameters*>(lr_scheduler_parameters);
ORT_API_RETURN_IF_STATUS_NOT_OK(
session->RegisterScheduler([&parameters](auto optimizer) {
return std::make_unique<onnxruntime::training::api::LinearLRScheduler>(
optimizer, parameters->warmup_step_count, parameters->total_step_count);
},
initial_lr ? std::optional<float>(*initial_lr) : std::nullopt));
break;
}
default: {
status = OrtApis::CreateStatus(ORT_FAIL, "Could not decipher the OrtLRSchedulerType from the given argument.");
break;
}
}
return status;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtTrainingApis::SchedulerStep, _Inout_ OrtTrainingSession* sess) {
API_IMPL_BEGIN
auto session = reinterpret_cast<onnxruntime::training::api::TrainingSession*>(sess);
ORT_API_RETURN_IF_STATUS_NOT_OK(session->SchedulerStep());
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtTrainingApis::LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path,
_Outptr_ OrtCheckpointState** checkpoint_state) {
API_IMPL_BEGIN
@ -240,7 +295,10 @@ static constexpr OrtTrainingApi ort_training_api = {
&OrtTrainingApis::ResetGrad,
&OrtTrainingApis::TrainStep,
&OrtTrainingApis::EvalStep,
&OrtTrainingApis::SetLearningRate,
&OrtTrainingApis::OptimizerStep,
&OrtTrainingApis::RegisterLRScheduler,
&OrtTrainingApis::SchedulerStep,
&OrtTrainingApis::ReleaseTrainingSession,
&OrtTrainingApis::ReleaseCheckpointState,
};

View file

@ -21,6 +21,19 @@ TrainingSession::TrainingSession(const Environment& session_env,
session_options, session_env, providers)
: std::unique_ptr<Optimizer>()} {}
Status TrainingSession::RegisterScheduler(
const std::function<std::unique_ptr<LRSchedulerBase>(std::shared_ptr<Optimizer>)>& get_scheduler,
std::optional<float> initial_lr) {
scheduler_ = std::move(get_scheduler(optimizer_));
ORT_RETURN_IF_NOT(scheduler_, "The provided instance of the learning rate scheduler is a nullptr.");
if (initial_lr.has_value()) {
ORT_RETURN_IF_ERROR(optimizer_->SetInitialLearningRate(initial_lr.value()));
}
return Status::OK();
}
size_t TrainingSession::GetTrainModeOutputCount() const noexcept {
return module_->GetTrainModeOutputCount();
}
@ -58,6 +71,17 @@ Status TrainingSession::CreateCheckpointState(CheckpointState& chkpt_state, bool
return Status::OK();
}
Status TrainingSession::SetLearningRate(float learning_rate) noexcept {
ORT_RETURN_IF_ERROR(optimizer_->SetLearningRate(learning_rate));
return Status::OK();
}
Status TrainingSession::SchedulerStep() noexcept {
ORT_RETURN_IF_NOT(scheduler_, "No learning rate schedler was registered. Please register a valid learning rate scheduler");
return scheduler_->Step();
}
} // namespace api
} // namespace training
} // namespace onnxruntime