expose lr scheduler python bindings for on device training. (#13882)

### Description
Exposing LR Scheduler python bindings for on device training.

Co-authored-by: Baiju Meswani <bmeswani@microsoft.com>
This commit is contained in:
Adam Louly 2022-12-22 20:44:04 -06:00 committed by GitHub
parent b999022b03
commit e49f358686
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 123 additions and 27 deletions

View file

@ -15,24 +15,23 @@ namespace py = pybind11;
using namespace onnxruntime::logging;
using ExecutionProviderMap = std::unordered_map<std::string, std::shared_ptr<IExecutionProvider> >;
using ExecutionProviderLibInfoMap = std::unordered_map<std::string, std::pair<std::string, ProviderOptions> > ;
using ExecutionProviderMap = std::unordered_map<std::string, std::shared_ptr<IExecutionProvider>>;
using ExecutionProviderLibInfoMap = std::unordered_map<std::string, std::pair<std::string, ProviderOptions>>;
class ORTTrainingPythonEnv{
public:
class ORTTrainingPythonEnv {
public:
ORTTrainingPythonEnv();
Environment& GetORTEnv();
std::shared_ptr<IExecutionProvider> GetExecutionProviderInstance(const std::string& provider_type,
size_t hash);
size_t hash);
void AddExecutionProvider(const std::string& provider_type,
size_t hash,
std::unique_ptr<IExecutionProvider> execution_provider);
void RegisterExtExecutionProviderInfo(const std::string& provider_type,
void RegisterExtExecutionProviderInfo(const std::string& provider_type,
const std::string& provider_lib_path,
const ProviderOptions& default_options);
@ -42,7 +41,7 @@ public:
void ClearExecutionProviderInstances();
private:
private:
std::string GetExecutionProviderMapKey(const std::string& provider_type,
size_t hash);
@ -51,5 +50,5 @@ private:
std::vector<std::string> available_training_eps_;
};
}
}
} // namespace python
} // namespace onnxruntime

View file

