diff --git a/include/onnxruntime/core/graph/graph_viewer.h b/include/onnxruntime/core/graph/graph_viewer.h index c0ace1ffdb..78b114e8dd 100644 --- a/include/onnxruntime/core/graph/graph_viewer.h +++ b/include/onnxruntime/core/graph/graph_viewer.h @@ -95,7 +95,7 @@ class GraphViewer { */ const ConstGraphNodes& Nodes() const noexcept; - /** Gets the number of valid nodes in the Graph. + /** Gets the number of valid nodes in the Graph. @remarks Returns the number of nodes in filter_info_ if set. */ int NumberOfNodes() const noexcept; @@ -103,7 +103,7 @@ class GraphViewer { /** Gets the maximum NodeIndex value used by Nodes in the Graph. */ int MaxNodeIndex() const noexcept; - /** Gets the NodeIndex values for the Graph nodes, sorted into topological order. + /** Gets the NodeIndex values for the Graph nodes, sorted into topological order. @remarks Filtered using filter_info_ if set. */ const std::vector& GetNodesInTopologicalOrder(ExecutionOrder order = ExecutionOrder::DEFAULT) const; @@ -138,7 +138,7 @@ class GraphViewer { /** returns true if 'name' is an initializer, and is constant and cannot be overridden at runtime. - @param check_outer_scope If true and the 'graph_' is a subgraph, check parent graph/s for 'name' + @param check_outer_scope If true and the 'graph_' is a subgraph, check parent graph/s for 'name' if the name is not found in 'graph_'. */ bool IsConstantInitializer(const std::string& name, bool check_outer_scope) const; @@ -188,5 +188,6 @@ class GraphViewer { std::vector filtered_node_inputs_; std::vector filtered_node_inputs_including_initializers_; std::vector filtered_node_outputs_; + InitializedTensorSet filtered_initializers_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index 6c89a40655..0b25b51a80 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -121,6 +121,26 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) std::copy_if(orig_order.cbegin(), orig_order.cend(), std::back_inserter(nodes_in_topological_order_), [this](NodeIndex idx) { return filtered_node_indices_.count(idx) != 0; }); + // Filter the initializers also + // Get the names of all the inputs and implicit inputs of all the nodes in this subgraph + for (const auto node_idx : filtered_node_indices_) { + const auto* node = GetNode(node_idx); + ORT_ENFORCE(node, "Mismatch between Graph and IndexedSubGraph. Node not found: ", node_idx); + const ONNX_NAMESPACE::TensorProto* tensor = nullptr; + for (const auto* node_input : node->InputDefs()) { + if (graph.GetInitializedTensor(node_input->Name(), tensor)) { + filtered_initializers_.insert({node_input->Name(), tensor}); + } + } + + // The implicit inputs for subgraphs (if any) + for (const auto* node_input : node->ImplicitInputDefs()) { + if (graph.GetInitializedTensor(node_input->Name(), tensor)) { + filtered_initializers_.insert({node_input->Name(), tensor}); + } + } + } + #if !defined(ORT_MINIMAL_BUILD) auto orig_priority_order = std::move(nodes_in_topological_order_with_priority_); nodes_in_topological_order_with_priority_.reserve(filter_info->nodes.size()); @@ -146,6 +166,10 @@ const std::string& GraphViewer::Description() const noexcept { bool GraphViewer::GetInitializedTensor(const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) const { + // if we are using filtered subgraph, the initializer has to be part of the subgraph + if (filter_info_ != nullptr && filtered_initializers_.find(tensor_name) == filtered_initializers_.cend()) + return false; + return graph_->GetInitializedTensor(tensor_name, value); } @@ -220,7 +244,9 @@ const std::vector& GraphViewer::GetRootNodes() const { } const InitializedTensorSet& GraphViewer::GetAllInitializedTensors() const noexcept { - return graph_->GetAllInitializedTensors(); + return (filter_info_ == nullptr) + ? graph_->GetAllInitializedTensors() + : filtered_initializers_; } const NodeArg* GraphViewer::GetNodeArg(const std::string& name) const { diff --git a/onnxruntime/test/ir/graph_viewer_test.cc b/onnxruntime/test/ir/graph_viewer_test.cc index 9ffe6f82e2..eba3b70bbe 100644 --- a/onnxruntime/test/ir/graph_viewer_test.cc +++ b/onnxruntime/test/ir/graph_viewer_test.cc @@ -92,6 +92,14 @@ TEST(GraphViewer, FilteredGraph) { EXPECT_EQ(viewer.GetOutputs().size(), final_metadef->outputs.size()); EXPECT_EQ(viewer.IsSubgraph(), false) << "GraphViewer is for a filtered set of nodes of a single graph and not a nested subgraph"; + + // Verify the viewer's initializers are filtered as well + const auto& viewer_initializers = viewer.GetAllInitializedTensors(); + EXPECT_EQ(viewer_initializers.size(), initializers.size()); + // We should have less initializers in the viewer than the underlying graph + EXPECT_LT(viewer_initializers.size(), graph.GetAllInitializedTensors().size()); + // Pick a initializers which is not in the viewer, and check it is not part of the viewer's initializers + EXPECT_TRUE(viewer_initializers.count("Constant15770PastValue16469") == 0); } } // namespace test } // namespace onnxruntime