adding get and set lr for optimizer (#13661)

### Description
Exposing get and set Learning rate for optimizer


### Motivation and Context
you can now set learning rate for optimizer.
This commit is contained in:
Adam Louly 2022-12-07 13:59:11 -06:00 committed by GitHub
parent 983877c712
commit f453d2845e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 44 additions and 0 deletions

View file

@ -897,6 +897,12 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
model->NamedParameters(), session_option,
GetTrainingORTEnv(), std::vector<std::shared_ptr<IExecutionProvider>>());
}))
.def("set_learning_rate", [](onnxruntime::training::api::Optimizer* optimizer, float lr) -> void {
ORT_THROW_IF_ERROR(optimizer->SetLearningRate(lr));
})
.def("get_learning_rate", [](onnxruntime::training::api::Optimizer* optimizer) -> float {
return optimizer->GetLearningRate();
})
.def("optimizer_step", [](onnxruntime::training::api::Optimizer* optimizer) -> void {
ORT_THROW_IF_ERROR(optimizer->Step());
});

View file

@ -22,3 +22,15 @@ class Optimizer:
Run Optimizer Step.
"""
self._optimizer.optimizer_step()
def set_learning_rate(self, learning_rate: float) -> None:
"""
Set Learning Rate.
"""
self._optimizer.set_learning_rate(learning_rate)
def get_learning_rate(self) -> float:
"""
Get Learning Rate.
"""
return self._optimizer.get_learning_rate()

View file

@ -142,6 +142,32 @@ def test_optimizer_step():
# TODO : Check if parameters changed from before and after optimizer step.
def test_get_and_set_lr():
# Initialize Models
simple_model, onnx_model, optimizer_model, _, _ = _create_training_models()
with tempfile.TemporaryDirectory() as temp_dir:
# Save models & checkpoint files to load them later.
checkpoint_file_path, model_file_path, optimizer_file_path = _get_test_models_path(
temp_dir, simple_model, onnx_model, optimizer_model=optimizer_model
)
# Create Checkpoint State.
state = CheckpointState(checkpoint_file_path)
# Create a Module and Optimizer.
model = Module(model_file_path, state)
optimizer = Optimizer(optimizer_file_path, model)
# Test get and set learning rate.
lr = optimizer.get_learning_rate()
assert round(lr, 3) == 0.001
optimizer.set_learning_rate(0.5)
new_lr = optimizer.get_learning_rate()
assert np.isclose(new_lr, 0.5)
assert lr != new_lr
def test_training_module_checkpoint():
# Initialize Models
simple_model, onnx_model, _, _, _ = _create_training_models()