Add Utils for federated learning scenarios (#13014)

**Description**: utils for federated learning.

**Motivation and Context**
- This PR includes utils that will be used on federated learning
scenarios.
- Exposing python bindings to some utils, and added a util to calculate
the difference between two buffers.

Co-authored-by: Adam Louly <adamlouly@microsoft.com@orttrainingdev7.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: Baiju Meswani <bmeswani@microsoft.com>
This commit is contained in:
Adam Louly 2022-10-17 12:39:43 -07:00 committed by GitHub
parent b4853a978a
commit 68eff69ab1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 85 additions and 0 deletions

View file

@ -849,6 +849,15 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
.def("reset_grad", [](onnxruntime::training::api::Module* model) -> void {
ORT_THROW_IF_ERROR(model->ResetGrad());
})
.def("copy_parameters_to_buffer", [](onnxruntime::training::api::Module* model, OrtValue& output) -> void {
ORT_THROW_IF_ERROR(model->CopyParametersToBuffer(output));
})
.def("copy_buffer_to_parameters", [](onnxruntime::training::api::Module* model, OrtValue& input) -> void {
ORT_THROW_IF_ERROR(model->CopyBufferToParameters(input));
})
.def("get_parameters_size", [](onnxruntime::training::api::Module* model, bool trainable_only) -> size_t {
return model->GetParametersSize(trainable_only);
})
.def("save_checkpoint", [](onnxruntime::training::api::Module* model, const std::string& checkpoint_path) -> void {
onnxruntime::training::api::CheckpointState state;
ORT_THROW_IF_ERROR(model->GetStateDict(state.module_checkpoint_state));

View file

@ -2,6 +2,8 @@
# Licensed under the MIT License.
# module.py
import numpy as np
from onnxruntime.capi import _pybind_state as C
from onnxruntime.capi.onnxruntime_inference_collection import OrtValue
from onnxruntime.capi.onnxruntime_pybind11_state import OrtValueVector
@ -81,3 +83,32 @@ class Module:
"""
# TODO : move this out of Module Class.
self._model.save_checkpoint(ckpt_uri)
# This function will change when the parameters will be exposed.
def get_contiguous_parameters(self, trainable_only: bool = False) -> OrtValue:
"""
Returns contiguous parameters object.
"""
parameters = OrtValue.ortvalue_from_shape_and_type(
[
self.get_parameters_size(trainable_only),
],
np.float32,
"cpu",
0,
)._ortvalue
self._model.copy_parameters_to_buffer(parameters)
return parameters
def get_parameters_size(self, trainable_only: bool = False) -> int:
"""
Returns the size of the parameters.
"""
return self._model.get_parameters_size(trainable_only)
def copy_buffer_to_parameters(self, buffer) -> None:
"""
Copies buffer to parameters.
"""
self._model.copy_buffer_to_parameters(buffer)

View file

@ -1,6 +1,7 @@
import os
import tempfile
import numpy as np
import onnx
import torch
from orttraining_test_onnxblock import _get_models
@ -167,3 +168,46 @@ def test_training_module_checkpoint():
# TODO : Load checkpoint to a zeroed model and assert parameters are different.
assert os.path.exists(checkpoint_save_path)
def test_copy_buffer_to_parameters():
# Initialize Models
simple_model, onnx_model, optimizer_model, _, _ = _create_training_models()
# Generating random data for testing.
inputs = torch.randn(64, 784).numpy()
labels = torch.randint(high=10, size=(64,), dtype=torch.int32).numpy()
forward_inputs = [inputs, labels]
with tempfile.TemporaryDirectory() as temp_dir:
# Save models & checkpoint files to load them later.
checkpoint_file_path, model_file_path, optimizer_file_path = _get_test_models_path(
temp_dir, simple_model, onnx_model, optimizer_model=optimizer_model
)
state = CheckpointState(checkpoint_file_path)
# Create a Module and Optimizer.
model = Module(model_file_path, state)
optimizer = Optimizer(optimizer_file_path, model)
# Keep a copy of the parameters.
old_output_params = model.get_contiguous_parameters()
# Run a Training Step.
model.train()
model(forward_inputs)
optimizer.step()
# Get the new parameters.
output_params = model.get_contiguous_parameters()
# Make sure old params are different from new params.
assert not np.array_equal(old_output_params.numpy(), output_params.numpy())
# Copy the old parameters to the model.
model.copy_buffer_to_parameters(old_output_params)
# Get the saved parameters.
saved_params = model.get_contiguous_parameters()
# Make sure the saved parameters are the same as the old parameters.
assert np.array_equal(old_output_params.numpy(), saved_params.numpy())

View file

@ -492,6 +492,7 @@ if enable_training:
]
)
if enable_training_on_device:
packages.append("onnxruntime.training.api")
packages.append("onnxruntime.training.onnxblock")
packages.append("onnxruntime.training.onnxblock.loss")
packages.append("onnxruntime.training.onnxblock.optim")