diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index f47dac274b..ca9b63f80a 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -897,6 +897,12 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn model->NamedParameters(), session_option, GetTrainingORTEnv(), std::vector>()); })) + .def("set_learning_rate", [](onnxruntime::training::api::Optimizer* optimizer, float lr) -> void { + ORT_THROW_IF_ERROR(optimizer->SetLearningRate(lr)); + }) + .def("get_learning_rate", [](onnxruntime::training::api::Optimizer* optimizer) -> float { + return optimizer->GetLearningRate(); + }) .def("optimizer_step", [](onnxruntime::training::api::Optimizer* optimizer) -> void { ORT_THROW_IF_ERROR(optimizer->Step()); }); diff --git a/orttraining/orttraining/python/training/api/optimizer.py b/orttraining/orttraining/python/training/api/optimizer.py index afb5185588..65480a5e9e 100644 --- a/orttraining/orttraining/python/training/api/optimizer.py +++ b/orttraining/orttraining/python/training/api/optimizer.py @@ -22,3 +22,15 @@ class Optimizer: Run Optimizer Step. """ self._optimizer.optimizer_step() + + def set_learning_rate(self, learning_rate: float) -> None: + """ + Set Learning Rate. + """ + self._optimizer.set_learning_rate(learning_rate) + + def get_learning_rate(self) -> float: + """ + Get Learning Rate. + """ + return self._optimizer.get_learning_rate() diff --git a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py b/orttraining/orttraining/test/python/orttraining_test_python_bindings.py index 278f8f59b6..378e58a3ac 100644 --- a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py +++ b/orttraining/orttraining/test/python/orttraining_test_python_bindings.py @@ -142,6 +142,32 @@ def test_optimizer_step(): # TODO : Check if parameters changed from before and after optimizer step. +def test_get_and_set_lr(): + # Initialize Models + simple_model, onnx_model, optimizer_model, _, _ = _create_training_models() + + with tempfile.TemporaryDirectory() as temp_dir: + # Save models & checkpoint files to load them later. + checkpoint_file_path, model_file_path, optimizer_file_path = _get_test_models_path( + temp_dir, simple_model, onnx_model, optimizer_model=optimizer_model + ) + # Create Checkpoint State. + state = CheckpointState(checkpoint_file_path) + # Create a Module and Optimizer. + model = Module(model_file_path, state) + optimizer = Optimizer(optimizer_file_path, model) + + # Test get and set learning rate. + lr = optimizer.get_learning_rate() + assert round(lr, 3) == 0.001 + + optimizer.set_learning_rate(0.5) + new_lr = optimizer.get_learning_rate() + + assert np.isclose(new_lr, 0.5) + assert lr != new_lr + + def test_training_module_checkpoint(): # Initialize Models simple_model, onnx_model, _, _, _ = _create_training_models()