mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
expose lr scheduler python bindings for on device training. (#13882)
### Description Exposing LR Scheduler python bindings for on device training. Co-authored-by: Baiju Meswani <bmeswani@microsoft.com>
This commit is contained in:
parent
b999022b03
commit
e49f358686
5 changed files with 123 additions and 27 deletions
|
|
@ -15,24 +15,23 @@ namespace py = pybind11;
|
|||
|
||||
using namespace onnxruntime::logging;
|
||||
|
||||
using ExecutionProviderMap = std::unordered_map<std::string, std::shared_ptr<IExecutionProvider> >;
|
||||
using ExecutionProviderLibInfoMap = std::unordered_map<std::string, std::pair<std::string, ProviderOptions> > ;
|
||||
using ExecutionProviderMap = std::unordered_map<std::string, std::shared_ptr<IExecutionProvider>>;
|
||||
using ExecutionProviderLibInfoMap = std::unordered_map<std::string, std::pair<std::string, ProviderOptions>>;
|
||||
|
||||
|
||||
class ORTTrainingPythonEnv{
|
||||
public:
|
||||
class ORTTrainingPythonEnv {
|
||||
public:
|
||||
ORTTrainingPythonEnv();
|
||||
|
||||
Environment& GetORTEnv();
|
||||
|
||||
std::shared_ptr<IExecutionProvider> GetExecutionProviderInstance(const std::string& provider_type,
|
||||
size_t hash);
|
||||
size_t hash);
|
||||
|
||||
void AddExecutionProvider(const std::string& provider_type,
|
||||
size_t hash,
|
||||
std::unique_ptr<IExecutionProvider> execution_provider);
|
||||
|
||||
void RegisterExtExecutionProviderInfo(const std::string& provider_type,
|
||||
void RegisterExtExecutionProviderInfo(const std::string& provider_type,
|
||||
const std::string& provider_lib_path,
|
||||
const ProviderOptions& default_options);
|
||||
|
||||
|
|
@ -42,7 +41,7 @@ public:
|
|||
|
||||
void ClearExecutionProviderInstances();
|
||||
|
||||
private:
|
||||
private:
|
||||
std::string GetExecutionProviderMapKey(const std::string& provider_type,
|
||||
size_t hash);
|
||||
|
||||
|
|
@ -51,5 +50,5 @@ private:
|
|||
std::vector<std::string> available_training_eps_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace python
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@
|
|||
|
||||
#ifdef ENABLE_TRAINING_ON_DEVICE
|
||||
#include "orttraining/training_api/include/checkpoint.h"
|
||||
#include "core/providers/provider_factory_creators.h"
|
||||
#include "orttraining/training_api/include/lr_scheduler.h"
|
||||
|
||||
#endif
|
||||
|
||||
|
|
@ -164,6 +164,20 @@ struct TrainingConfigurationResult {
|
|||
optional<std::string> loss_scale_input_name;
|
||||
};
|
||||
|
||||
#ifdef ENABLE_TRAINING_ON_DEVICE
|
||||
// Thin wrapper over internal C++ Optimizer
|
||||
struct PyOptimizer {
|
||||
PyOptimizer(const std::string optimizer_model_uri,
|
||||
onnxruntime::training::api::Module* model, std::vector<std::shared_ptr<IExecutionProvider>> provider)
|
||||
: optimizer_(std::make_unique<onnxruntime::training::api::Optimizer>(optimizer_model_uri,
|
||||
model->NamedParameters(), onnxruntime::SessionOptions(),
|
||||
GetTrainingORTEnv(), provider)) {
|
||||
}
|
||||
|
||||
std::shared_ptr<onnxruntime::training::api::Optimizer> optimizer_;
|
||||
};
|
||||
#endif
|
||||
|
||||
struct PyGradientGraphBuilder {
|
||||
std::unique_ptr<GradientGraphBuilder> builder;
|
||||
std::shared_ptr<Model> model;
|
||||
|
|
@ -917,7 +931,7 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
|
|||
return state;
|
||||
}));
|
||||
|
||||
py::class_<onnxruntime::training::api::Optimizer>
|
||||
py::class_<PyOptimizer>
|
||||
training_optimizer(m, "Optimizer", R"pbdoc(Training Optimizer.)pbdoc");
|
||||
training_optimizer.def(py::init([](
|
||||
const std::string optimizer_model_uri,
|
||||
|
|
@ -925,21 +939,34 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
|
|||
OrtDevice device) {
|
||||
onnxruntime::SessionOptions session_option;
|
||||
std::vector<std::shared_ptr<IExecutionProvider>> provider = GetExecutionProvidersForTrainingApis(device);
|
||||
return std::make_unique<onnxruntime::training::api::Optimizer>(
|
||||
optimizer_model_uri,
|
||||
model->NamedParameters(), session_option,
|
||||
GetTrainingORTEnv(), provider);
|
||||
}))
|
||||
.def("set_learning_rate", [](onnxruntime::training::api::Optimizer* optimizer, float lr) -> void {
|
||||
ORT_THROW_IF_ERROR(optimizer->SetLearningRate(lr));
|
||||
})
|
||||
.def("get_learning_rate", [](onnxruntime::training::api::Optimizer* optimizer) -> float {
|
||||
return optimizer->GetLearningRate();
|
||||
})
|
||||
.def("optimizer_step", [](onnxruntime::training::api::Optimizer* optimizer) -> void {
|
||||
ORT_THROW_IF_ERROR(optimizer->Step());
|
||||
});
|
||||
|
||||
return std::make_unique<PyOptimizer>(
|
||||
optimizer_model_uri,
|
||||
model, provider);
|
||||
}))
|
||||
.def("optimizer_step", [](PyOptimizer* optimizer) -> void {
|
||||
ORT_THROW_IF_ERROR(optimizer->optimizer_->Step());
|
||||
})
|
||||
.def("set_learning_rate", [](PyOptimizer* optimizer, float lr) -> void {
|
||||
ORT_THROW_IF_ERROR(optimizer->optimizer_->SetLearningRate(lr));
|
||||
})
|
||||
.def("get_learning_rate", [](PyOptimizer* optimizer) -> float {
|
||||
return optimizer->optimizer_->GetLearningRate();
|
||||
});
|
||||
py::class_<onnxruntime::training::api::LinearLRScheduler>
|
||||
lr_scheduler(m, "LinearLRScheduler", R"pbdoc(Learning Rate Scheduler.)pbdoc");
|
||||
lr_scheduler.def(py::init([](PyOptimizer* optimizer,
|
||||
int64_t total_step_count,
|
||||
int64_t warmup_step_count,
|
||||
float initial_lr) {
|
||||
ORT_THROW_IF_ERROR(optimizer->optimizer_->SetInitialLearningRate(initial_lr));
|
||||
|
||||
return std::make_unique<onnxruntime::training::api::LinearLRScheduler>(
|
||||
optimizer->optimizer_, warmup_step_count, total_step_count);
|
||||
}))
|
||||
.def("scheduler_step", [](onnxruntime::training::api::LinearLRScheduler* scheduler) -> void {
|
||||
ORT_THROW_IF_ERROR(scheduler->Step());
|
||||
});
|
||||
m.def("save_checkpoint",
|
||||
[](const std::vector<py::bytes>& trainable_tensor_protos_pybytes,
|
||||
const std::vector<py::bytes>& non_trainable_tensor_protos_pybytes,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from .checkpoint_state import CheckpointState
|
||||
from .lr_scheduler import LinearLRScheduler
|
||||
from .module import Module
|
||||
from .optimizer import Optimizer
|
||||
|
|
|
|||
34
orttraining/orttraining/python/training/api/lr_scheduler.py
Normal file
34
orttraining/orttraining/python/training/api/lr_scheduler.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# lr_scheduler.py
|
||||
|
||||
from onnxruntime.capi import _pybind_state as C
|
||||
|
||||
|
||||
class LinearLRScheduler:
|
||||
"""
|
||||
Linearly updates the learning rate in the optimizer
|
||||
|
||||
The linear learning rate scheduler decays the learning rate by linearly updated
|
||||
multiplicative factor from the initial learning rate set on the training session to 0. The decay
|
||||
is performed after the initial warm up phase where the learning rate is linearly incremented
|
||||
from to the initial learning rate provided.
|
||||
|
||||
Args:
|
||||
optimizer (:obj:`training_api.Optimizer`): User's onnxruntime training Optimizer
|
||||
warmup_step_count (int): The number of steps in the warm up phase.
|
||||
total_step_count (int): The total number of training steps.
|
||||
initial_lr (float): The initial learning rate.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, warmup_step_count, total_step_count, initial_lr) -> None:
|
||||
|
||||
self._scheduler = C.LinearLRScheduler(optimizer._optimizer, warmup_step_count, total_step_count, initial_lr)
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
The step method of the LinearLRScheduler class is used to update the learning rate of the optimizer according
|
||||
to the scheduler's strategy.
|
||||
This method should be called at each step of training to ensure that the learning rate is properly adjusted.
|
||||
"""
|
||||
self._scheduler.scheduler_step()
|
||||
|
|
@ -7,7 +7,7 @@ import torch
|
|||
from orttraining_test_onnxblock import _get_models
|
||||
|
||||
import onnxruntime.training.onnxblock as onnxblock
|
||||
from onnxruntime.training.api import CheckpointState, Module, Optimizer
|
||||
from onnxruntime.training.api import CheckpointState, LinearLRScheduler, Module, Optimizer
|
||||
|
||||
|
||||
class SimpleModelWithCrossEntropyLoss(onnxblock.TrainingModel):
|
||||
|
|
@ -168,6 +168,41 @@ def test_get_and_set_lr():
|
|||
assert lr != new_lr
|
||||
|
||||
|
||||
def test_scheduler_step():
|
||||
# Initialize Models
|
||||
simple_model, onnx_model, optimizer_model, _, _ = _create_training_models()
|
||||
|
||||
# Generating random data for testing.
|
||||
inputs = torch.randn(64, 784).numpy()
|
||||
labels = torch.randint(high=10, size=(64,), dtype=torch.int32).numpy()
|
||||
forward_inputs = [inputs, labels]
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Save models & checkpoint files to load them later.
|
||||
checkpoint_file_path, model_file_path, optimizer_file_path = _get_test_models_path(
|
||||
temp_dir, simple_model, onnx_model, optimizer_model=optimizer_model
|
||||
)
|
||||
# Create Checkpoint State.
|
||||
state = CheckpointState(checkpoint_file_path)
|
||||
# Create a Module and Optimizer.
|
||||
model = Module(model_file_path, state)
|
||||
optimizer = Optimizer(optimizer_file_path, model)
|
||||
scheduler = LinearLRScheduler(optimizer, 1, 2, 0.2)
|
||||
|
||||
# Test get and set learning rate.
|
||||
lr = optimizer.get_learning_rate()
|
||||
assert np.allclose(lr, 0.0)
|
||||
|
||||
model.train()
|
||||
model(forward_inputs)
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
# Get new learning rate.
|
||||
new_lr = optimizer.get_learning_rate()
|
||||
assert new_lr != lr
|
||||
|
||||
|
||||
def test_training_module_checkpoint():
|
||||
# Initialize Models
|
||||
simple_model, onnx_model, _, _, _ = _create_training_models()
|
||||
|
|
|
|||
Loading…
Reference in a new issue