Enable saving optimized models in OrtModule (#7214)

* Enable saving optimized models in OrtModule

Co-authored-by: Sherlock Huang <bahuang@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
Sherlock 2021-04-02 12:37:05 -07:00 committed by GitHub
parent ebde320950
commit a98c2ebb8c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 29 additions and 6 deletions

View file

@ -94,7 +94,14 @@ std::string ModuleGradientGraphBuilder::GetGradientModel() const {
if (!gradient_model_->ToProto().SerializeToString(&model_str)) {
ORT_THROW("Fail to serialize gradient model to string.");
}
return model_str;
}
std::string ModuleGradientGraphBuilder::GetInferenceOptimizedModel() const {
std::string model_str;
if (!inference_optimized_model_->ToProto().SerializeToString(&model_str)) {
ORT_THROW("Fail to serialize inference optimized model to string.");
}
return model_str;
}
@ -159,6 +166,9 @@ Status ModuleGradientGraphBuilder::BuildGradientGraph() {
graph_transformation_mgr.ApplyTransformers(gradient_graph, static_cast<TransformerLevel>(i), *logger_));
}
// Save a copy of inference optimized model
ORT_RETURN_IF_ERROR(Model::Load(gradient_model_->ToProto(), inference_optimized_model_, nullptr, *logger_));
// Build gradient graph.
GradientGraphConfiguration gradient_graph_config{};
gradient_graph_config.use_invertible_layernorm_grad = config_.use_invertible_layernorm_grad;

View file

@ -76,6 +76,12 @@ class ModuleGradientGraphBuilder {
*/
std::string GetGradientModel() const;
/**
* Get inference optimized model.
* @return The gradient model serialized to string.
*/
std::string GetInferenceOptimizedModel() const;
/**
* Get the training graphs information.
* @return The training graphs information.
@ -96,6 +102,7 @@ class ModuleGradientGraphBuilder {
void ReorderOutputs();
std::shared_ptr<onnxruntime::Model> model_;
std::shared_ptr<onnxruntime::Model> inference_optimized_model_;
std::shared_ptr<onnxruntime::Model> gradient_model_;
TrainingGraphInfo training_graph_info_;

View file

@ -540,6 +540,10 @@ py::class_<TrainingAgent>(m, "TrainingAgent", R"pbdoc(This is the main class use
[](ModuleGradientGraphBuilder* module_gradient_graph_builder) {
return py::bytes(module_gradient_graph_builder->GetGradientModel());
})
.def("get_inference_optimized_model",
[](ModuleGradientGraphBuilder* module_gradient_graph_builder) {
return py::bytes(module_gradient_graph_builder->GetInferenceOptimizedModel());
})
.def("get_training_graph_info", [](ModuleGradientGraphBuilder* module_gradient_graph_builder) {
return module_gradient_graph_builder->GetTrainingGraphInfo();
});

View file

@ -366,11 +366,8 @@ class ORTModule(torch.nn.Module):
def _get_inference_graph_and_init_gradient_graph_builder(self, *inputs, **kwargs):
self._onnx_inference = self._get_inference_graph(*inputs, **kwargs)
if self._save_onnx:
onnx.save(self._onnx_inference,
self._save_onnx_prefix + '_inference.onnx')
onnx.save(self._onnx_inference, self._save_onnx_prefix + '_inference.onnx')
self._initialize_module_gradient_graph_builder()
def _create_training_session(self):
@ -397,6 +394,9 @@ class ORTModule(torch.nn.Module):
session_options.execution_order = onnxruntime.ExecutionOrder.PRIORITY_BASED
# 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
session_options.log_severity_level = int(self._verbosity)
# enable dumping optimized training graph
if self._save_onnx:
session_options.optimized_model_filepath = self._save_onnx_prefix + '_training_optimized.onnx'
self._training_session = onnxruntime.training.TrainingAgent(self._onnx_training.SerializeToString(),
session_options, providers, provider_options)
@ -412,8 +412,10 @@ class ORTModule(torch.nn.Module):
self._onnx_graphs_info = self._module_gradient_graph_builder.get_training_graph_info()
if self._save_onnx:
onnx.save(self._onnx_training,
self._save_onnx_prefix + '_training.onnx')
inference_optimized_model = onnx.load_model_from_string(
self._module_gradient_graph_builder.get_inference_optimized_model())
onnx.save(inference_optimized_model, self._save_onnx_prefix + '_inference_optimized.onnx')
onnx.save(self._onnx_training, self._save_onnx_prefix + '_training.onnx')
def eval(self: T) -> T:
self._flattened_output_module.eval()