mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-29 03:30:52 +00:00
Add Utils for federated learning scenarios (#13014)
**Description**: utils for federated learning. **Motivation and Context** - This PR includes utils that will be used on federated learning scenarios. - Exposing python bindings to some utils, and added a util to calculate the difference between two buffers. Co-authored-by: Adam Louly <adamlouly@microsoft.com@orttrainingdev7.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net> Co-authored-by: Baiju Meswani <bmeswani@microsoft.com>
This commit is contained in:
parent
b4853a978a
commit
68eff69ab1
4 changed files with 85 additions and 0 deletions
|
|
@ -849,6 +849,15 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
|
|||
.def("reset_grad", [](onnxruntime::training::api::Module* model) -> void {
|
||||
ORT_THROW_IF_ERROR(model->ResetGrad());
|
||||
})
|
||||
.def("copy_parameters_to_buffer", [](onnxruntime::training::api::Module* model, OrtValue& output) -> void {
|
||||
ORT_THROW_IF_ERROR(model->CopyParametersToBuffer(output));
|
||||
})
|
||||
.def("copy_buffer_to_parameters", [](onnxruntime::training::api::Module* model, OrtValue& input) -> void {
|
||||
ORT_THROW_IF_ERROR(model->CopyBufferToParameters(input));
|
||||
})
|
||||
.def("get_parameters_size", [](onnxruntime::training::api::Module* model, bool trainable_only) -> size_t {
|
||||
return model->GetParametersSize(trainable_only);
|
||||
})
|
||||
.def("save_checkpoint", [](onnxruntime::training::api::Module* model, const std::string& checkpoint_path) -> void {
|
||||
onnxruntime::training::api::CheckpointState state;
|
||||
ORT_THROW_IF_ERROR(model->GetStateDict(state.module_checkpoint_state));
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
# Licensed under the MIT License.
|
||||
# module.py
|
||||
|
||||
import numpy as np
|
||||
|
||||
from onnxruntime.capi import _pybind_state as C
|
||||
from onnxruntime.capi.onnxruntime_inference_collection import OrtValue
|
||||
from onnxruntime.capi.onnxruntime_pybind11_state import OrtValueVector
|
||||
|
|
@ -81,3 +83,32 @@ class Module:
|
|||
"""
|
||||
# TODO : move this out of Module Class.
|
||||
self._model.save_checkpoint(ckpt_uri)
|
||||
|
||||
# This function will change when the parameters will be exposed.
|
||||
def get_contiguous_parameters(self, trainable_only: bool = False) -> OrtValue:
|
||||
"""
|
||||
Returns contiguous parameters object.
|
||||
"""
|
||||
parameters = OrtValue.ortvalue_from_shape_and_type(
|
||||
[
|
||||
self.get_parameters_size(trainable_only),
|
||||
],
|
||||
np.float32,
|
||||
"cpu",
|
||||
0,
|
||||
)._ortvalue
|
||||
self._model.copy_parameters_to_buffer(parameters)
|
||||
|
||||
return parameters
|
||||
|
||||
def get_parameters_size(self, trainable_only: bool = False) -> int:
|
||||
"""
|
||||
Returns the size of the parameters.
|
||||
"""
|
||||
return self._model.get_parameters_size(trainable_only)
|
||||
|
||||
def copy_buffer_to_parameters(self, buffer) -> None:
|
||||
"""
|
||||
Copies buffer to parameters.
|
||||
"""
|
||||
self._model.copy_buffer_to_parameters(buffer)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import torch
|
||||
from orttraining_test_onnxblock import _get_models
|
||||
|
|
@ -167,3 +168,46 @@ def test_training_module_checkpoint():
|
|||
|
||||
# TODO : Load checkpoint to a zeroed model and assert parameters are different.
|
||||
assert os.path.exists(checkpoint_save_path)
|
||||
|
||||
|
||||
def test_copy_buffer_to_parameters():
|
||||
# Initialize Models
|
||||
simple_model, onnx_model, optimizer_model, _, _ = _create_training_models()
|
||||
|
||||
# Generating random data for testing.
|
||||
inputs = torch.randn(64, 784).numpy()
|
||||
labels = torch.randint(high=10, size=(64,), dtype=torch.int32).numpy()
|
||||
forward_inputs = [inputs, labels]
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Save models & checkpoint files to load them later.
|
||||
checkpoint_file_path, model_file_path, optimizer_file_path = _get_test_models_path(
|
||||
temp_dir, simple_model, onnx_model, optimizer_model=optimizer_model
|
||||
)
|
||||
state = CheckpointState(checkpoint_file_path)
|
||||
|
||||
# Create a Module and Optimizer.
|
||||
model = Module(model_file_path, state)
|
||||
optimizer = Optimizer(optimizer_file_path, model)
|
||||
|
||||
# Keep a copy of the parameters.
|
||||
old_output_params = model.get_contiguous_parameters()
|
||||
|
||||
# Run a Training Step.
|
||||
model.train()
|
||||
model(forward_inputs)
|
||||
optimizer.step()
|
||||
|
||||
# Get the new parameters.
|
||||
output_params = model.get_contiguous_parameters()
|
||||
# Make sure old params are different from new params.
|
||||
assert not np.array_equal(old_output_params.numpy(), output_params.numpy())
|
||||
|
||||
# Copy the old parameters to the model.
|
||||
model.copy_buffer_to_parameters(old_output_params)
|
||||
|
||||
# Get the saved parameters.
|
||||
saved_params = model.get_contiguous_parameters()
|
||||
|
||||
# Make sure the saved parameters are the same as the old parameters.
|
||||
assert np.array_equal(old_output_params.numpy(), saved_params.numpy())
|
||||
|
|
|
|||
1
setup.py
1
setup.py
|
|
@ -492,6 +492,7 @@ if enable_training:
|
|||
]
|
||||
)
|
||||
if enable_training_on_device:
|
||||
packages.append("onnxruntime.training.api")
|
||||
packages.append("onnxruntime.training.onnxblock")
|
||||
packages.append("onnxruntime.training.onnxblock.loss")
|
||||
packages.append("onnxruntime.training.onnxblock.optim")
|
||||
|
|
|
|||
Loading…
Reference in a new issue