From b9d80131a72f3e3bcac86d138cc46c7b8a7e4b07 Mon Sep 17 00:00:00 2001 From: pengwa Date: Wed, 2 Aug 2023 14:05:26 +0800 Subject: [PATCH] Save optimized pre_grad graph once ready (#16816) ### Save optimized pre_grad graph once it's ready `graph_builder.build()` did two things for training: 1. optimized forward graph, e.g. pre_grad graph optimization. 2. build gradient graph. Originally after `graph_builder.build()` completed, pre_graph graph is saved. While if pre_grad graph optimization completed, but fail during gradient graph build, we still cannot get pre_grad graph to investigate. This PR made the change once pre_grad graph is ready, we save it (if save_model is enabled) in C++ backend. --- .../core/framework/ortmodule_graph_builder.cc | 4 ++++ .../core/optimizer/graph_transformer_config.h | 3 +++ .../orttraining/python/orttraining_pybind_state.cc | 1 + .../training/ortmodule/_graph_execution_manager.py | 9 +++++++++ .../python/training/ortmodule/_onnx_models.py | 14 ++++---------- .../python/training/ortmodule/_training_manager.py | 3 --- .../test/python/orttraining_test_ortmodule_api.py | 2 +- 7 files changed, 22 insertions(+), 14 deletions(-) diff --git a/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc b/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc index 4b7325135e..bfc6e4a5bb 100644 --- a/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc +++ b/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc @@ -182,6 +182,10 @@ Status OrtModuleGraphBuilder::OptimizeForwardGraph(const TrainingGraphTransforme graph_transformation_mgr.ApplyTransformers(forward_graph, static_cast(i), *logger_)); } + if (!config.optimized_pre_grad_filepath.empty()) { + ORT_RETURN_IF_ERROR(Model::Save(*forward_model_, config.optimized_pre_grad_filepath)); + } + return Status::OK(); } diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_config.h b/orttraining/orttraining/core/optimizer/graph_transformer_config.h index 7f82f944e0..cc3edfb016 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_config.h +++ b/orttraining/orttraining/core/optimizer/graph_transformer_config.h @@ -30,6 +30,9 @@ struct TrainingGraphTransformerConfiguration : public GraphTransformerConfigurat // Enable label sparsity compute optimization for the input names in the below list. std::vector sparse_label_input_names; + + // Path for serialization of the transformed optimized model. If empty, serialization is disabled. + std::string optimized_pre_grad_filepath; }; } // namespace training diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index 31a5819b81..a8d9db8d87 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -769,6 +769,7 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn .def_readwrite("enable_compute_optimizer", &TrainingGraphTransformerConfiguration::enable_compute_optimizer) .def_readwrite("sparse_embedding_input_names", &TrainingGraphTransformerConfiguration::sparse_embedding_input_names) .def_readwrite("sparse_label_input_names", &TrainingGraphTransformerConfiguration::sparse_label_input_names) + .def_readwrite("optimized_pre_grad_filepath", &TrainingGraphTransformerConfiguration::optimized_pre_grad_filepath) .def_readwrite("propagate_cast_ops_config", &TrainingGraphTransformerConfiguration::GraphTransformerConfiguration::propagate_cast_ops_config); py::class_ module_graph_builder_config( diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index eac8c2173b..abc44c9da2 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -431,6 +431,15 @@ class GraphExecutionManager(GraphExecutionInterface): graph_transformer_config.propagate_cast_ops_config.allow = self._runtime_options.propagate_cast_ops_allow graph_transformer_config.propagate_cast_ops_config.strategy = self._runtime_options.propagate_cast_ops_strategy graph_transformer_config.enable_compute_optimizer = self._runtime_options.enable_compute_optimizer + + if self._debug_options.save_onnx_models.save: + graph_transformer_config.optimized_pre_grad_filepath = os.path.join( + self._debug_options.save_onnx_models.path, + _onnx_models._get_onnx_file_name( + self._debug_options.save_onnx_models.name_prefix, "optimized_pre_grad", self._export_mode + ), + ) + return graph_transformer_config @_logger.TrackTime(_logger.ORTModuleInitPhase.GRAPH_BUILDER_INIT) diff --git a/orttraining/orttraining/python/training/ortmodule/_onnx_models.py b/orttraining/orttraining/python/training/ortmodule/_onnx_models.py index 45e9992692..ac09c838af 100644 --- a/orttraining/orttraining/python/training/ortmodule/_onnx_models.py +++ b/orttraining/orttraining/python/training/ortmodule/_onnx_models.py @@ -26,15 +26,14 @@ class ONNXModels: 1. exported_model: Model that is exported by torch.onnx.export 2. optimized_model: For eval mode it's exported_model with concrete input shapes set if needed, for training mode, it's optimized model after gradients graph has been built. - 3. optimized_pre_grad_model: Optimized model before gradient graph is built. It's None for eval mode. - 4. In addition, ORTModule also saves the execution_model which is the model - that is being executed by ORT. It has further optimizations done by the - InferenceSession and is saved by the InferenceSession. + In addition, ORTModule also saves two other models, to the user-provided path: + a. the pre_grad_model which is the model before the gradients graph is built. + b. the execution_model which is the model that is being executed by ORT. + It has further optimizations done by the InferenceSession and is saved by the InferenceSession. """ exported_model: Optional[onnx.ModelProto] = None optimized_model: Optional[onnx.ModelProto] = None - optimized_pre_grad_model: Optional[onnx.ModelProto] = None def save_exported_model(self, path, name_prefix, export_mode): # save the ortmodule exported model @@ -47,8 +46,3 @@ class ONNXModels: _save_model( self.optimized_model, os.path.join(path, _get_onnx_file_name(name_prefix, "optimized", export_mode)) ) - if self.optimized_pre_grad_model is not None: - _save_model( - self.optimized_pre_grad_model, - os.path.join(path, _get_onnx_file_name(name_prefix, "optimized_pre_grad", export_mode)), - ) diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index cb8561867a..b28526522b 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -360,9 +360,6 @@ class TrainingManager(GraphExecutionManager): super()._build_graph(graph_transformer_config) self._onnx_models.optimized_model = onnx.load_model_from_string(self._graph_builder.get_gradient_model()) - self._onnx_models.optimized_pre_grad_model = onnx.load_model_from_string( - self._graph_builder.get_forward_model() - ) # Apply registered graph transformers to the optimized model device_type = self._device.type diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index bdb5dc4835..afbf1a23a3 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -5232,7 +5232,7 @@ def test_sigmoid_grad_opset13(): pt_prediction, pt_loss = run_step(pt_model, pt_x) if step == 0: model_onx = ort_model._torch_module._execution_manager._training_manager._onnx_models - for name in ["exported_model", "optimized_model", "optimized_pre_grad_model"]: + for name in ["exported_model", "optimized_model"]: onx = getattr(model_onx, name) opv = None for op in onx.opset_import: