mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Enable saving optimized models in OrtModule (#7214)
* Enable saving optimized models in OrtModule Co-authored-by: Sherlock Huang <bahuang@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
parent
ebde320950
commit
a98c2ebb8c
4 changed files with 29 additions and 6 deletions
|
|
@ -94,7 +94,14 @@ std::string ModuleGradientGraphBuilder::GetGradientModel() const {
|
|||
if (!gradient_model_->ToProto().SerializeToString(&model_str)) {
|
||||
ORT_THROW("Fail to serialize gradient model to string.");
|
||||
}
|
||||
return model_str;
|
||||
}
|
||||
|
||||
std::string ModuleGradientGraphBuilder::GetInferenceOptimizedModel() const {
|
||||
std::string model_str;
|
||||
if (!inference_optimized_model_->ToProto().SerializeToString(&model_str)) {
|
||||
ORT_THROW("Fail to serialize inference optimized model to string.");
|
||||
}
|
||||
return model_str;
|
||||
}
|
||||
|
||||
|
|
@ -159,6 +166,9 @@ Status ModuleGradientGraphBuilder::BuildGradientGraph() {
|
|||
graph_transformation_mgr.ApplyTransformers(gradient_graph, static_cast<TransformerLevel>(i), *logger_));
|
||||
}
|
||||
|
||||
// Save a copy of inference optimized model
|
||||
ORT_RETURN_IF_ERROR(Model::Load(gradient_model_->ToProto(), inference_optimized_model_, nullptr, *logger_));
|
||||
|
||||
// Build gradient graph.
|
||||
GradientGraphConfiguration gradient_graph_config{};
|
||||
gradient_graph_config.use_invertible_layernorm_grad = config_.use_invertible_layernorm_grad;
|
||||
|
|
|
|||
|
|
@ -76,6 +76,12 @@ class ModuleGradientGraphBuilder {
|
|||
*/
|
||||
std::string GetGradientModel() const;
|
||||
|
||||
/**
|
||||
* Get inference optimized model.
|
||||
* @return The gradient model serialized to string.
|
||||
*/
|
||||
std::string GetInferenceOptimizedModel() const;
|
||||
|
||||
/**
|
||||
* Get the training graphs information.
|
||||
* @return The training graphs information.
|
||||
|
|
@ -96,6 +102,7 @@ class ModuleGradientGraphBuilder {
|
|||
void ReorderOutputs();
|
||||
|
||||
std::shared_ptr<onnxruntime::Model> model_;
|
||||
std::shared_ptr<onnxruntime::Model> inference_optimized_model_;
|
||||
std::shared_ptr<onnxruntime::Model> gradient_model_;
|
||||
TrainingGraphInfo training_graph_info_;
|
||||
|
||||
|
|
|
|||
|
|
@ -540,6 +540,10 @@ py::class_<TrainingAgent>(m, "TrainingAgent", R"pbdoc(This is the main class use
|
|||
[](ModuleGradientGraphBuilder* module_gradient_graph_builder) {
|
||||
return py::bytes(module_gradient_graph_builder->GetGradientModel());
|
||||
})
|
||||
.def("get_inference_optimized_model",
|
||||
[](ModuleGradientGraphBuilder* module_gradient_graph_builder) {
|
||||
return py::bytes(module_gradient_graph_builder->GetInferenceOptimizedModel());
|
||||
})
|
||||
.def("get_training_graph_info", [](ModuleGradientGraphBuilder* module_gradient_graph_builder) {
|
||||
return module_gradient_graph_builder->GetTrainingGraphInfo();
|
||||
});
|
||||
|
|
|
|||
|
|
@ -366,11 +366,8 @@ class ORTModule(torch.nn.Module):
|
|||
|
||||
def _get_inference_graph_and_init_gradient_graph_builder(self, *inputs, **kwargs):
|
||||
self._onnx_inference = self._get_inference_graph(*inputs, **kwargs)
|
||||
|
||||
if self._save_onnx:
|
||||
onnx.save(self._onnx_inference,
|
||||
self._save_onnx_prefix + '_inference.onnx')
|
||||
|
||||
onnx.save(self._onnx_inference, self._save_onnx_prefix + '_inference.onnx')
|
||||
self._initialize_module_gradient_graph_builder()
|
||||
|
||||
def _create_training_session(self):
|
||||
|
|
@ -397,6 +394,9 @@ class ORTModule(torch.nn.Module):
|
|||
session_options.execution_order = onnxruntime.ExecutionOrder.PRIORITY_BASED
|
||||
# 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
|
||||
session_options.log_severity_level = int(self._verbosity)
|
||||
# enable dumping optimized training graph
|
||||
if self._save_onnx:
|
||||
session_options.optimized_model_filepath = self._save_onnx_prefix + '_training_optimized.onnx'
|
||||
|
||||
self._training_session = onnxruntime.training.TrainingAgent(self._onnx_training.SerializeToString(),
|
||||
session_options, providers, provider_options)
|
||||
|
|
@ -412,8 +412,10 @@ class ORTModule(torch.nn.Module):
|
|||
self._onnx_graphs_info = self._module_gradient_graph_builder.get_training_graph_info()
|
||||
|
||||
if self._save_onnx:
|
||||
onnx.save(self._onnx_training,
|
||||
self._save_onnx_prefix + '_training.onnx')
|
||||
inference_optimized_model = onnx.load_model_from_string(
|
||||
self._module_gradient_graph_builder.get_inference_optimized_model())
|
||||
onnx.save(inference_optimized_model, self._save_onnx_prefix + '_inference_optimized.onnx')
|
||||
onnx.save(self._onnx_training, self._save_onnx_prefix + '_training.onnx')
|
||||
|
||||
def eval(self: T) -> T:
|
||||
self._flattened_output_module.eval()
|
||||
|
|
|
|||
Loading…
Reference in a new issue