From eb41bfb7b585e8eed3e876f0fd2b085ee2879fc9 Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Sun, 19 Jun 2022 19:28:18 -0700 Subject: [PATCH] Fix graph viewer to proto (#11862) * Add test for case where main const initialier in subgraph * update test to use trt ep * add initializer when converting from graph viewer to proto * add comments * add comments * add comments * only add initialier that is outer scope value * make including outer scope value optional * modify python format * modify python format * modify python format * Remove test * remove redundant argument --- .../core/graph/graph_proto_serializer.cc | 23 ++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/graph/graph_proto_serializer.cc b/onnxruntime/core/graph/graph_proto_serializer.cc index 89eb20f734..aad52b6174 100644 --- a/onnxruntime/core/graph/graph_proto_serializer.cc +++ b/onnxruntime/core/graph/graph_proto_serializer.cc @@ -43,12 +43,33 @@ void GraphViewerToProto(const GraphViewer& graph_view, } if (include_initializer) { + std::unordered_set current_scope_initializer_set; + auto& initializers = graph_view.GetAllInitializedTensors(); for (auto& it : initializers) { auto* p_initializer = graph_proto.add_initializer(); *p_initializer = *(it.second); + current_scope_initializer_set.insert(it.first); + } + + + // handle outer scope value which is a constant initializer + if (include_outer_scope_args) { + for (auto& node_idx : graph_view.GetNodesInTopologicalOrder()) { + const auto& node = graph_view.GetNode(node_idx); + for (const auto& input : node->InputDefs()) { + if (current_scope_initializer_set.find(input->Name()) != current_scope_initializer_set.end()) { + continue; + } + if (graph_view.IsConstantInitializer(input->Name(), true)) { + auto* p_initializer = graph_proto.add_initializer(); + *p_initializer = *(graph_view.GetConstantInitializer(input->Name(), true)); + current_scope_initializer_set.insert(input->Name()); + } + } + } } } } -} \ No newline at end of file +}