Checkpoint load inference (#12168)

* LoadCheckPoint to tensor cpp functions (draft)

* Load Checkpoint into inference model

* fix python lint

* fix python lint

* Fixing lint and some unused imports

* added assert for zero weights model, resolved other issues

* resolved issues

* Solved issues

* changed variable names for get_models

* paparameters names missmatched fix

Co-authored-by: Adam Louly <adamlouly@microsoft.com@orttrainingdev7.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
This commit is contained in:
Adam Louly 2022-07-26 11:08:50 -05:00 committed by GitHub
parent de57daaab0
commit f3dcbf539a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 129 additions and 3 deletions

View file

@ -840,6 +840,25 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
ORT_THROW_IF_ERROR(onnxruntime::training::api::SaveCheckpoint(trainable_tensor_protos,
non_trainable_tensor_protos, checkpoint_path));
});
m.def("load_checkpoint",
[](const std::string& checkpoint_path) {
std::vector<TensorProto> tensor_protos;
ORT_THROW_IF_ERROR(onnxruntime::training::api::LoadCheckpoint(checkpoint_path, tensor_protos));
std::vector<py::bytes> tensor_protos_pybytes(tensor_protos.size());
const auto parse_tensor_proto_to_pybytes =
[](std::vector<py::bytes>& tensor_protos_pybytes, const std::vector<TensorProto>& tensor_protos) {
for (size_t i = 0; i < tensor_protos.size(); ++i) {
std::string tensor_proto_str;
tensor_protos[i].SerializeToString(&tensor_proto_str);
tensor_protos_pybytes[i] = tensor_proto_str;
}
};
parse_tensor_proto_to_pybytes(tensor_protos_pybytes, tensor_protos);
return tensor_protos_pybytes;
});
#endif
}

View file

@ -6,6 +6,6 @@
from . import loss, optim
from .building_blocks import Block
from .checkpoint_utils import save_checkpoint
from .checkpoint_utils import load_checkpoint_to_model, save_checkpoint
from .model import Model, TrainingModel
from .model_accessor import onnx_model

View file

@ -2,6 +2,9 @@
# Licensed under the MIT License.
# checkpoint_utils.py
from onnx import TensorProto
from onnxruntime.capi._pybind_state import load_checkpoint as _internal_load_checkpoint
from onnxruntime.capi._pybind_state import save_checkpoint as _internal_save_checkpoint
@ -17,3 +20,19 @@ def save_checkpoint(parameters, path_to_checkpoint):
trainable_params = [param.SerializeToString() for param in trainable_params]
non_trainable_params = [param.SerializeToString() for param in non_trainable_params]
_internal_save_checkpoint(trainable_params, non_trainable_params, path_to_checkpoint)
def load_checkpoint_to_model(path_to_checkpoint, model):
"""Loads the checkpoint to an onnx inference model."""
# Load the parameters from the checkpoint
parameters = _internal_load_checkpoint(path_to_checkpoint)
parameters_dict = {}
for param in parameters:
param_proto = TensorProto()
param_proto.ParseFromString(param)
parameters_dict[param_proto.name] = param_proto
for initializer in model.graph.initializer:
initializer.CopyFrom(parameters_dict[initializer.name])

View file