@ -36,7 +36,7 @@
#ifdef ENABLE_TRAINING_ON_DEVICE
#include "orttraining/training_api/include/checkpoint.h"
#include "core/providers/provider_factory_creators.h"
#include "orttraining/training_api/include/lr_scheduler.h"
#endif
@ -164,6 +164,20 @@ struct TrainingConfigurationResult {
optional<std::string> loss_scale_input_name;
};
#ifdef ENABLE_TRAINING_ON_DEVICE
// Thin wrapper over internal C++ Optimizer
struct PyOptimizer {
PyOptimizer(const std::string optimizer_model_uri,
onnxruntime::training::api::Module* model, std::vector<std::shared_ptr<IExecutionProvider>> provider)
: optimizer_(std::make_unique<onnxruntime::training::api::Optimizer>(optimizer_model_uri,
model->NamedParameters(), onnxruntime::SessionOptions(),
GetTrainingORTEnv(), provider)) {
}
std::shared_ptr<onnxruntime::training::api::Optimizer> optimizer_;
};
#endif
struct PyGradientGraphBuilder {
std::unique_ptr<GradientGraphBuilder> builder;
std::shared_ptr<Model> model;
@ -917,7 +931,7 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
return state;
}));
py::class_<onnxruntime::training::api::Optimizer>
py::class_<PyOptimizer>
training_optimizer(m, "Optimizer", R"pbdoc(Training Optimizer.)pbdoc");
training_optimizer.def(py::init([](
const std::string optimizer_model_uri,
@ -925,21 +939,34 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
OrtDevice device) {
onnxruntime::SessionOptions session_option;
std::vector<std::shared_ptr<IExecutionProvider>> provider = GetExecutionProvidersForTrainingApis(device);
return std::make_unique<onnxruntime::training::api::Optimizer>(
optimizer_model_uri,
model->NamedParameters(), session_option,
GetTrainingORTEnv(), provider);
}))
.def("set_learning_rate", [](onnxruntime::training::api::Optimizer* optimizer, float lr) -> void {
ORT_THROW_IF_ERROR(optimizer->SetLearningRate(lr));
})
.def("get_learning_rate", [](onnxruntime::training::api::Optimizer* optimizer) -> float {
return optimizer->GetLearningRate();
})
.def("optimizer_step", [](onnxruntime::training::api::Optimizer* optimizer) -> void {
ORT_THROW_IF_ERROR(optimizer->Step());
});
return std::make_unique<PyOptimizer>(
optimizer_model_uri,
model, provider);
}))
.def("optimizer_step", [](PyOptimizer* optimizer) -> void {
ORT_THROW_IF_ERROR(optimizer->optimizer_->Step());
})
.def("set_learning_rate", [](PyOptimizer* optimizer, float lr) -> void {
ORT_THROW_IF_ERROR(optimizer->optimizer_->SetLearningRate(lr));
})
.def("get_learning_rate", [](PyOptimizer* optimizer) -> float {
return optimizer->optimizer_->GetLearningRate();
});
py::class_<onnxruntime::training::api::LinearLRScheduler>
lr_scheduler(m, "LinearLRScheduler", R"pbdoc(Learning Rate Scheduler.)pbdoc");
lr_scheduler.def(py::init([](PyOptimizer* optimizer,
int64_t total_step_count,
int64_t warmup_step_count,
float initial_lr) {
ORT_THROW_IF_ERROR(optimizer->optimizer_->SetInitialLearningRate(initial_lr));
return std::make_unique<onnxruntime::training::api::LinearLRScheduler>(
optimizer->optimizer_, warmup_step_count, total_step_count);
}))
.def("scheduler_step", [](onnxruntime::training::api::LinearLRScheduler* scheduler) -> void {
ORT_THROW_IF_ERROR(scheduler->Step());
});
m.def("save_checkpoint",
[](const std::vector<py::bytes>& trainable_tensor_protos_pybytes,
const std::vector<py::bytes>& non_trainable_tensor_protos_pybytes,

View file

@ -1,3 +1,4 @@
from .checkpoint_state import CheckpointState
from .lr_scheduler import LinearLRScheduler
from .module import Module
from .optimizer import Optimizer

View file

@ -0,0 +1,34 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# lr_scheduler.py
from onnxruntime.capi import _pybind_state as C
class LinearLRScheduler:
"""
Linearly updates the learning rate in the optimizer
The linear learning rate scheduler decays the learning rate by linearly updated
multiplicative factor from the initial learning rate set on the training session to 0. The decay
is performed after the initial warm up phase where the learning rate is linearly incremented
from to the initial learning rate provided.
Args:
optimizer (:obj:`training_api.Optimizer`): User's onnxruntime training Optimizer
warmup_step_count (int): The number of steps in the warm up phase.
total_step_count (int): The total number of training steps.
initial_lr (float): The initial learning rate.
"""
def __init__(self, optimizer, warmup_step_count, total_step_count, initial_lr) -> None:
self._scheduler = C.LinearLRScheduler(optimizer._optimizer, warmup_step_count, total_step_count, initial_lr)
def step(self):
"""
The step method of the LinearLRScheduler class is used to update the learning rate of the optimizer according
to the scheduler's strategy.
This method should be called at each step of training to ensure that the learning rate is properly adjusted.
"""
self._scheduler.scheduler_step()

View file

@ -7,7 +7,7 @@ import torch
from orttraining_test_onnxblock import _get_models
import onnxruntime.training.onnxblock as onnxblock
from onnxruntime.training.api import CheckpointState, Module, Optimizer
from onnxruntime.training.api import CheckpointState, LinearLRScheduler, Module, Optimizer
class SimpleModelWithCrossEntropyLoss(onnxblock.TrainingModel):
@ -168,6 +168,41 @@ def test_get_and_set_lr():
assert lr != new_lr
def test_scheduler_step():
# Initialize Models
simple_model, onnx_model, optimizer_model, _, _ = _create_training_models()
# Generating random data for testing.
inputs = torch.randn(64, 784).numpy()
labels = torch.randint(high=10, size=(64,), dtype=torch.int32).numpy()
forward_inputs = [inputs, labels]
with tempfile.TemporaryDirectory() as temp_dir:
# Save models & checkpoint files to load them later.
checkpoint_file_path, model_file_path, optimizer_file_path = _get_test_models_path(
temp_dir, simple_model, onnx_model, optimizer_model=optimizer_model
)
# Create Checkpoint State.
state = CheckpointState(checkpoint_file_path)
# Create a Module and Optimizer.
model = Module(model_file_path, state)
optimizer = Optimizer(optimizer_file_path, model)
scheduler = LinearLRScheduler(optimizer, 1, 2, 0.2)
# Test get and set learning rate.
lr = optimizer.get_learning_rate()
assert np.allclose(lr, 0.0)
model.train()
model(forward_inputs)
optimizer.step()
scheduler.step()
# Get new learning rate.
new_lr = optimizer.get_learning_rate()
assert new_lr != lr
def test_training_module_checkpoint():
# Initialize Models
simple_model, onnx_model, _, _, _ = _create_training_models()