Save optimized pre_grad graph once ready (#16816)

### Save optimized pre_grad graph once it's ready

`graph_builder.build()` did two things for training: 1. optimized
forward graph, e.g. pre_grad graph optimization. 2. build gradient
graph.

Originally after `graph_builder.build()` completed, pre_graph graph is
saved. While if pre_grad graph optimization completed, but fail during
gradient graph build, we still cannot get pre_grad graph to investigate.

This PR made the change once pre_grad graph is ready, we save it (if
save_model is enabled) in C++ backend.
This commit is contained in:
pengwa 2023-08-02 14:05:26 +08:00 committed by GitHub
parent ba49d64f67
commit b9d80131a7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 22 additions and 14 deletions

View file

@ -182,6 +182,10 @@ Status OrtModuleGraphBuilder::OptimizeForwardGraph(const TrainingGraphTransforme
graph_transformation_mgr.ApplyTransformers(forward_graph, static_cast<TransformerLevel>(i), *logger_));
}
if (!config.optimized_pre_grad_filepath.empty()) {
ORT_RETURN_IF_ERROR(Model::Save(*forward_model_, config.optimized_pre_grad_filepath));
}
return Status::OK();
}

View file

@ -30,6 +30,9 @@ struct TrainingGraphTransformerConfiguration : public GraphTransformerConfigurat
// Enable label sparsity compute optimization for the input names in the below list.
std::vector<std::string> sparse_label_input_names;
// Path for serialization of the transformed optimized model. If empty, serialization is disabled.
std::string optimized_pre_grad_filepath;
};
} // namespace training

View file

@ -769,6 +769,7 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
.def_readwrite("enable_compute_optimizer", &TrainingGraphTransformerConfiguration::enable_compute_optimizer)
.def_readwrite("sparse_embedding_input_names", &TrainingGraphTransformerConfiguration::sparse_embedding_input_names)
.def_readwrite("sparse_label_input_names", &TrainingGraphTransformerConfiguration::sparse_label_input_names)
.def_readwrite("optimized_pre_grad_filepath", &TrainingGraphTransformerConfiguration::optimized_pre_grad_filepath)
.def_readwrite("propagate_cast_ops_config", &TrainingGraphTransformerConfiguration::GraphTransformerConfiguration::propagate_cast_ops_config);
py::class_<OrtModuleGraphBuilderConfiguration> module_graph_builder_config(

View file

@ -431,6 +431,15 @@ class GraphExecutionManager(GraphExecutionInterface):
graph_transformer_config.propagate_cast_ops_config.allow = self._runtime_options.propagate_cast_ops_allow
graph_transformer_config.propagate_cast_ops_config.strategy = self._runtime_options.propagate_cast_ops_strategy
graph_transformer_config.enable_compute_optimizer = self._runtime_options.enable_compute_optimizer
if self._debug_options.save_onnx_models.save:
graph_transformer_config.optimized_pre_grad_filepath = os.path.join(
self._debug_options.save_onnx_models.path,
_onnx_models._get_onnx_file_name(
self._debug_options.save_onnx_models.name_prefix, "optimized_pre_grad", self._export_mode
),
)
return graph_transformer_config
@_logger.TrackTime(_logger.ORTModuleInitPhase.GRAPH_BUILDER_INIT)

View file

@ -26,15 +26,14 @@ class ONNXModels:
1. exported_model: Model that is exported by torch.onnx.export
2. optimized_model: For eval mode it's exported_model with concrete input shapes set if needed,
for training mode, it's optimized model after gradients graph has been built.
3. optimized_pre_grad_model: Optimized model before gradient graph is built. It's None for eval mode.
4. In addition, ORTModule also saves the execution_model which is the model
that is being executed by ORT. It has further optimizations done by the
InferenceSession and is saved by the InferenceSession.
In addition, ORTModule also saves two other models, to the user-provided path:
a. the pre_grad_model which is the model before the gradients graph is built.
b. the execution_model which is the model that is being executed by ORT.
It has further optimizations done by the InferenceSession and is saved by the InferenceSession.
"""
exported_model: Optional[onnx.ModelProto] = None
optimized_model: Optional[onnx.ModelProto] = None
optimized_pre_grad_model: Optional[onnx.ModelProto] = None
def save_exported_model(self, path, name_prefix, export_mode):
# save the ortmodule exported model
@ -47,8 +46,3 @@ class ONNXModels:
_save_model(
self.optimized_model, os.path.join(path, _get_onnx_file_name(name_prefix, "optimized", export_mode))
)
if self.optimized_pre_grad_model is not None:
_save_model(
self.optimized_pre_grad_model,
os.path.join(path, _get_onnx_file_name(name_prefix, "optimized_pre_grad", export_mode)),
)

View file

@ -360,9 +360,6 @@ class TrainingManager(GraphExecutionManager):
super()._build_graph(graph_transformer_config)
self._onnx_models.optimized_model = onnx.load_model_from_string(self._graph_builder.get_gradient_model())
self._onnx_models.optimized_pre_grad_model = onnx.load_model_from_string(
self._graph_builder.get_forward_model()
)
# Apply registered graph transformers to the optimized model
device_type = self._device.type

View file

@ -5232,7 +5232,7 @@ def test_sigmoid_grad_opset13():
pt_prediction, pt_loss = run_step(pt_model, pt_x)
if step == 0:
model_onx = ort_model._torch_module._execution_manager._training_manager._onnx_models
for name in ["exported_model", "optimized_model", "optimized_pre_grad_model"]:
for name in ["exported_model", "optimized_model"]:
onx = getattr(model_onx, name)
opv = None
for op in onx.opset_import: