mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
Save optimized pre_grad graph once ready (#16816)
### Save optimized pre_grad graph once it's ready `graph_builder.build()` did two things for training: 1. optimized forward graph, e.g. pre_grad graph optimization. 2. build gradient graph. Originally after `graph_builder.build()` completed, pre_graph graph is saved. While if pre_grad graph optimization completed, but fail during gradient graph build, we still cannot get pre_grad graph to investigate. This PR made the change once pre_grad graph is ready, we save it (if save_model is enabled) in C++ backend.
This commit is contained in:
parent
ba49d64f67
commit
b9d80131a7
7 changed files with 22 additions and 14 deletions
|
|
@ -182,6 +182,10 @@ Status OrtModuleGraphBuilder::OptimizeForwardGraph(const TrainingGraphTransforme
|
|||
graph_transformation_mgr.ApplyTransformers(forward_graph, static_cast<TransformerLevel>(i), *logger_));
|
||||
}
|
||||
|
||||
if (!config.optimized_pre_grad_filepath.empty()) {
|
||||
ORT_RETURN_IF_ERROR(Model::Save(*forward_model_, config.optimized_pre_grad_filepath));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -30,6 +30,9 @@ struct TrainingGraphTransformerConfiguration : public GraphTransformerConfigurat
|
|||
|
||||
// Enable label sparsity compute optimization for the input names in the below list.
|
||||
std::vector<std::string> sparse_label_input_names;
|
||||
|
||||
// Path for serialization of the transformed optimized model. If empty, serialization is disabled.
|
||||
std::string optimized_pre_grad_filepath;
|
||||
};
|
||||
|
||||
} // namespace training
|
||||
|
|
|
|||
|
|
@ -769,6 +769,7 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
|
|||
.def_readwrite("enable_compute_optimizer", &TrainingGraphTransformerConfiguration::enable_compute_optimizer)
|
||||
.def_readwrite("sparse_embedding_input_names", &TrainingGraphTransformerConfiguration::sparse_embedding_input_names)
|
||||
.def_readwrite("sparse_label_input_names", &TrainingGraphTransformerConfiguration::sparse_label_input_names)
|
||||
.def_readwrite("optimized_pre_grad_filepath", &TrainingGraphTransformerConfiguration::optimized_pre_grad_filepath)
|
||||
.def_readwrite("propagate_cast_ops_config", &TrainingGraphTransformerConfiguration::GraphTransformerConfiguration::propagate_cast_ops_config);
|
||||
|
||||
py::class_<OrtModuleGraphBuilderConfiguration> module_graph_builder_config(
|
||||
|
|
|
|||
|
|
@ -431,6 +431,15 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
graph_transformer_config.propagate_cast_ops_config.allow = self._runtime_options.propagate_cast_ops_allow
|
||||
graph_transformer_config.propagate_cast_ops_config.strategy = self._runtime_options.propagate_cast_ops_strategy
|
||||
graph_transformer_config.enable_compute_optimizer = self._runtime_options.enable_compute_optimizer
|
||||
|
||||
if self._debug_options.save_onnx_models.save:
|
||||
graph_transformer_config.optimized_pre_grad_filepath = os.path.join(
|
||||
self._debug_options.save_onnx_models.path,
|
||||
_onnx_models._get_onnx_file_name(
|
||||
self._debug_options.save_onnx_models.name_prefix, "optimized_pre_grad", self._export_mode
|
||||
),
|
||||
)
|
||||
|
||||
return graph_transformer_config
|
||||
|
||||
@_logger.TrackTime(_logger.ORTModuleInitPhase.GRAPH_BUILDER_INIT)
|
||||
|
|
|
|||
|
|
@ -26,15 +26,14 @@ class ONNXModels:
|
|||
1. exported_model: Model that is exported by torch.onnx.export
|
||||
2. optimized_model: For eval mode it's exported_model with concrete input shapes set if needed,
|
||||
for training mode, it's optimized model after gradients graph has been built.
|
||||
3. optimized_pre_grad_model: Optimized model before gradient graph is built. It's None for eval mode.
|
||||
4. In addition, ORTModule also saves the execution_model which is the model
|
||||
that is being executed by ORT. It has further optimizations done by the
|
||||
InferenceSession and is saved by the InferenceSession.
|
||||
In addition, ORTModule also saves two other models, to the user-provided path:
|
||||
a. the pre_grad_model which is the model before the gradients graph is built.
|
||||
b. the execution_model which is the model that is being executed by ORT.
|
||||
It has further optimizations done by the InferenceSession and is saved by the InferenceSession.
|
||||
"""
|
||||
|
||||
exported_model: Optional[onnx.ModelProto] = None
|
||||
optimized_model: Optional[onnx.ModelProto] = None
|
||||
optimized_pre_grad_model: Optional[onnx.ModelProto] = None
|
||||
|
||||
def save_exported_model(self, path, name_prefix, export_mode):
|
||||
# save the ortmodule exported model
|
||||
|
|
@ -47,8 +46,3 @@ class ONNXModels:
|
|||
_save_model(
|
||||
self.optimized_model, os.path.join(path, _get_onnx_file_name(name_prefix, "optimized", export_mode))
|
||||
)
|
||||
if self.optimized_pre_grad_model is not None:
|
||||
_save_model(
|
||||
self.optimized_pre_grad_model,
|
||||
os.path.join(path, _get_onnx_file_name(name_prefix, "optimized_pre_grad", export_mode)),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -360,9 +360,6 @@ class TrainingManager(GraphExecutionManager):
|
|||
|
||||
super()._build_graph(graph_transformer_config)
|
||||
self._onnx_models.optimized_model = onnx.load_model_from_string(self._graph_builder.get_gradient_model())
|
||||
self._onnx_models.optimized_pre_grad_model = onnx.load_model_from_string(
|
||||
self._graph_builder.get_forward_model()
|
||||
)
|
||||
|
||||
# Apply registered graph transformers to the optimized model
|
||||
device_type = self._device.type
|
||||
|
|
|
|||
|
|
@ -5232,7 +5232,7 @@ def test_sigmoid_grad_opset13():
|
|||
pt_prediction, pt_loss = run_step(pt_model, pt_x)
|
||||
if step == 0:
|
||||
model_onx = ort_model._torch_module._execution_manager._training_manager._onnx_models
|
||||
for name in ["exported_model", "optimized_model", "optimized_pre_grad_model"]:
|
||||
for name in ["exported_model", "optimized_model"]:
|
||||
onx = getattr(model_onx, name)
|
||||
opv = None
|
||||
for op in onx.opset_import:
|
||||
|
|
|
|||
Loading…
Reference in a new issue