@ -133,9 +133,16 @@ def _to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
def _get_models(device, batch_size, input_size, hidden_size, output_size):
def _get_models(device, batch_size, input_size, hidden_size, output_size, zero_flag=False):
"""Returns the pt and onnx models for SimpleNet"""
pt_model = SimpleNet(input_size, hidden_size, output_size).to(device)
# setting all initial weights to zero
if zero_flag:
with torch.no_grad():
for param in pt_model.parameters():
param.zero_()
x = torch.randn(batch_size, input_size, device=device)
onnx_model = _get_onnx_model(pt_model, (x,))
@ -537,11 +544,51 @@ def test_save_checkpoint():
with tempfile.TemporaryDirectory() as checkpoint_dir_name:
checkpoint_file_path = os.path.join(checkpoint_dir_name, "checkpoint")
onnxblock.save_checkpoint((trainable_params, non_trainable_params), checkpoint_file_path)
# Then
assert os.path.exists(checkpoint_file_path)
def test_load_checkpoint():
# Given
device = "cuda"
batch_size, input_size, hidden_size, output_size = 64, 784, 500, 10
_, zero_onnx_model = _get_models(device, batch_size, input_size, hidden_size, output_size, zero_flag=True)
for i in range(len(zero_onnx_model.graph.initializer)):
zero_np = onnx.numpy_helper.to_array(zero_onnx_model.graph.initializer[i])
assert np.allclose(zero_np, np.zeros(zero_np.shape))
_, onnx_model = _get_models(device, batch_size, input_size, hidden_size, output_size)
# Copy of onnx_model for comparison
onnx_model_copy = copy.deepcopy(onnx_model)
simple_model = SimpleTrainingModelWithMSELoss()
# When
simple_model.requires_grad("fc2.weight", False)
simple_model.requires_grad("fc1.bias", False)
with onnxblock.onnx_model(onnx_model):
_ = simple_model(onnx_model.graph.output[0].name)
trainable_params, non_trainable_params = simple_model.parameters()
with tempfile.TemporaryDirectory() as checkpoint_dir_name:
checkpoint_file_path = os.path.join(checkpoint_dir_name, "checkpoint")
onnxblock.save_checkpoint((trainable_params, non_trainable_params), checkpoint_file_path)
# Load checkpoint parameters to the new simple model
onnxblock.load_checkpoint_to_model(checkpoint_file_path, zero_onnx_model)
# Then
onnx_model_copy.graph.initializer.sort(key=lambda x: x.name)
zero_onnx_model.graph.initializer.sort(key=lambda x: x.name)
for i, _ in enumerate(onnx_model_copy.graph.initializer):
onnx_np = onnx.numpy_helper.to_array(onnx_model_copy.graph.initializer[i])
zero_np = onnx.numpy_helper.to_array(zero_onnx_model.graph.initializer[i])
assert np.allclose(onnx_np, zero_np)
def test_set_requires_grad_on_parameters():
# Given
device = "cuda"

View file

@ -530,6 +530,33 @@ Status OrtLoadCustomPropertyInternal(const PathString& property_folder_path,
return Status::OK();
}
Status OrtLoadInternal(const PathString& checkpoint_path,
std::vector<ONNX_NAMESPACE::TensorProto>& param_tensor_protos) {
// Find tensor proto files.
std::vector<PathString> tensor_proto_filenames;
FilterFilesFromDirectory(
checkpoint_path,
[&tensor_proto_filenames](const PathChar* filename) -> bool {
PathString filename_str = filename;
if (StringEndsWith(filename_str, k_tensor_proto_file_name)) {
tensor_proto_filenames.push_back(filename_str);
}
return true;
});
// Load tensor protos to the tensorProto Vector
for (const auto& tensor_file_path : tensor_proto_filenames) {
std::vector<ONNX_NAMESPACE::TensorProto> tensor_protos{};
const auto tensor_file_full_path = ConcatPathComponent<PathChar>(checkpoint_path, tensor_file_path);
LoadTensorProtoFromFile(tensor_file_full_path, tensor_protos, "[params]");
for (const auto& tensor_proto : tensor_protos) {
param_tensor_protos.push_back(tensor_proto);
}
}
return Status::OK();
}
Status OrtLoadInternal(const PathString& checkpoint_path, CheckpointState& state) {
ORT_ENFORCE(Env::Default().FolderExists(checkpoint_path), "Checkpoint folder not exit");
ORT_RETURN_IF_ERROR(OrtLoadModuleStatesInternal(checkpoint_path, state.module_checkpoint_state));
@ -554,6 +581,11 @@ Status LoadCheckpoint(const PathString& checkpoint_path, CheckpointState& checkp
return OrtLoadInternal(checkpoint_path, checkpoint_states);
}
Status LoadCheckpoint(const PathString& checkpoint_path,
std::vector<ONNX_NAMESPACE::TensorProto>& param_tensor_protos) {
return OrtLoadInternal(checkpoint_path, param_tensor_protos);
}
} // namespace api
} // namespace training
} // namespace onnxruntime

View file

@ -82,6 +82,15 @@ Status SaveCheckpoint(const std::vector<ONNX_NAMESPACE::TensorProto>& trainable_
Status LoadCheckpoint(const PathString& checkpoint_path,
CheckpointState& checkpoint_state);
/**
* @brief Load training states from ORT checkpoint and returns vector of tensor protos representing the model parameters.
* @param param_tensor_protos parameters in TensorProto format.
* @param checkpoint_path folder where checkpoint is stored.
* @return Status
*/
Status LoadCheckpoint(const PathString& checkpoint_path,
std::vector<ONNX_NAMESPACE::TensorProto>& param_tensor_protos);
} // namespace api
} // namespace training
} // namespace onnxruntime