Fix reshape fusion crash (#5252)

* fix reshape fusion crash

* handling start_node statelessly

* fix
This commit is contained in:
Ye Wang 2020-09-22 15:04:13 -07:00 committed by Tianlei Wu
parent fc259de3bc
commit 87b15f32ef

View file

@ -755,9 +755,10 @@ bool FindPath(Graph& graph, const Node& node, bool is_input_edge, const std::vec
bool RemoveNodesWithOneOutputBottomUp(Graph& graph, const Node& start_node) {
std::queue<NodeIndex> q;
std::unordered_set<NodeIndex> removed_nodes;
q.push(start_node.Index());
bool is_start_node(true);
NodeIndex start_node_index = start_node.Index();
q.push(start_node_index);
// From the current node, remove nodes bottom-up util it reaches a node with multiple outputs/graph output.
while (!q.empty()) {
NodeIndex cur_node_index = q.front();
@ -781,16 +782,18 @@ bool RemoveNodesWithOneOutputBottomUp(Graph& graph, const Node& start_node) {
continue;
}
const Node* parent_node = GetInputNode(cur_node, i);
if (nullptr == parent_node) {
continue;
}
q.push(parent_node->Index());
}
if (is_start_node || cur_node.GetOutputEdgesCount() == 0) {
if (cur_node_index == start_node_index || cur_node.GetOutputEdgesCount() == 0) {
Node* cur_node_p = graph.GetNode(cur_node_index);
RemoveNodeOutputEdges(graph, *cur_node_p);
graph.RemoveNode(cur_node_index);
removed_nodes.insert(cur_node_index);
is_start_node = false;
}
}