diff --git a/orttraining/orttraining/core/framework/module_gradient_graph_builder.cc b/orttraining/orttraining/core/framework/module_gradient_graph_builder.cc index 99ca1bf2c9..03380ef294 100644 --- a/orttraining/orttraining/core/framework/module_gradient_graph_builder.cc +++ b/orttraining/orttraining/core/framework/module_gradient_graph_builder.cc @@ -94,7 +94,14 @@ std::string ModuleGradientGraphBuilder::GetGradientModel() const { if (!gradient_model_->ToProto().SerializeToString(&model_str)) { ORT_THROW("Fail to serialize gradient model to string."); } + return model_str; +} +std::string ModuleGradientGraphBuilder::GetInferenceOptimizedModel() const { + std::string model_str; + if (!inference_optimized_model_->ToProto().SerializeToString(&model_str)) { + ORT_THROW("Fail to serialize inference optimized model to string."); + } return model_str; } @@ -159,6 +166,9 @@ Status ModuleGradientGraphBuilder::BuildGradientGraph() { graph_transformation_mgr.ApplyTransformers(gradient_graph, static_cast(i), *logger_)); } + // Save a copy of inference optimized model + ORT_RETURN_IF_ERROR(Model::Load(gradient_model_->ToProto(), inference_optimized_model_, nullptr, *logger_)); + // Build gradient graph. GradientGraphConfiguration gradient_graph_config{}; gradient_graph_config.use_invertible_layernorm_grad = config_.use_invertible_layernorm_grad; diff --git a/orttraining/orttraining/core/framework/module_gradient_graph_builder.h b/orttraining/orttraining/core/framework/module_gradient_graph_builder.h index 3379ab65ef..9410a89bbb 100644 --- a/orttraining/orttraining/core/framework/module_gradient_graph_builder.h +++ b/orttraining/orttraining/core/framework/module_gradient_graph_builder.h @@ -76,6 +76,12 @@ class ModuleGradientGraphBuilder { */ std::string GetGradientModel() const; + /** + * Get inference optimized model. + * @return The gradient model serialized to string. + */ + std::string GetInferenceOptimizedModel() const; + /** * Get the training graphs information. * @return The training graphs information. @@ -96,6 +102,7 @@ class ModuleGradientGraphBuilder { void ReorderOutputs(); std::shared_ptr model_; + std::shared_ptr inference_optimized_model_; std::shared_ptr gradient_model_; TrainingGraphInfo training_graph_info_; diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index c458a130aa..55df003902 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -540,6 +540,10 @@ py::class_(m, "TrainingAgent", R"pbdoc(This is the main class use [](ModuleGradientGraphBuilder* module_gradient_graph_builder) { return py::bytes(module_gradient_graph_builder->GetGradientModel()); }) + .def("get_inference_optimized_model", + [](ModuleGradientGraphBuilder* module_gradient_graph_builder) { + return py::bytes(module_gradient_graph_builder->GetInferenceOptimizedModel()); + }) .def("get_training_graph_info", [](ModuleGradientGraphBuilder* module_gradient_graph_builder) { return module_gradient_graph_builder->GetTrainingGraphInfo(); }); diff --git a/orttraining/orttraining/python/training/ortmodule.py b/orttraining/orttraining/python/training/ortmodule.py index 29d4387b16..ea1cd743af 100644 --- a/orttraining/orttraining/python/training/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule.py @@ -366,11 +366,8 @@ class ORTModule(torch.nn.Module): def _get_inference_graph_and_init_gradient_graph_builder(self, *inputs, **kwargs): self._onnx_inference = self._get_inference_graph(*inputs, **kwargs) - if self._save_onnx: - onnx.save(self._onnx_inference, - self._save_onnx_prefix + '_inference.onnx') - + onnx.save(self._onnx_inference, self._save_onnx_prefix + '_inference.onnx') self._initialize_module_gradient_graph_builder() def _create_training_session(self): @@ -397,6 +394,9 @@ class ORTModule(torch.nn.Module): session_options.execution_order = onnxruntime.ExecutionOrder.PRIORITY_BASED # 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2. session_options.log_severity_level = int(self._verbosity) + # enable dumping optimized training graph + if self._save_onnx: + session_options.optimized_model_filepath = self._save_onnx_prefix + '_training_optimized.onnx' self._training_session = onnxruntime.training.TrainingAgent(self._onnx_training.SerializeToString(), session_options, providers, provider_options) @@ -412,8 +412,10 @@ class ORTModule(torch.nn.Module): self._onnx_graphs_info = self._module_gradient_graph_builder.get_training_graph_info() if self._save_onnx: - onnx.save(self._onnx_training, - self._save_onnx_prefix + '_training.onnx') + inference_optimized_model = onnx.load_model_from_string( + self._module_gradient_graph_builder.get_inference_optimized_model()) + onnx.save(inference_optimized_model, self._save_onnx_prefix + '_inference_optimized.onnx') + onnx.save(self._onnx_training, self._save_onnx_prefix + '_training.onnx') def eval(self: T) -> T: self._flattened_output_module.eval()