mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-05 04:17:53 +00:00
adding get and set lr for optimizer (#13661)
### Description Exposing get and set Learning rate for optimizer ### Motivation and Context you can now set learning rate for optimizer.
This commit is contained in:
parent
983877c712
commit
f453d2845e
3 changed files with 44 additions and 0 deletions
|
|
@ -897,6 +897,12 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
|
|||
model->NamedParameters(), session_option,
|
||||
GetTrainingORTEnv(), std::vector<std::shared_ptr<IExecutionProvider>>());
|
||||
}))
|
||||
.def("set_learning_rate", [](onnxruntime::training::api::Optimizer* optimizer, float lr) -> void {
|
||||
ORT_THROW_IF_ERROR(optimizer->SetLearningRate(lr));
|
||||
})
|
||||
.def("get_learning_rate", [](onnxruntime::training::api::Optimizer* optimizer) -> float {
|
||||
return optimizer->GetLearningRate();
|
||||
})
|
||||
.def("optimizer_step", [](onnxruntime::training::api::Optimizer* optimizer) -> void {
|
||||
ORT_THROW_IF_ERROR(optimizer->Step());
|
||||
});
|
||||
|
|
|
|||
|
|
@ -22,3 +22,15 @@ class Optimizer:
|
|||
Run Optimizer Step.
|
||||
"""
|
||||
self._optimizer.optimizer_step()
|
||||
|
||||
def set_learning_rate(self, learning_rate: float) -> None:
|
||||
"""
|
||||
Set Learning Rate.
|
||||
"""
|
||||
self._optimizer.set_learning_rate(learning_rate)
|
||||
|
||||
def get_learning_rate(self) -> float:
|
||||
"""
|
||||
Get Learning Rate.
|
||||
"""
|
||||
return self._optimizer.get_learning_rate()
|
||||
|
|
|
|||
|
|
@ -142,6 +142,32 @@ def test_optimizer_step():
|
|||
# TODO : Check if parameters changed from before and after optimizer step.
|
||||
|
||||
|
||||
def test_get_and_set_lr():
|
||||
# Initialize Models
|
||||
simple_model, onnx_model, optimizer_model, _, _ = _create_training_models()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Save models & checkpoint files to load them later.
|
||||
checkpoint_file_path, model_file_path, optimizer_file_path = _get_test_models_path(
|
||||
temp_dir, simple_model, onnx_model, optimizer_model=optimizer_model
|
||||
)
|
||||
# Create Checkpoint State.
|
||||
state = CheckpointState(checkpoint_file_path)
|
||||
# Create a Module and Optimizer.
|
||||
model = Module(model_file_path, state)
|
||||
optimizer = Optimizer(optimizer_file_path, model)
|
||||
|
||||
# Test get and set learning rate.
|
||||
lr = optimizer.get_learning_rate()
|
||||
assert round(lr, 3) == 0.001
|
||||
|
||||
optimizer.set_learning_rate(0.5)
|
||||
new_lr = optimizer.get_learning_rate()
|
||||
|
||||
assert np.isclose(new_lr, 0.5)
|
||||
assert lr != new_lr
|
||||
|
||||
|
||||
def test_training_module_checkpoint():
|
||||
# Initialize Models
|
||||
simple_model, onnx_model, _, _, _ = _create_training_models()
|
||||
|
|
|
|||
Loading…
Reference in a new issue