From 9c6cc018a9a71f2d3b36647b83ef60659ebb2a4c Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Fri, 25 Mar 2022 17:13:56 -0700 Subject: [PATCH] Add utility to get the gradient graph from GradientGraphBuilder (#10995) * Add pybind method to get the gradient graph * Fix segmentation fault because of logging for gradien building --- .../orttraining/python/orttraining_pybind_state.cc | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index bc7432ad91..08fe4eae99 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -782,16 +782,16 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn const std::unordered_set& x_node_arg_names, const std::string loss_node_arg_name) { std::shared_ptr model; - auto logger = logging::LoggingManager::DefaultLogger(); + auto logger_ptr = std::make_unique(logging::LoggingManager::DefaultLogger()); + logger_ptr->SetSeverity(logging::Severity::kINFO); ONNX_NAMESPACE::ModelProto model_proto; std::istringstream model_istream(serialized_model); ORT_THROW_IF_ERROR(Model::Load(model_istream, &model_proto)); - ORT_THROW_IF_ERROR(Model::Load(model_proto, model, nullptr, logger)); + ORT_THROW_IF_ERROR(Model::Load(model_proto, model, nullptr, *logger_ptr)); GradientGraphConfiguration gradient_graph_config{}; gradient_graph_config.set_gradients_as_graph_outputs = true; // Save some objects, otherwise they get lost. - auto gradient_graph_config_ptr = std::make_unique(gradient_graph_config); - auto logger_ptr = std::make_unique(logger); + auto gradient_graph_config_ptr = std::make_unique(gradient_graph_config); auto builder = std::make_unique( &model->MainGraph(), @@ -808,6 +808,11 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn }) .def("save", [](PyGradientGraphBuilder* gradient_graph_builder, const std::string& path) { ORT_THROW_IF_ERROR(Model::Save(*(gradient_graph_builder->model), path)); + }) + .def("get_model", [](PyGradientGraphBuilder* gradient_graph_builder) { + std::string model_str; + gradient_graph_builder->model->ToProto().SerializeToString(&model_str); + return py::bytes(model_str); }); py::class_ gradient_node_attribute_definition(