diff --git a/onnxruntime/core/graph/graph_utils.cc b/onnxruntime/core/graph/graph_utils.cc index c567b38d1b..ef294e4887 100644 --- a/onnxruntime/core/graph/graph_utils.cc +++ b/onnxruntime/core/graph/graph_utils.cc @@ -755,9 +755,10 @@ bool FindPath(Graph& graph, const Node& node, bool is_input_edge, const std::vec bool RemoveNodesWithOneOutputBottomUp(Graph& graph, const Node& start_node) { std::queue q; std::unordered_set removed_nodes; - q.push(start_node.Index()); - bool is_start_node(true); + NodeIndex start_node_index = start_node.Index(); + q.push(start_node_index); + // From the current node, remove nodes bottom-up util it reaches a node with multiple outputs/graph output. while (!q.empty()) { NodeIndex cur_node_index = q.front(); @@ -781,16 +782,18 @@ bool RemoveNodesWithOneOutputBottomUp(Graph& graph, const Node& start_node) { continue; } const Node* parent_node = GetInputNode(cur_node, i); + if (nullptr == parent_node) { + continue; + } q.push(parent_node->Index()); } - if (is_start_node || cur_node.GetOutputEdgesCount() == 0) { + if (cur_node_index == start_node_index || cur_node.GetOutputEdgesCount() == 0) { Node* cur_node_p = graph.GetNode(cur_node_index); RemoveNodeOutputEdges(graph, *cur_node_p); graph.RemoveNode(cur_node_index); removed_nodes.insert(cur_node_index); - is_start_node = false; } }