From a46c599a404bea970fa7948e23d7bc91bbe6dc9e Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Thu, 27 Oct 2022 09:34:01 -0700 Subject: [PATCH] Training API to export the eval model to an inference model (#13345) --- .../Training/NativeTrainingMethods.shared.cs | 1 + .../testdata/training_api/eval_model.onnx | Bin 0 -> 932 bytes .../testdata/training_api/training_model.onnx | Bin 3692 -> 3997 bytes .../python/orttraining_pybind_state.cc | 48 ++-- .../orttraining/python/training/api/module.py | 6 + .../orttraining_test_python_bindings.py | 30 ++- .../training_api/core/training_api_tests.cc | 236 ++++++++++++------ .../orttraining/training_api/checkpoint.cc | 26 +- .../orttraining/training_api/include/module.h | 6 + .../include/onnxruntime_training_c_api.h | 19 ++ .../include/onnxruntime_training_cxx_api.h | 8 + .../include/onnxruntime_training_cxx_inline.h | 10 + .../training_api/include/ort_training_apis.h | 4 + .../training_api/include/training_session.h | 11 +- .../orttraining/training_api/include/utils.h | 3 + .../orttraining/training_api/module.cc | 127 +++++++++- .../onnxruntime_training_c_api.cc | 26 ++ .../training_api/training_session.cc | 5 + orttraining/orttraining/training_api/utils.cc | 26 ++ 19 files changed, 463 insertions(+), 129 deletions(-) create mode 100644 onnxruntime/test/testdata/training_api/eval_model.onnx diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs index 63713a3977..3054b229af 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs @@ -32,6 +32,7 @@ namespace Microsoft.ML.OnnxRuntime public IntPtr CopyBufferToParameters; public IntPtr ReleaseTrainingSession; public IntPtr ReleaseCheckpointState; + public IntPtr ExportModelForInferencing; } internal static class NativeTrainingMethods diff --git a/onnxruntime/test/testdata/training_api/eval_model.onnx b/onnxruntime/test/testdata/training_api/eval_model.onnx new file mode 100644 index 0000000000000000000000000000000000000000..d607c282f4dfbe32593f40dcdf4084356f1c080d GIT binary patch literal 932 zcmbVL-A6;8RlG3KY|JJmCcAgs5U5j20|ObR{@nQp z`UZVFU%&@&0HsZXUU)N^Ip6sX-#OnwuA*vm&q6vFwIsc(?{_uxSpo)Ck3lD{b~s@?)7Dx_9&3vU(rLhw<3}S$EtMWlrh?F^gXm&5 zd)Gi>)FJWd4`C{l;TQ_aqSYS7{-pPhAG2@4EZ7zk2yUP#F_CY~$u~0jawd8=6K$l8 zd{Ok10C_c#e8U-$gj%etbX8+L8`>pLP{68GBz_~l4 zobiV+8bir#u?U!d@zZ_-20eS^s+TCRfc; zPwJgW-NBy+@RI;rnrScGS6zVFq$k}_?efY6k^PS~bxueJe>>;YFcfRy;!364EQfQg zB9pYM@JQ9|D|NR}`0O^qA}hQ^bYsJE{j-0_v^I+lI+wYRx1_m4ko z?PbW@btpS@x|q0{Wwwe2Rn2>06aEq4NnkESpjwp=Y;r zuFNb);iS{oiWfLKB~LR&%_b-Eb-rB~%5ke44%+)rH!aInZZ!8wp@!|=DJt6alfDY# zJ4)@TX804FWAGhh10WTa&5V%#dGIUWOsVN=XzrvX2LHVN`;|zg@rG0*-uSk5Ip4in z&|8uwVzNiVh!B0wgN z)Ivn|Fn(MG^-(ut{&>`3^=!!#G5F`~^XdP$09pDyA}kFr)$>G%OF1H>i##I4MPcqM z@_Qyi+)4}^62ZTjA|k|B2FZ`S%p}6^IWY@vLq>z(HjHvCZpouTMvn1G6uS@Cp}`*l ze9W1b)X8LFEOeC>cxk+t&8P!BMyTrQ4|iD4=G65-lcQWUSmT>@mKqvXkF&sl<@8$h zH|96Gid8e3{8r|3599bOaT=*0lLxj6@i0e47C?DDfPNf0mW= znYiq^ER6RdrDrZd5o1yqOYmSp!}Exgy@Usg#^yMNRhM zDZ01IWn|Y&o$Fre<0yB_X?BZ}Z%a<^U~-M+9oY2?t~n=7XrQb06<0CauTL@_B_U>> zkBlS3joS8b2a>OXnlStb8mcC-XNzPz^m~ooNcaNPOwGJR=^i{zJR>A?a&EV&1Fufj&d8 GZT<(yvB$6g delta 1070 zcmaJ=OK%cU6y^%d47a=-Q)-8zKs5$*rgun-fQcBIT1{BBF;O=Y9#aYdh75zH8WVqk zNj*t->e6KwZklM~O5>`oNp$1F#Kgq-51a=GBk1a$dmi6+&v(v!e*VYOS%zlm&c3Fo zm0gY!1yK}8?pYXj4;G_-cfH-|X&kjFw^|Z`C3JIj8Cy=)tF+twT+R%oIciI8_9QXx zGEV#m+oA`*Vzh;kGB9pEz>Lz#82#Xlts{*)QrQd_ZLTJ;J_GtyEJ>Rcr(JL*N@xnn~RC|KV0o^%mdg zxV923n!Tpr9iC5HrnIT^tyI?~Rj&3ba+M9u_Hzrv(%a7ap1XLM1xL;(J+cSRWae;+ z(k3Gt*r}w1J-J@n)vPAiA?${$gbv}AE5eA+SY_1tD-ujRYa^l}eJQc0SsC7EcTz#2G>mDAF}3 zc*1Z(V~Be00MDGjYx*JdJRh$G%^l|(b&Z$%9YxiW3qcfjIUG;&1^Wgia|{a7Ipio< zBR4y|WNJOt$=|u=dW@F+3m!nsn}MHmDRdQ%=!@`q?hIOi7oKH!?7b3V@p$bi>|zk> zC<_?!o)~=h-q2qjhHu_^xaUjw5|hfqC*L{P^Yekr_Q4r`=ECrY8W$sB@;kv>e& user_inputs, std::vector& user_outputs) -> void { ORT_THROW_IF_ERROR(model->EvalStep(user_inputs, user_outputs)); }) - .def("reset_grad", [](onnxruntime::training::api::Module* model) -> void { - ORT_THROW_IF_ERROR(model->ResetGrad()); - }) - .def("copy_parameters_to_buffer", [](onnxruntime::training::api::Module* model, OrtValue& output) -> void { - ORT_THROW_IF_ERROR(model->CopyParametersToBuffer(output)); - }) - .def("copy_buffer_to_parameters", [](onnxruntime::training::api::Module* model, OrtValue& input) -> void { - ORT_THROW_IF_ERROR(model->CopyBufferToParameters(input)); - }) - .def("get_parameters_size", [](onnxruntime::training::api::Module* model, bool trainable_only) -> size_t { - return model->GetParametersSize(trainable_only); - }) - .def("save_checkpoint", [](onnxruntime::training::api::Module* model, const std::string& checkpoint_path) -> void { - onnxruntime::training::api::CheckpointState state; - ORT_THROW_IF_ERROR(model->GetStateDict(state.module_checkpoint_state)); - ORT_THROW_IF_ERROR(onnxruntime::training::api::SaveCheckpoint(state, - ToPathString(checkpoint_path))); - }); + .def("reset_grad", + [](onnxruntime::training::api::Module* model) -> void { + ORT_THROW_IF_ERROR(model->ResetGrad()); + }) + .def("copy_parameters_to_buffer", + [](onnxruntime::training::api::Module* model, OrtValue& output) -> void { + ORT_THROW_IF_ERROR(model->CopyParametersToBuffer(output)); + }) + .def("copy_buffer_to_parameters", + [](onnxruntime::training::api::Module* model, OrtValue& input) -> void { + ORT_THROW_IF_ERROR(model->CopyBufferToParameters(input)); + }) + .def("get_parameters_size", + [](onnxruntime::training::api::Module* model, bool trainable_only) -> size_t { + return model->GetParametersSize(trainable_only); + }) + .def("save_checkpoint", + [](onnxruntime::training::api::Module* model, const std::string& checkpoint_path) -> void { + onnxruntime::training::api::CheckpointState state; + ORT_THROW_IF_ERROR(model->GetStateDict(state.module_checkpoint_state)); + ORT_THROW_IF_ERROR(onnxruntime::training::api::SaveCheckpoint(state, + ToPathString(checkpoint_path))); + }) + .def("export_model_for_inferencing", + [](onnxruntime::training::api::Module* model, const std::string& inference_model_path, + const std::vector& graph_output_names) -> void { + ORT_ENFORCE(model, "Received a nullptr for expected pointer to class training::api::Module"); + ORT_THROW_IF_ERROR(model->ExportModelForInferencing(inference_model_path, + graph_output_names)); + }); py::class_ checkpoint_state(m, "CheckpointState", R"pbdoc(CheckpointState.)pbdoc"); diff --git a/orttraining/orttraining/python/training/api/module.py b/orttraining/orttraining/python/training/api/module.py index 037c533c98..2bcca528cb 100644 --- a/orttraining/orttraining/python/training/api/module.py +++ b/orttraining/orttraining/python/training/api/module.py @@ -112,3 +112,9 @@ class Module: Copies buffer to parameters. """ self._model.copy_buffer_to_parameters(buffer) + + def export_model_for_inferencing(self, inference_model_uri: str, graph_output_names: list[str]) -> None: + """ + Exports the model for inferencing. + """ + self._model.export_model_for_inferencing(inference_model_uri, graph_output_names) diff --git a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py b/orttraining/orttraining/test/python/orttraining_test_python_bindings.py index 4f26fdb31c..278f8f59b6 100644 --- a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py +++ b/orttraining/orttraining/test/python/orttraining_test_python_bindings.py @@ -21,7 +21,7 @@ class SimpleModelWithCrossEntropyLoss(onnxblock.TrainingModel): def _create_training_models(): # Given - device = "cuda" + device = "cpu" batch_size, input_size, hidden_size, output_size = 64, 784, 500, 10 pt_model, onnx_model = _get_models(device, batch_size, input_size, hidden_size, output_size) @@ -82,11 +82,11 @@ def test_train_step(): fetches = model(forward_inputs) # Calculate loss using pytorch model to compare it with Module's output. - pt_outputs = pt_model(torch.from_numpy(inputs).to("cuda")) + pt_outputs = pt_model(torch.from_numpy(inputs)) loss_fn = torch.nn.CrossEntropyLoss() - pt_loss = loss_fn(pt_outputs, torch.from_numpy(labels).to("cuda").long()) + pt_loss = loss_fn(pt_outputs, torch.from_numpy(labels).long()) - assert fetches[0] == pt_loss.item() + assert np.allclose(fetches[0], pt_loss.detach().numpy()) def test_eval_step(): @@ -211,3 +211,25 @@ def test_copy_buffer_to_parameters(): # Make sure the saved parameters are the same as the old parameters. assert np.array_equal(old_output_params.numpy(), saved_params.numpy()) + + +def test_export_model_for_inferencing(): + # Initialize Models + simple_model, onnx_model, _, eval_model, _ = _create_training_models() + + with tempfile.TemporaryDirectory() as temp_dir: + # Save models & checkpoint files to load them later. + checkpoint_file_path, model_file_path, eval_model_file_path = _get_test_models_path( + temp_dir, simple_model, onnx_model, eval_model=eval_model + ) + + # Create Checkpoint State. + state = CheckpointState(checkpoint_file_path) + + # Create a Module. + model = Module(model_file_path, state, eval_model_file_path) + + # Export inference model + inference_model_file_path = os.path.join(temp_dir, "inference_model.onnx") + model.export_model_for_inferencing(inference_model_file_path, ["output-0"]) + assert os.path.exists(inference_model_file_path) diff --git a/orttraining/orttraining/test/training_api/core/training_api_tests.cc b/orttraining/orttraining/test/training_api/core/training_api_tests.cc index da55071228..1c8a9decd6 100644 --- a/orttraining/orttraining/test/training_api/core/training_api_tests.cc +++ b/orttraining/orttraining/test/training_api/core/training_api_tests.cc @@ -17,6 +17,7 @@ #include "orttraining/training_api/include/checkpoint.h" #include "orttraining/training_api/include/lr_scheduler.h" #include "orttraining/test/training_api/core/data_utils.h" +#include "test/util/include/temp_dir.h" #include "default_providers.h" using json = nlohmann::json; @@ -47,6 +48,153 @@ void GenerateRandomInput(gsl::span dims, OrtValue& input) { onnxruntime::training::api::utils::CreateInputOrtValue(dims, data, &input); } +void TestModuleExport(const std::vector>& providers) { + auto training_model_uri = MODEL_FOLDER "training_model.onnx"; + auto eval_model_uri = MODEL_FOLDER "eval_model.onnx"; + + onnxruntime::training::api::CheckpointState state; + auto checkpoint_to_load_path = MODEL_FOLDER "checkpoint.ckpt"; + ASSERT_STATUS_OK(onnxruntime::training::api::LoadCheckpoint(checkpoint_to_load_path, state)); + + std::unique_ptr env; + ASSERT_STATUS_OK(Environment::Create(nullptr, env)); + auto model = std::make_unique( + ToUTF8String(training_model_uri), state.module_checkpoint_state.named_parameters, onnxruntime::SessionOptions(), + *env, providers, ToUTF8String(eval_model_uri)); + + auto test_dir = ORT_TSTR("export_model_for_inferencing_test_dir"); + if (Env::Default().FolderExists(test_dir)) { + ORT_ENFORCE(Env::Default().DeleteFolder(test_dir).IsOK()); + } + onnxruntime::test::TemporaryDirectory tmp_dir{test_dir}; + PathString inference_model_path{ + ConcatPathComponent(tmp_dir.Path(), ORT_TSTR("inference_model.onnx"))}; + + std::vector graph_output_names({"output-0"}); + ASSERT_STATUS_OK(model->ExportModelForInferencing(ToUTF8String(inference_model_path), graph_output_names)); + + // Load model + ONNX_NAMESPACE::ModelProto eval_model; + ONNX_NAMESPACE::ModelProto inference_model; + ORT_THROW_IF_ERROR(Model::Load(eval_model_uri, eval_model)); + ORT_THROW_IF_ERROR(Model::Load(inference_model_path, inference_model)); + + // Check it has only one graph input + ASSERT_EQ(eval_model.graph().input().size(), 6); + ASSERT_EQ(inference_model.graph().input().size(), 1); + ASSERT_EQ(inference_model.graph().input()[0].name(), "input-0"); + + // Check that it does not have any node which has op type SoftmaxCrossEntropyLoss + auto softmaxceloss_node_found = [](auto& model) -> bool { + for (auto& node : model.graph().node()) { + if (node.op_type() == "SoftmaxCrossEntropyLoss") { + return true; + } + } + return false; + }; + ASSERT_EQ(softmaxceloss_node_found(eval_model), true); + ASSERT_EQ(softmaxceloss_node_found(inference_model), false); + + // Try running an inference session + auto inference_session = std::make_unique(onnxruntime::SessionOptions(), *env); + ASSERT_STATUS_OK(inference_session->Load(inference_model_path)); + ASSERT_STATUS_OK(inference_session->Initialize()); + std::vector input_names({"input-0"}); + OrtValue graph_input; + GenerateRandomInput(std::array{2, 784}, graph_input); + std::vector feeds; + feeds.emplace_back(graph_input); + std::vector output_names({"output-0"}); + std::vector outputs; + ASSERT_STATUS_OK(inference_session->Run(RunOptions(), input_names, feeds, output_names, &outputs)); + ASSERT_EQ(outputs.size(), 1U); +} + +void CompareValue(float expected, float output, float rtol = 1e-4, float atol = 1e-5) { + ASSERT_NEAR(expected, output, atol); + ASSERT_NEAR(expected, output, rtol * std::abs(expected)); +} + +#if defined(USE_CUDA) || defined(USE_ROCM) + +const int64_t total_step_count = 100; +const float initial_lr = 1e-3f; +const int64_t resume_step = total_step_count / 2; + +void TestLRSchduler(const std::string& test_file_name, float initial_lr, int64_t total_step_count, + int64_t warmup_step_count) { + /// Load model and optimizer graph, create Module, Optimizer and LRScheduler instances. + auto model_uri = MODEL_FOLDER "training_model.onnx"; + auto optim_uri = MODEL_FOLDER "adamw.onnx"; + + onnxruntime::training::api::CheckpointState state; + auto checkpoint_to_load_path = MODEL_FOLDER "checkpoint.ckpt"; + ASSERT_STATUS_OK(LoadCheckpoint(checkpoint_to_load_path, state)); + + onnxruntime::SessionOptions session_option; + std::unique_ptr env; + ASSERT_STATUS_OK(Environment::Create(nullptr, env)); + const std::vector> providers{onnxruntime::test::DefaultCudaExecutionProvider()}; + auto model = std::make_unique( + ToUTF8String(model_uri), state.module_checkpoint_state.named_parameters, + session_option, *env, providers); + auto optim = std::make_shared( + ToUTF8String(optim_uri), model->NamedParameters(), session_option, + *env, providers); + + OrtValue input, target; + GenerateRandomInput(std::array{2, 784}, input); + onnxruntime::training::api::utils::CreateInputOrtValue( + std::array{2}, std::vector(2, 1), &target); + + /// Load test data for learning rate schedulers. + auto data_uri = ORT_TSTR("testdata/test_data_generation/lr_scheduler/" + test_file_name); + std::ifstream in{data_uri}; + // Element of vector represent a pair of > + typedef std::vector>> TestDataDictType; + TestDataDictType test_data; + const json j = json::parse(in); + j.get_to(test_data); + + int64_t resume_step = (*test_data.begin()).first; + ASSERT_EQ(total_step_count, static_cast(test_data.size()) + resume_step); + + if (resume_step != 0) { + /// Reset optimizer states to match the initial state we want to test. + onnxruntime::training::api::OptimizerCheckpointState optimizer_checkpoint_states; + auto group_opt_state = + optimizer_checkpoint_states.group_named_optimizer_states["group0"] = + std::make_shared(); + group_opt_state->step = resume_step; + group_opt_state->initial_lr = initial_lr; + ASSERT_STATUS_OK(optim->LoadStateDict(optimizer_checkpoint_states)); + } + + // KNOWN ISSUE: LinearLRScheduler by default use optim's states to calculate the first step's learning rate. + // If we restored it after creation, it will only affect the learning rate from the second step. + auto scheduler = std::make_unique( + optim, warmup_step_count, total_step_count); + + for (auto it = test_data.begin(); it != test_data.end(); ++it) { + onnxruntime::training::api::OptimizerCheckpointState optimizer_states; + ASSERT_STATUS_OK(optim->GetStateDict(optimizer_states)); + auto group_optimizer_state = optimizer_states.group_named_optimizer_states["group0"]; + CompareValue(it->second[0], group_optimizer_state->learning_rate); + ASSERT_EQ(it->first, group_optimizer_state->step); + + std::vector inputs{input, target}; + std::vector fetches; + ASSERT_STATUS_OK(model->TrainStep(inputs, fetches)); + ASSERT_STATUS_OK(optim->Step()); + ASSERT_STATUS_OK(scheduler->Step()); + } +} + +#endif + +} // namespace + TEST(TrainingApiTest, ModuleParametersSize) { auto model_uri = MODEL_FOLDER "training_model.onnx"; @@ -164,8 +312,18 @@ TEST(TrainingApiTest, ModuleTrainStep) { } } +TEST(TrainingApiTest, ModuleExportModelForInferencingCPU) { + std::vector> providers{onnxruntime::test::DefaultCpuExecutionProvider()}; + TestModuleExport(providers); +} + #if defined(USE_CUDA) || defined(USE_ROCM) +TEST(TrainingApiTest, ModuleExportModelForInferencingCUDA) { + std::vector> providers{onnxruntime::test::DefaultCudaExecutionProvider()}; + TestModuleExport(providers); +} + TEST(TrainingApiTest, OptimStep) { auto model_uri = MODEL_FOLDER "training_model.onnx"; auto optim_uri = MODEL_FOLDER "adamw.onnx"; @@ -241,83 +399,6 @@ TEST(TrainingApiTest, OptimStep) { } } -void CompareValue(float expected, float output, float rtol = 1e-4, float atol = 1e-5) { - ASSERT_NEAR(expected, output, atol); - ASSERT_NEAR(expected, output, rtol * std::abs(expected)); -} - -void TestLRSchduler(const std::string& test_file_name, float initial_lr, int64_t total_step_count, - int64_t warmup_step_count) { - /// Load model and optimizer graph, create Module, Optimizer and LRScheduler instances. - auto model_uri = MODEL_FOLDER "training_model.onnx"; - auto optim_uri = MODEL_FOLDER "adamw.onnx"; - - onnxruntime::training::api::CheckpointState state; - auto checkpoint_to_load_path = MODEL_FOLDER "checkpoint.ckpt"; - ASSERT_STATUS_OK(LoadCheckpoint(checkpoint_to_load_path, state)); - - onnxruntime::SessionOptions session_option; - std::unique_ptr env; - ASSERT_STATUS_OK(Environment::Create(nullptr, env)); - const std::vector> providers{onnxruntime::test::DefaultCudaExecutionProvider()}; - auto model = std::make_unique( - ToUTF8String(model_uri), state.module_checkpoint_state.named_parameters, - session_option, *env, providers); - auto optim = std::make_shared( - ToUTF8String(optim_uri), model->NamedParameters(), session_option, - *env, providers); - - OrtValue input, target; - GenerateRandomInput(std::array{2, 784}, input); - onnxruntime::training::api::utils::CreateInputOrtValue( - std::array{2}, std::vector(2, 1), &target); - - /// Load test data for learning rate schedulers. - auto data_uri = ORT_TSTR("testdata/test_data_generation/lr_scheduler/" + test_file_name); - std::ifstream in{data_uri}; - // Element of vector represent a pair of > - typedef std::vector>> TestDataDictType; - TestDataDictType test_data; - const json j = json::parse(in); - j.get_to(test_data); - - int64_t resume_step = (*test_data.begin()).first; - ASSERT_EQ(total_step_count, static_cast(test_data.size()) + resume_step); - - if (resume_step != 0) { - /// Reset optimizer states to match the initial state we want to test. - onnxruntime::training::api::OptimizerCheckpointState optimizer_checkpoint_states; - auto group_opt_state = - optimizer_checkpoint_states.group_named_optimizer_states["group0"] = - std::make_shared(); - group_opt_state->step = resume_step; - group_opt_state->initial_lr = initial_lr; - ASSERT_STATUS_OK(optim->LoadStateDict(optimizer_checkpoint_states)); - } - - // KNOWN ISSUE: LinearLRScheduler by default use optim's states to calculate the first step's learning rate. - // If we restored it after creation, it will only affect the learning rate from the second step. - auto scheduler = std::make_unique( - optim, warmup_step_count, total_step_count); - - for (auto it = test_data.begin(); it != test_data.end(); ++it) { - onnxruntime::training::api::OptimizerCheckpointState optimizer_states; - ASSERT_STATUS_OK(optim->GetStateDict(optimizer_states)); - auto group_optimizer_state = optimizer_states.group_named_optimizer_states["group0"]; - CompareValue(it->second[0], group_optimizer_state->learning_rate); - ASSERT_EQ(it->first, group_optimizer_state->step); - - std::vector inputs{input, target}; - std::vector fetches; - ASSERT_STATUS_OK(model->TrainStep(inputs, fetches)); - ASSERT_STATUS_OK(optim->Step()); - ASSERT_STATUS_OK(scheduler->Step()); - } -} - -const int64_t total_step_count = 100; -const float initial_lr = 1e-3f; -const int64_t resume_step = total_step_count / 2; TEST(TrainingApiTest, LinearLRScheduler_NoWarmUp_Test) { // No warm up. TestLRSchduler("warmup_linear_scheduler_warmupstep-0.json", initial_lr, total_step_count, 0); @@ -360,7 +441,6 @@ TEST(TrainingApiTest, LinearLRScheduler_WarmUp200Step_ResumeFromCheckpoint_Test) #endif -} // namespace } // namespace test } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/training_api/checkpoint.cc b/orttraining/orttraining/training_api/checkpoint.cc index 028f75dc3d..7f4494cde9 100644 --- a/orttraining/orttraining/training_api/checkpoint.cc +++ b/orttraining/orttraining/training_api/checkpoint.cc @@ -6,7 +6,6 @@ #include "core/common/logging/sinks/clog_sink.h" #include "core/common/path.h" #include "core/framework/framework_common.h" -#include "core/framework/tensorprotoutils.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" #include "core/platform/env.h" @@ -16,6 +15,7 @@ #include "orttraining/core/framework/checkpoint_common.h" #include "orttraining/core/framework/protobuf_message_sequence.h" #include "orttraining/training_api/include/checkpoint.h" +#include "orttraining/training_api/include/utils.h" namespace onnxruntime { namespace training { @@ -53,10 +53,6 @@ Status CreateTensorProtosFromOrtValues( [](const NameMLValMap::value_type& v) { return v.first; }); std::sort(ordered_tensor_names.begin(), ordered_tensor_names.end()); - // Copy the tensor data and create TensorProto storing the data. - InlinedVector tensor_data_buffer{}; - static const OrtMemoryInfo cpu_alloc_info{onnxruntime::CPU, OrtDeviceAllocator}; - saved_tensor_protos.reserve(ordered_tensor_names.size()); uint64_t total_bytes = 0; @@ -65,7 +61,6 @@ Status CreateTensorProtosFromOrtValues( const OrtValue& ort_value = name_to_ort_value.at(tensor_name); ORT_RETURN_IF_NOT(ort_value.IsTensor(), "ort_value.IsTensor() was false"); const Tensor& src_tensor = ort_value.Get(); - tensor_data_buffer.resize(src_tensor.SizeInBytes()); // Currently large model size not considered, so exception thrown here // when protobuf upper limit hit. @@ -74,23 +69,8 @@ Status CreateTensorProtosFromOrtValues( ORT_THROW("checkpoint file size hit upper limit."); } - auto& tensor_location = src_tensor.Location(); - if (tensor_location.device.Type() == OrtDevice::CPU || - tensor_location.mem_type == OrtMemTypeCPUInput || - tensor_location.mem_type == OrtMemTypeCPUOutput || - tensor_location.device.Type() == OrtDevice::GPU) { - gsl::span dst_span = gsl::make_span(tensor_data_buffer); - ORT_RETURN_IF_NOT(src_tensor.SizeInBytes() == static_cast(dst_span.size_bytes()), "src size != dst size"); - Tensor dst_tensor{src_tensor.DataType(), src_tensor.Shape(), dst_span.data(), cpu_alloc_info}; - ORT_RETURN_IF_ERROR(data_transfer_manager.CopyTensor(src_tensor, dst_tensor)); - - // Convert Tensor to TensorProto. - ONNX_NAMESPACE::TensorProto tensor_proto; - tensor_proto = utils::TensorToTensorProto(dst_tensor, tensor_name); - saved_tensor_protos.emplace_back(tensor_proto); - } else { - ORT_THROW("Unsupported device type for saving tensors"); - } + saved_tensor_protos.emplace_back(utils::CopyTensorToTensorProto( + src_tensor, tensor_name, data_transfer_manager)); } return Status::OK(); diff --git a/orttraining/orttraining/training_api/include/module.h b/orttraining/orttraining/training_api/include/module.h index cc060f5252..738530f54e 100644 --- a/orttraining/orttraining/training_api/include/module.h +++ b/orttraining/orttraining/training_api/include/module.h @@ -106,6 +106,11 @@ struct Module { // Copy parameter values from contiguous buffer held by parameters_buffer onto parameters Status CopyBufferToParameters(OrtValue& parameters_buffer, const bool trainable_only = true); + // Load the eval model from eval_model_path_or_bytes and transform it for the purpose of + // inferencing, and serialize to given path + Status ExportModelForInferencing(const std::string& inference_model_path, + gsl::span graph_output_names) const; + private: std::unique_ptr train_sess_{nullptr}; std::unique_ptr eval_sess_{nullptr}; @@ -118,6 +123,7 @@ struct Module { std::vector gradients_; bool accumulate_gradient_ = false; const std::unordered_map>& named_parameters_; + std::string eval_model_path_; }; } // namespace api diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h index b8e363359f..2e14d082e9 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h @@ -296,6 +296,25 @@ struct OrtTrainingApi { * */ ORT_CLASS_RELEASE(CheckpointState); + + /** \brief Export a model that can be used for inferencing. + * + * If the training session was provided with an eval model, the training session can generate + * an inference model if it knows the inference graph outputs. The input inference graph outputs + * are used to prune the eval model so that the output model's outputs align with the provided outputs. + * The exported model is saved at the path provided and can be used for inferencing with InferenceSession. + * Note that the function re-loads the eval model from the path provided to CreateTrainingSession and expects + * that this path still be valid. + * + * \param[in] sess The training session. + * \param[in] inference_model_path Path where the inference model should be serialized to. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(ExportModelForInferencing, _Inout_ OrtTrainingSession* sess, + _In_ const ORTCHAR_T* inference_model_path, size_t graph_outputs_len, + _In_reads_(graph_outputs_len) const char* const* graph_output_names); }; typedef struct OrtTrainingApi OrtTrainingApi; diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h index 1d812ba188..d810085646 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h @@ -143,6 +143,14 @@ class TrainingSession : public detail::Base { * */ void OptimizerStep(); + + /** \brief Exports a model that can be used for inferencing with the inference session. + * + * Wraps OrtTrainingApi::ExportModelForInferencing + * + */ + void ExportModelForInferencing(const std::basic_string& inference_model_path, + const std::vector& graph_output_names); }; } // namespace Ort diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h index 3b04765336..0ebff687f4 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h @@ -94,4 +94,14 @@ inline void CheckpointState::SaveCheckpoint(const TrainingSession& session, ThrowOnError(GetTrainingApi().SaveCheckpoint(path_to_checkpoint.c_str(), session, include_optimizer_states)); } +inline void TrainingSession::ExportModelForInferencing(const std::basic_string& inference_model_path, + const std::vector& graph_output_names) { + std::vector output_names(graph_output_names.size(), nullptr); + for (auto& output_name : graph_output_names) { + output_names.push_back(output_name.c_str()); + } + ThrowOnError(GetTrainingApi().ExportModelForInferencing( + p_, inference_model_path.c_str(), graph_output_names.size(), output_names.data())); +} + } // namespace Ort diff --git a/orttraining/orttraining/training_api/include/ort_training_apis.h b/orttraining/orttraining/training_api/include/ort_training_apis.h index 5521f922bc..833989a1a8 100644 --- a/orttraining/orttraining/training_api/include/ort_training_apis.h +++ b/orttraining/orttraining/training_api/include/ort_training_apis.h @@ -60,4 +60,8 @@ ORT_API(void, ReleaseCheckpointState, _Frees_ptr_opt_ OrtCheckpointState* sessio ORT_API(void, ReleaseTrainingSession, _Frees_ptr_opt_ OrtTrainingSession* session); +ORT_API_STATUS_IMPL(ExportModelForInferencing, _Inout_ OrtTrainingSession* sess, + _In_ const ORTCHAR_T* inference_model_path, size_t graph_outputs_len, + _In_reads_(graph_outputs_len) const char* const* graph_output_names); + } // namespace OrtTrainingApis diff --git a/orttraining/orttraining/training_api/include/training_session.h b/orttraining/orttraining/training_api/include/training_session.h index 6681c9e05c..379ba31a15 100644 --- a/orttraining/orttraining/training_api/include/training_session.h +++ b/orttraining/orttraining/training_api/include/training_session.h @@ -63,11 +63,14 @@ class TrainingSession { Status CreateCheckpointState(CheckpointState& chkpt_state, bool save_optimizer_state) const; - size_t GetParametersSize(const bool trainable_only=true) const; + size_t GetParametersSize(const bool trainable_only = true) const; - Status CopyParametersToBuffer(OrtValue& parameters_buffer, const bool trainable_only=true); - - Status CopyBufferToParameters(OrtValue& parameters_buffer, const bool trainable_only=true); + Status CopyParametersToBuffer(OrtValue& parameters_buffer, const bool trainable_only = true); + + Status CopyBufferToParameters(OrtValue& parameters_buffer, const bool trainable_only = true); + + Status ExportModelForInferencing(const std::string& inference_model_path, + gsl::span graph_output_names) const; private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TrainingSession); diff --git a/orttraining/orttraining/training_api/include/utils.h b/orttraining/orttraining/training_api/include/utils.h index 9d00b05a00..f1275e3101 100644 --- a/orttraining/orttraining/training_api/include/utils.h +++ b/orttraining/orttraining/training_api/include/utils.h @@ -85,6 +85,9 @@ T GetValue(OrtValue& ort_value) { return val; } +ONNX_NAMESPACE::TensorProto CopyTensorToTensorProto(const Tensor& tensor, const std::string& tensor_proto_name, + const DataTransferManager& data_transfer_manager); + } // namespace utils } // namespace api } // namespace training diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index 0a29da2b74..cbf90e5472 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/common/safeint.h" +#include "core/common/string_utils.h" #include "core/framework/execution_provider.h" #include "core/session/inference_session.h" #include "core/session/environment.h" @@ -20,6 +21,94 @@ namespace { // TODO: consolidate with frontend tooling const std::string ACCUMULATE_GRAD_CONTROL_INPUT_NAME{"lazy_reset_grad"}; +std::unordered_set GetReverseReachableNodes(Graph& inference_graph, + InlinedVector& output_node_args) { + // Perform a graph traversal from the graph outputs to collect all reachable nodes from the outputs + InlinedVector nodes; + nodes.reserve((output_node_args.size())); + std::unordered_set visited_nodes; + for (auto node_arg : output_node_args) { + auto* node = inference_graph.GetProducerNode(node_arg->Name()); + if (!node || std::find(nodes.begin(), nodes.end(), node->Index()) != nodes.end()) { + continue; + } + + nodes.push_back(node->Index()); + } + + inference_graph.ReverseDFSFrom(nodes, [&visited_nodes](const Node* node) { visited_nodes.insert(node); }, {}); + + return visited_nodes; +} + +Status RemoveUnusedNodes(Graph& inference_graph, InlinedVector& output_node_args) { + auto reachable_nodes = GetReverseReachableNodes(inference_graph, output_node_args); + + // Get all graph nodes and remove those that are not in the reachable nodes. + GraphViewer graph_viewer(inference_graph); + for (auto& node : graph_viewer.Nodes()) { + if (!reachable_nodes.count(&node)) { + inference_graph.RemoveNode(node.Index()); + } + } + + return Status::OK(); +} + +Status TransformModelOutputsForInference(Graph& inference_graph, + gsl::span inference_graph_outputs) { + // Model is updated to remove any outputs that are not defined in inference_graph_outputs. Nodes + // producing these unused model outputs are also subsequently removed. + + ORT_RETURN_IF(inference_graph_outputs.empty(), + "Expected a non empty vector of graph output names. Got an empty vector."); + + InlinedVector inference_graph_output_node_args; + inference_graph_output_node_args.reserve(inference_graph_outputs.size()); + for (const auto& output_name : inference_graph_outputs) { + const NodeArg* output_node_arg = inference_graph.GetNodeArg(std::string(output_name)); + ORT_RETURN_IF_NOT(output_node_arg, "Expected graph output for inference graph " + std::string(output_name) + + " could not be found. Please regenerate the eval graph."); + inference_graph_output_node_args.push_back(output_node_arg); + } + + // Set the inference graph outputs, and remove any unused nodes. + inference_graph.SetOutputs(inference_graph_output_node_args); + ORT_RETURN_IF_ERROR(RemoveUnusedNodes(inference_graph, inference_graph_output_node_args)); + + ORT_RETURN_IF_ERROR(inference_graph.Resolve()); + + return Status::OK(); +} + +Status TransformModelInputsForInference(Graph& inference_graph, + const std::unordered_map< + std::string, std::shared_ptr>& named_parameters, + const DataTransferManager& data_transfer_manager) { + std::vector user_graph_inputs; + for (auto& graph_input_node_arg : inference_graph.GetInputs()) { + auto named_parameter_it = named_parameters.find(graph_input_node_arg->Name()); + if (named_parameter_it == named_parameters.end()) { + if (inference_graph.GetConsumerNodes(graph_input_node_arg->Name()).empty()) { + continue; + } + user_graph_inputs.emplace_back(graph_input_node_arg); + } else { + ORT_ENFORCE(!inference_graph.IsInitializedTensor(named_parameter_it->first), + "The eval graph is invalid. Expected model parameter ", + named_parameter_it->first, " to be a graph input, not a graph initializer."); + inference_graph.AddInitializedTensor(utils::CopyTensorToTensorProto( + named_parameter_it->second->Data().Get(), + named_parameter_it->first, data_transfer_manager)); + } + } + + inference_graph.SetInputs(user_graph_inputs); + ORT_RETURN_IF_ERROR(inference_graph.Resolve()); + + return Status::OK(); +} + } // namespace Status Parameter::SetGrad(const std::string& gradient_name, const OrtValue& param_grad) { @@ -162,6 +251,9 @@ Module::Module(const std::string& train_model_path_or_bytes, if (eval_model_path_or_bytes.has_value()) { eval_sess_ = std::make_unique(session_options, env); ORT_THROW_IF_ERROR(eval_sess_->Load(eval_model_path_or_bytes.value())); + for (const auto& provider : providers) { + ORT_THROW_IF_ERROR(eval_sess_->RegisterExecutionProvider(provider)); + } ORT_THROW_IF_ERROR(eval_sess_->Initialize()); utils::GetGraphInputOutputNames(eval_sess_, eval_input_names_, eval_output_names_); @@ -185,6 +277,10 @@ Module::Module(const std::string& train_model_path_or_bytes, } eval_input_names_ = eval_user_input_names; eval_input_names_.insert(eval_input_names_.end(), eval_param_input_names.begin(), eval_param_input_names.end()); + + // Keep a copy of the eval model path to be able to later export the model for inferencing. + // The inference model will be reconstructed from the eval model. + eval_model_path_ = eval_model_path_or_bytes.value(); } } @@ -232,7 +328,7 @@ Status Module::CopyParametersToBuffer(OrtValue& parameters_buffer, const bool tr ORT_ENFORCE(nullptr != init_tensor); auto expected_buffer_size = static_cast(GetParametersSize(trainable_only)); ORT_ENFORCE(init_tensor->Shape().Size() == expected_buffer_size, - "Parameters buffer size incorrect. Expected:",expected_buffer_size, + "Parameters buffer size incorrect. Expected:", expected_buffer_size, ", Actual:", init_tensor->Shape().Size()); const DataTransferManager& sess_data_transfer_manager = train_sess_->GetDataTransferManager(); @@ -275,7 +371,7 @@ Status Module::CopyBufferToParameters(OrtValue& parameters_buffer, const bool tr ORT_ENFORCE(nullptr != init_tensor); auto expected_buffer_size = static_cast(GetParametersSize(trainable_only)); ORT_ENFORCE(init_tensor->Shape().Size() == expected_buffer_size, - "Parameters buffer size incorrect. Expected:",expected_buffer_size, + "Parameters buffer size incorrect. Expected:", expected_buffer_size, ", Actual:", init_tensor->Shape().Size()); const DataTransferManager& sess_data_transfer_manager = train_sess_->GetDataTransferManager(); @@ -356,6 +452,33 @@ Status Module::GetStateDict(ModuleCheckpointState& module_checkpoint_state) { return Status::OK(); } +Status Module::ExportModelForInferencing(const std::string& inference_model_path, + gsl::span graph_output_names) const { + ORT_RETURN_IF(!eval_sess_ || eval_model_path_.empty(), + "Eval model was not provided. Cannot export a model for inferencing."); + + ONNX_NAMESPACE::ModelProto eval_model; + ORT_THROW_IF_ERROR(Model::Load(ToPathString(eval_model_path_), eval_model)); + + // Clone the eval mode into an inference onnxruntime::Model. + std::shared_ptr inference_model; + ORT_RETURN_IF_ERROR(Model::Load(eval_model, inference_model, nullptr, logging::LoggingManager::DefaultLogger())); + + // The cloned model's outputs are transformed such that the model has outputs as defined by graph_output_names + // Any nodes not contributing to the inference outputs will be pruned. + ORT_THROW_IF_ERROR(TransformModelOutputsForInference(inference_model->MainGraph(), graph_output_names)); + + // The cloned model's inputs are transformed such that the model has only user defined inputs. All parameters + // are moved to be constant initializers for the model. + ORT_RETURN_IF_ERROR(TransformModelInputsForInference(inference_model->MainGraph(), named_parameters_, + eval_sess_->GetDataTransferManager())); + + // Save the model at desired location. + ORT_THROW_IF_ERROR(Model::Save(*inference_model, inference_model_path)); + + return Status::OK(); +} + } // namespace api } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc index cb439efda0..7d37335a4e 100644 --- a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc +++ b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc @@ -339,6 +339,31 @@ ORT_API(void, OrtTrainingApis::ReleaseCheckpointState, _Frees_ptr_opt_ OrtCheckp delete reinterpret_cast(checkpoint_state); } +ORT_API_STATUS_IMPL(OrtTrainingApis::ExportModelForInferencing, _Inout_ OrtTrainingSession* sess, + _In_ const ORTCHAR_T* inference_model_path, size_t graph_outputs_len, + _In_reads_(graph_outputs_len) const char* const* graph_output_names) { + API_IMPL_BEGIN + + auto session = reinterpret_cast(sess); + + onnxruntime::InlinedVector output_names(graph_outputs_len); + + for (size_t i = 0; i != graph_outputs_len; ++i) { + if (graph_output_names[i] == nullptr || graph_output_names[i][0] == '\0') { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Name of graph output cannot be empty. Please provide valid graph names"); + } + + output_names[i] = graph_output_names[i]; + } + + ORT_API_RETURN_IF_STATUS_NOT_OK( + session->ExportModelForInferencing(onnxruntime::ToUTF8String(inference_model_path), output_names)); + + return nullptr; + API_IMPL_END +} + static constexpr OrtTrainingApi ort_training_api = { // NOTE: The C# bindings depend on the API order within this struct. Since Training APIs are not officially // released, it is OK to change the order here, however a corresponding matching change should also be done in the @@ -363,6 +388,7 @@ static constexpr OrtTrainingApi ort_training_api = { &OrtTrainingApis::CopyBufferToParameters, &OrtTrainingApis::ReleaseTrainingSession, &OrtTrainingApis::ReleaseCheckpointState, + &OrtTrainingApis::ExportModelForInferencing, }; ORT_API(const OrtTrainingApi*, OrtTrainingApis::GetTrainingApi, uint32_t) { diff --git a/orttraining/orttraining/training_api/training_session.cc b/orttraining/orttraining/training_api/training_session.cc index 939ad7b7c5..654e94df34 100644 --- a/orttraining/orttraining/training_api/training_session.cc +++ b/orttraining/orttraining/training_api/training_session.cc @@ -110,6 +110,11 @@ Status TrainingSession::CopyBufferToParameters(OrtValue& parameters_buffer, cons return module_->CopyBufferToParameters(parameters_buffer, trainable_only); } +Status TrainingSession::ExportModelForInferencing(const std::string& inference_model_path, + gsl::span graph_output_names) const { + return module_->ExportModelForInferencing(inference_model_path, graph_output_names); +} + } // namespace api } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/training_api/utils.cc b/orttraining/orttraining/training_api/utils.cc index 7696bf67ca..a4c95a53b4 100644 --- a/orttraining/orttraining/training_api/utils.cc +++ b/orttraining/orttraining/training_api/utils.cc @@ -6,6 +6,7 @@ #include "core/framework/ort_value.h" #include "core/framework/tensor.h" #include "core/framework/allocator.h" +#include "core/framework/tensorprotoutils.h" #include "orttraining/training_api/include/utils.h" @@ -87,6 +88,31 @@ Status OrtValueLike(const SessionState& sess_state, const OrtValue& input_val, O return Status::OK(); } +ONNX_NAMESPACE::TensorProto CopyTensorToTensorProto(const Tensor& src_tensor, const std::string& tensor_proto_name, + const DataTransferManager& data_transfer_manager) { + auto& tensor_location = src_tensor.Location(); + if (tensor_location.device.Type() != OrtDevice::CPU && + tensor_location.mem_type != OrtMemTypeCPUInput && + tensor_location.mem_type != OrtMemTypeCPUOutput && + tensor_location.device.Type() != OrtDevice::GPU) { + ORT_THROW("Unsupported device type for saving tensors"); + } + + // Copy the tensor data and create TensorProto storing the data. + InlinedVector tensor_data_buffer{}; + tensor_data_buffer.resize(src_tensor.SizeInBytes()); + static const OrtMemoryInfo cpu_alloc_info{onnxruntime::CPU, OrtDeviceAllocator}; + + gsl::span dst_span = gsl::make_span(tensor_data_buffer); + ORT_ENFORCE(src_tensor.SizeInBytes() == static_cast(dst_span.size_bytes()), "src size != dst size"); + Tensor dst_tensor{src_tensor.DataType(), src_tensor.Shape(), dst_span.data(), cpu_alloc_info}; + ORT_THROW_IF_ERROR(data_transfer_manager.CopyTensor(src_tensor, dst_tensor)); + + // Convert Tensor to TensorProto. + ONNX_NAMESPACE::TensorProto tensor_proto; + return onnxruntime::utils::TensorToTensorProto(dst_tensor, tensor_proto_name); +} + } // namespace utils } // namespace api } // namespace training