From f3dcbf539aa410f31a0e4bce1807f79331fe1c1f Mon Sep 17 00:00:00 2001 From: Adam Louly Date: Tue, 26 Jul 2022 11:08:50 -0500 Subject: [PATCH] Checkpoint load inference (#12168) * LoadCheckPoint to tensor cpp functions (draft) * Load Checkpoint into inference model * fix python lint * fix python lint * Fixing lint and some unused imports * added assert for zero weights model, resolved other issues * resolved issues * Solved issues * changed variable names for get_models * paparameters names missmatched fix Co-authored-by: Adam Louly --- .../python/orttraining_pybind_state.cc | 19 +++++++ .../python/training/onnxblock/__init__.py | 2 +- .../training/onnxblock/checkpoint_utils.py | 19 +++++++ .../test/python/orttraining_test_onnxblock.py | 51 ++++++++++++++++++- .../orttraining/training_api/checkpoint.cc | 32 ++++++++++++ .../training_api/include/checkpoint.h | 9 ++++ 6 files changed, 129 insertions(+), 3 deletions(-) diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index 44465c6651..de5206935b 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -840,6 +840,25 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn ORT_THROW_IF_ERROR(onnxruntime::training::api::SaveCheckpoint(trainable_tensor_protos, non_trainable_tensor_protos, checkpoint_path)); }); + m.def("load_checkpoint", + [](const std::string& checkpoint_path) { + std::vector tensor_protos; + ORT_THROW_IF_ERROR(onnxruntime::training::api::LoadCheckpoint(checkpoint_path, tensor_protos)); + std::vector tensor_protos_pybytes(tensor_protos.size()); + + const auto parse_tensor_proto_to_pybytes = + [](std::vector& tensor_protos_pybytes, const std::vector& tensor_protos) { + for (size_t i = 0; i < tensor_protos.size(); ++i) { + std::string tensor_proto_str; + tensor_protos[i].SerializeToString(&tensor_proto_str); + tensor_protos_pybytes[i] = tensor_proto_str; + } + }; + + parse_tensor_proto_to_pybytes(tensor_protos_pybytes, tensor_protos); + + return tensor_protos_pybytes; + }); #endif } diff --git a/orttraining/orttraining/python/training/onnxblock/__init__.py b/orttraining/orttraining/python/training/onnxblock/__init__.py index 877f725f01..e6b43510c9 100644 --- a/orttraining/orttraining/python/training/onnxblock/__init__.py +++ b/orttraining/orttraining/python/training/onnxblock/__init__.py @@ -6,6 +6,6 @@ from . import loss, optim from .building_blocks import Block -from .checkpoint_utils import save_checkpoint +from .checkpoint_utils import load_checkpoint_to_model, save_checkpoint from .model import Model, TrainingModel from .model_accessor import onnx_model diff --git a/orttraining/orttraining/python/training/onnxblock/checkpoint_utils.py b/orttraining/orttraining/python/training/onnxblock/checkpoint_utils.py index b635708bc8..ae0e0cfb6c 100644 --- a/orttraining/orttraining/python/training/onnxblock/checkpoint_utils.py +++ b/orttraining/orttraining/python/training/onnxblock/checkpoint_utils.py @@ -2,6 +2,9 @@ # Licensed under the MIT License. # checkpoint_utils.py +from onnx import TensorProto + +from onnxruntime.capi._pybind_state import load_checkpoint as _internal_load_checkpoint from onnxruntime.capi._pybind_state import save_checkpoint as _internal_save_checkpoint @@ -17,3 +20,19 @@ def save_checkpoint(parameters, path_to_checkpoint): trainable_params = [param.SerializeToString() for param in trainable_params] non_trainable_params = [param.SerializeToString() for param in non_trainable_params] _internal_save_checkpoint(trainable_params, non_trainable_params, path_to_checkpoint) + + +def load_checkpoint_to_model(path_to_checkpoint, model): + """Loads the checkpoint to an onnx inference model.""" + + # Load the parameters from the checkpoint + parameters = _internal_load_checkpoint(path_to_checkpoint) + + parameters_dict = {} + for param in parameters: + param_proto = TensorProto() + param_proto.ParseFromString(param) + parameters_dict[param_proto.name] = param_proto + + for initializer in model.graph.initializer: + initializer.CopyFrom(parameters_dict[initializer.name]) diff --git a/orttraining/orttraining/test/python/orttraining_test_onnxblock.py b/orttraining/orttraining/test/python/orttraining_test_onnxblock.py index e66d7ada2c..15ae7824dd 100644 --- a/orttraining/orttraining/test/python/orttraining_test_onnxblock.py +++ b/orttraining/orttraining/test/python/orttraining_test_onnxblock.py @@ -133,9 +133,16 @@ def _to_numpy(tensor): return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() -def _get_models(device, batch_size, input_size, hidden_size, output_size): +def _get_models(device, batch_size, input_size, hidden_size, output_size, zero_flag=False): """Returns the pt and onnx models for SimpleNet""" pt_model = SimpleNet(input_size, hidden_size, output_size).to(device) + + # setting all initial weights to zero + if zero_flag: + with torch.no_grad(): + for param in pt_model.parameters(): + param.zero_() + x = torch.randn(batch_size, input_size, device=device) onnx_model = _get_onnx_model(pt_model, (x,)) @@ -537,11 +544,51 @@ def test_save_checkpoint(): with tempfile.TemporaryDirectory() as checkpoint_dir_name: checkpoint_file_path = os.path.join(checkpoint_dir_name, "checkpoint") onnxblock.save_checkpoint((trainable_params, non_trainable_params), checkpoint_file_path) - # Then assert os.path.exists(checkpoint_file_path) +def test_load_checkpoint(): + # Given + device = "cuda" + batch_size, input_size, hidden_size, output_size = 64, 784, 500, 10 + _, zero_onnx_model = _get_models(device, batch_size, input_size, hidden_size, output_size, zero_flag=True) + for i in range(len(zero_onnx_model.graph.initializer)): + zero_np = onnx.numpy_helper.to_array(zero_onnx_model.graph.initializer[i]) + assert np.allclose(zero_np, np.zeros(zero_np.shape)) + + _, onnx_model = _get_models(device, batch_size, input_size, hidden_size, output_size) + + # Copy of onnx_model for comparison + onnx_model_copy = copy.deepcopy(onnx_model) + + simple_model = SimpleTrainingModelWithMSELoss() + + # When + simple_model.requires_grad("fc2.weight", False) + simple_model.requires_grad("fc1.bias", False) + + with onnxblock.onnx_model(onnx_model): + _ = simple_model(onnx_model.graph.output[0].name) + trainable_params, non_trainable_params = simple_model.parameters() + + with tempfile.TemporaryDirectory() as checkpoint_dir_name: + checkpoint_file_path = os.path.join(checkpoint_dir_name, "checkpoint") + onnxblock.save_checkpoint((trainable_params, non_trainable_params), checkpoint_file_path) + + # Load checkpoint parameters to the new simple model + onnxblock.load_checkpoint_to_model(checkpoint_file_path, zero_onnx_model) + + # Then + onnx_model_copy.graph.initializer.sort(key=lambda x: x.name) + zero_onnx_model.graph.initializer.sort(key=lambda x: x.name) + + for i, _ in enumerate(onnx_model_copy.graph.initializer): + onnx_np = onnx.numpy_helper.to_array(onnx_model_copy.graph.initializer[i]) + zero_np = onnx.numpy_helper.to_array(zero_onnx_model.graph.initializer[i]) + assert np.allclose(onnx_np, zero_np) + + def test_set_requires_grad_on_parameters(): # Given device = "cuda" diff --git a/orttraining/orttraining/training_api/checkpoint.cc b/orttraining/orttraining/training_api/checkpoint.cc index 600473e6c0..78ffb37702 100644 --- a/orttraining/orttraining/training_api/checkpoint.cc +++ b/orttraining/orttraining/training_api/checkpoint.cc @@ -530,6 +530,33 @@ Status OrtLoadCustomPropertyInternal(const PathString& property_folder_path, return Status::OK(); } +Status OrtLoadInternal(const PathString& checkpoint_path, + std::vector& param_tensor_protos) { + // Find tensor proto files. + std::vector tensor_proto_filenames; + FilterFilesFromDirectory( + checkpoint_path, + [&tensor_proto_filenames](const PathChar* filename) -> bool { + PathString filename_str = filename; + if (StringEndsWith(filename_str, k_tensor_proto_file_name)) { + tensor_proto_filenames.push_back(filename_str); + } + return true; + }); + + // Load tensor protos to the tensorProto Vector + for (const auto& tensor_file_path : tensor_proto_filenames) { + std::vector tensor_protos{}; + const auto tensor_file_full_path = ConcatPathComponent(checkpoint_path, tensor_file_path); + LoadTensorProtoFromFile(tensor_file_full_path, tensor_protos, "[params]"); + + for (const auto& tensor_proto : tensor_protos) { + param_tensor_protos.push_back(tensor_proto); + } + } + return Status::OK(); +} + Status OrtLoadInternal(const PathString& checkpoint_path, CheckpointState& state) { ORT_ENFORCE(Env::Default().FolderExists(checkpoint_path), "Checkpoint folder not exit"); ORT_RETURN_IF_ERROR(OrtLoadModuleStatesInternal(checkpoint_path, state.module_checkpoint_state)); @@ -554,6 +581,11 @@ Status LoadCheckpoint(const PathString& checkpoint_path, CheckpointState& checkp return OrtLoadInternal(checkpoint_path, checkpoint_states); } +Status LoadCheckpoint(const PathString& checkpoint_path, + std::vector& param_tensor_protos) { + return OrtLoadInternal(checkpoint_path, param_tensor_protos); +} + } // namespace api } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/training_api/include/checkpoint.h b/orttraining/orttraining/training_api/include/checkpoint.h index e85585bfa7..fdbd0dde6d 100644 --- a/orttraining/orttraining/training_api/include/checkpoint.h +++ b/orttraining/orttraining/training_api/include/checkpoint.h @@ -82,6 +82,15 @@ Status SaveCheckpoint(const std::vector& trainable_ Status LoadCheckpoint(const PathString& checkpoint_path, CheckpointState& checkpoint_state); +/** + * @brief Load training states from ORT checkpoint and returns vector of tensor protos representing the model parameters. + * @param param_tensor_protos parameters in TensorProto format. + * @param checkpoint_path folder where checkpoint is stored. + * @return Status + */ +Status LoadCheckpoint(const PathString& checkpoint_path, + std::vector& param_tensor_protos); + } // namespace api } // namespace training } // namespace onnxruntime