mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
Checkpoint load inference (#12168)
* LoadCheckPoint to tensor cpp functions (draft) * Load Checkpoint into inference model * fix python lint * fix python lint * Fixing lint and some unused imports * added assert for zero weights model, resolved other issues * resolved issues * Solved issues * changed variable names for get_models * paparameters names missmatched fix Co-authored-by: Adam Louly <adamlouly@microsoft.com@orttrainingdev7.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
This commit is contained in:
parent
de57daaab0
commit
f3dcbf539a
6 changed files with 129 additions and 3 deletions
|
|
@ -840,6 +840,25 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
|
|||
ORT_THROW_IF_ERROR(onnxruntime::training::api::SaveCheckpoint(trainable_tensor_protos,
|
||||
non_trainable_tensor_protos, checkpoint_path));
|
||||
});
|
||||
m.def("load_checkpoint",
|
||||
[](const std::string& checkpoint_path) {
|
||||
std::vector<TensorProto> tensor_protos;
|
||||
ORT_THROW_IF_ERROR(onnxruntime::training::api::LoadCheckpoint(checkpoint_path, tensor_protos));
|
||||
std::vector<py::bytes> tensor_protos_pybytes(tensor_protos.size());
|
||||
|
||||
const auto parse_tensor_proto_to_pybytes =
|
||||
[](std::vector<py::bytes>& tensor_protos_pybytes, const std::vector<TensorProto>& tensor_protos) {
|
||||
for (size_t i = 0; i < tensor_protos.size(); ++i) {
|
||||
std::string tensor_proto_str;
|
||||
tensor_protos[i].SerializeToString(&tensor_proto_str);
|
||||
tensor_protos_pybytes[i] = tensor_proto_str;
|
||||
}
|
||||
};
|
||||
|
||||
parse_tensor_proto_to_pybytes(tensor_protos_pybytes, tensor_protos);
|
||||
|
||||
return tensor_protos_pybytes;
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,6 @@
|
|||
|
||||
from . import loss, optim
|
||||
from .building_blocks import Block
|
||||
from .checkpoint_utils import save_checkpoint
|
||||
from .checkpoint_utils import load_checkpoint_to_model, save_checkpoint
|
||||
from .model import Model, TrainingModel
|
||||
from .model_accessor import onnx_model
|
||||
|
|
|
|||
|
|
@ -2,6 +2,9 @@
|
|||
# Licensed under the MIT License.
|
||||
# checkpoint_utils.py
|
||||
|
||||
from onnx import TensorProto
|
||||
|
||||
from onnxruntime.capi._pybind_state import load_checkpoint as _internal_load_checkpoint
|
||||
from onnxruntime.capi._pybind_state import save_checkpoint as _internal_save_checkpoint
|
||||
|
||||
|
||||
|
|
@ -17,3 +20,19 @@ def save_checkpoint(parameters, path_to_checkpoint):
|
|||
trainable_params = [param.SerializeToString() for param in trainable_params]
|
||||
non_trainable_params = [param.SerializeToString() for param in non_trainable_params]
|
||||
_internal_save_checkpoint(trainable_params, non_trainable_params, path_to_checkpoint)
|
||||
|
||||
|
||||
def load_checkpoint_to_model(path_to_checkpoint, model):
|
||||
"""Loads the checkpoint to an onnx inference model."""
|
||||
|
||||
# Load the parameters from the checkpoint
|
||||
parameters = _internal_load_checkpoint(path_to_checkpoint)
|
||||
|
||||
parameters_dict = {}
|
||||
for param in parameters:
|
||||
param_proto = TensorProto()
|
||||
param_proto.ParseFromString(param)
|
||||
parameters_dict[param_proto.name] = param_proto
|
||||
|
||||
for initializer in model.graph.initializer:
|
||||
initializer.CopyFrom(parameters_dict[initializer.name])
|
||||
|
|
|
|||
|
|
@ -133,9 +133,16 @@ def _to_numpy(tensor):
|
|||
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
|
||||
|
||||
|
||||
def _get_models(device, batch_size, input_size, hidden_size, output_size):
|
||||
def _get_models(device, batch_size, input_size, hidden_size, output_size, zero_flag=False):
|
||||
"""Returns the pt and onnx models for SimpleNet"""
|
||||
pt_model = SimpleNet(input_size, hidden_size, output_size).to(device)
|
||||
|
||||
# setting all initial weights to zero
|
||||
if zero_flag:
|
||||
with torch.no_grad():
|
||||
for param in pt_model.parameters():
|
||||
param.zero_()
|
||||
|
||||
x = torch.randn(batch_size, input_size, device=device)
|
||||
onnx_model = _get_onnx_model(pt_model, (x,))
|
||||
|
||||
|
|
@ -537,11 +544,51 @@ def test_save_checkpoint():
|
|||
with tempfile.TemporaryDirectory() as checkpoint_dir_name:
|
||||
checkpoint_file_path = os.path.join(checkpoint_dir_name, "checkpoint")
|
||||
onnxblock.save_checkpoint((trainable_params, non_trainable_params), checkpoint_file_path)
|
||||
|
||||
# Then
|
||||
assert os.path.exists(checkpoint_file_path)
|
||||
|
||||
|
||||
def test_load_checkpoint():
|
||||
# Given
|
||||
device = "cuda"
|
||||
batch_size, input_size, hidden_size, output_size = 64, 784, 500, 10
|
||||
_, zero_onnx_model = _get_models(device, batch_size, input_size, hidden_size, output_size, zero_flag=True)
|
||||
for i in range(len(zero_onnx_model.graph.initializer)):
|
||||
zero_np = onnx.numpy_helper.to_array(zero_onnx_model.graph.initializer[i])
|
||||
assert np.allclose(zero_np, np.zeros(zero_np.shape))
|
||||
|
||||
_, onnx_model = _get_models(device, batch_size, input_size, hidden_size, output_size)
|
||||
|
||||
# Copy of onnx_model for comparison
|
||||
onnx_model_copy = copy.deepcopy(onnx_model)
|
||||
|
||||
simple_model = SimpleTrainingModelWithMSELoss()
|
||||
|
||||
# When
|
||||
simple_model.requires_grad("fc2.weight", False)
|
||||
simple_model.requires_grad("fc1.bias", False)
|
||||
|
||||
with onnxblock.onnx_model(onnx_model):
|
||||
_ = simple_model(onnx_model.graph.output[0].name)
|
||||
trainable_params, non_trainable_params = simple_model.parameters()
|
||||
|
||||
with tempfile.TemporaryDirectory() as checkpoint_dir_name:
|
||||
checkpoint_file_path = os.path.join(checkpoint_dir_name, "checkpoint")
|
||||
onnxblock.save_checkpoint((trainable_params, non_trainable_params), checkpoint_file_path)
|
||||
|
||||
# Load checkpoint parameters to the new simple model
|
||||
onnxblock.load_checkpoint_to_model(checkpoint_file_path, zero_onnx_model)
|
||||
|
||||
# Then
|
||||
onnx_model_copy.graph.initializer.sort(key=lambda x: x.name)
|
||||
zero_onnx_model.graph.initializer.sort(key=lambda x: x.name)
|
||||
|
||||
for i, _ in enumerate(onnx_model_copy.graph.initializer):
|
||||
onnx_np = onnx.numpy_helper.to_array(onnx_model_copy.graph.initializer[i])
|
||||
zero_np = onnx.numpy_helper.to_array(zero_onnx_model.graph.initializer[i])
|
||||
assert np.allclose(onnx_np, zero_np)
|
||||
|
||||
|
||||
def test_set_requires_grad_on_parameters():
|
||||
# Given
|
||||
device = "cuda"
|
||||
|
|
|
|||
|
|
@ -530,6 +530,33 @@ Status OrtLoadCustomPropertyInternal(const PathString& property_folder_path,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OrtLoadInternal(const PathString& checkpoint_path,
|
||||
std::vector<ONNX_NAMESPACE::TensorProto>& param_tensor_protos) {
|
||||
// Find tensor proto files.
|
||||
std::vector<PathString> tensor_proto_filenames;
|
||||
FilterFilesFromDirectory(
|
||||
checkpoint_path,
|
||||
[&tensor_proto_filenames](const PathChar* filename) -> bool {
|
||||
PathString filename_str = filename;
|
||||
if (StringEndsWith(filename_str, k_tensor_proto_file_name)) {
|
||||
tensor_proto_filenames.push_back(filename_str);
|
||||
}
|
||||
return true;
|
||||
});
|
||||
|
||||
// Load tensor protos to the tensorProto Vector
|
||||
for (const auto& tensor_file_path : tensor_proto_filenames) {
|
||||
std::vector<ONNX_NAMESPACE::TensorProto> tensor_protos{};
|
||||
const auto tensor_file_full_path = ConcatPathComponent<PathChar>(checkpoint_path, tensor_file_path);
|
||||
LoadTensorProtoFromFile(tensor_file_full_path, tensor_protos, "[params]");
|
||||
|
||||
for (const auto& tensor_proto : tensor_protos) {
|
||||
param_tensor_protos.push_back(tensor_proto);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OrtLoadInternal(const PathString& checkpoint_path, CheckpointState& state) {
|
||||
ORT_ENFORCE(Env::Default().FolderExists(checkpoint_path), "Checkpoint folder not exit");
|
||||
ORT_RETURN_IF_ERROR(OrtLoadModuleStatesInternal(checkpoint_path, state.module_checkpoint_state));
|
||||
|
|
@ -554,6 +581,11 @@ Status LoadCheckpoint(const PathString& checkpoint_path, CheckpointState& checkp
|
|||
return OrtLoadInternal(checkpoint_path, checkpoint_states);
|
||||
}
|
||||
|
||||
Status LoadCheckpoint(const PathString& checkpoint_path,
|
||||
std::vector<ONNX_NAMESPACE::TensorProto>& param_tensor_protos) {
|
||||
return OrtLoadInternal(checkpoint_path, param_tensor_protos);
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -82,6 +82,15 @@ Status SaveCheckpoint(const std::vector<ONNX_NAMESPACE::TensorProto>& trainable_
|
|||
Status LoadCheckpoint(const PathString& checkpoint_path,
|
||||
CheckpointState& checkpoint_state);
|
||||
|
||||
/**
|
||||
* @brief Load training states from ORT checkpoint and returns vector of tensor protos representing the model parameters.
|
||||
* @param param_tensor_protos parameters in TensorProto format.
|
||||
* @param checkpoint_path folder where checkpoint is stored.
|
||||
* @return Status
|
||||
*/
|
||||
Status LoadCheckpoint(const PathString& checkpoint_path,
|
||||
std::vector<ONNX_NAMESPACE::TensorProto>& param_tensor_protos);
|
||||
|
||||
} // namespace api
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue