From 68eff69ab193fa8201ce131168a6cfa78cc8eeda Mon Sep 17 00:00:00 2001 From: Adam Louly Date: Mon, 17 Oct 2022 12:39:43 -0700 Subject: [PATCH] Add Utils for federated learning scenarios (#13014) **Description**: utils for federated learning. **Motivation and Context** - This PR includes utils that will be used on federated learning scenarios. - Exposing python bindings to some utils, and added a util to calculate the difference between two buffers. Co-authored-by: Adam Louly Co-authored-by: Baiju Meswani --- .../python/orttraining_pybind_state.cc | 9 ++++ .../orttraining/python/training/api/module.py | 31 +++++++++++++ .../orttraining_test_python_bindings.py | 44 +++++++++++++++++++ setup.py | 1 + 4 files changed, 85 insertions(+) diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index 89c0b0ade0..e68d25d4d2 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -849,6 +849,15 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn .def("reset_grad", [](onnxruntime::training::api::Module* model) -> void { ORT_THROW_IF_ERROR(model->ResetGrad()); }) + .def("copy_parameters_to_buffer", [](onnxruntime::training::api::Module* model, OrtValue& output) -> void { + ORT_THROW_IF_ERROR(model->CopyParametersToBuffer(output)); + }) + .def("copy_buffer_to_parameters", [](onnxruntime::training::api::Module* model, OrtValue& input) -> void { + ORT_THROW_IF_ERROR(model->CopyBufferToParameters(input)); + }) + .def("get_parameters_size", [](onnxruntime::training::api::Module* model, bool trainable_only) -> size_t { + return model->GetParametersSize(trainable_only); + }) .def("save_checkpoint", [](onnxruntime::training::api::Module* model, const std::string& checkpoint_path) -> void { onnxruntime::training::api::CheckpointState state; ORT_THROW_IF_ERROR(model->GetStateDict(state.module_checkpoint_state)); diff --git a/orttraining/orttraining/python/training/api/module.py b/orttraining/orttraining/python/training/api/module.py index 2e2cb3e31f..037c533c98 100644 --- a/orttraining/orttraining/python/training/api/module.py +++ b/orttraining/orttraining/python/training/api/module.py @@ -2,6 +2,8 @@ # Licensed under the MIT License. # module.py +import numpy as np + from onnxruntime.capi import _pybind_state as C from onnxruntime.capi.onnxruntime_inference_collection import OrtValue from onnxruntime.capi.onnxruntime_pybind11_state import OrtValueVector @@ -81,3 +83,32 @@ class Module: """ # TODO : move this out of Module Class. self._model.save_checkpoint(ckpt_uri) + + # This function will change when the parameters will be exposed. + def get_contiguous_parameters(self, trainable_only: bool = False) -> OrtValue: + """ + Returns contiguous parameters object. + """ + parameters = OrtValue.ortvalue_from_shape_and_type( + [ + self.get_parameters_size(trainable_only), + ], + np.float32, + "cpu", + 0, + )._ortvalue + self._model.copy_parameters_to_buffer(parameters) + + return parameters + + def get_parameters_size(self, trainable_only: bool = False) -> int: + """ + Returns the size of the parameters. + """ + return self._model.get_parameters_size(trainable_only) + + def copy_buffer_to_parameters(self, buffer) -> None: + """ + Copies buffer to parameters. + """ + self._model.copy_buffer_to_parameters(buffer) diff --git a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py b/orttraining/orttraining/test/python/orttraining_test_python_bindings.py index 8d65a23ace..4f26fdb31c 100644 --- a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py +++ b/orttraining/orttraining/test/python/orttraining_test_python_bindings.py @@ -1,6 +1,7 @@ import os import tempfile +import numpy as np import onnx import torch from orttraining_test_onnxblock import _get_models @@ -167,3 +168,46 @@ def test_training_module_checkpoint(): # TODO : Load checkpoint to a zeroed model and assert parameters are different. assert os.path.exists(checkpoint_save_path) + + +def test_copy_buffer_to_parameters(): + # Initialize Models + simple_model, onnx_model, optimizer_model, _, _ = _create_training_models() + + # Generating random data for testing. + inputs = torch.randn(64, 784).numpy() + labels = torch.randint(high=10, size=(64,), dtype=torch.int32).numpy() + forward_inputs = [inputs, labels] + + with tempfile.TemporaryDirectory() as temp_dir: + # Save models & checkpoint files to load them later. + checkpoint_file_path, model_file_path, optimizer_file_path = _get_test_models_path( + temp_dir, simple_model, onnx_model, optimizer_model=optimizer_model + ) + state = CheckpointState(checkpoint_file_path) + + # Create a Module and Optimizer. + model = Module(model_file_path, state) + optimizer = Optimizer(optimizer_file_path, model) + + # Keep a copy of the parameters. + old_output_params = model.get_contiguous_parameters() + + # Run a Training Step. + model.train() + model(forward_inputs) + optimizer.step() + + # Get the new parameters. + output_params = model.get_contiguous_parameters() + # Make sure old params are different from new params. + assert not np.array_equal(old_output_params.numpy(), output_params.numpy()) + + # Copy the old parameters to the model. + model.copy_buffer_to_parameters(old_output_params) + + # Get the saved parameters. + saved_params = model.get_contiguous_parameters() + + # Make sure the saved parameters are the same as the old parameters. + assert np.array_equal(old_output_params.numpy(), saved_params.numpy()) diff --git a/setup.py b/setup.py index 2f5b99b5e1..536ee576ee 100644 --- a/setup.py +++ b/setup.py @@ -492,6 +492,7 @@ if enable_training: ] ) if enable_training_on_device: + packages.append("onnxruntime.training.api") packages.append("onnxruntime.training.onnxblock") packages.append("onnxruntime.training.onnxblock.loss") packages.append("onnxruntime.training.onnxblock.optim")