mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-02 23:39:58 +00:00
Fix reshape fusion crash (#5252)
* fix reshape fusion crash * handling start_node statelessly * fix
This commit is contained in:
parent
fc259de3bc
commit
87b15f32ef
1 changed files with 7 additions and 4 deletions
|
|
@ -755,9 +755,10 @@ bool FindPath(Graph& graph, const Node& node, bool is_input_edge, const std::vec
|
|||
bool RemoveNodesWithOneOutputBottomUp(Graph& graph, const Node& start_node) {
|
||||
std::queue<NodeIndex> q;
|
||||
std::unordered_set<NodeIndex> removed_nodes;
|
||||
q.push(start_node.Index());
|
||||
|
||||
bool is_start_node(true);
|
||||
NodeIndex start_node_index = start_node.Index();
|
||||
q.push(start_node_index);
|
||||
|
||||
// From the current node, remove nodes bottom-up util it reaches a node with multiple outputs/graph output.
|
||||
while (!q.empty()) {
|
||||
NodeIndex cur_node_index = q.front();
|
||||
|
|
@ -781,16 +782,18 @@ bool RemoveNodesWithOneOutputBottomUp(Graph& graph, const Node& start_node) {
|
|||
continue;
|
||||
}
|
||||
const Node* parent_node = GetInputNode(cur_node, i);
|
||||
if (nullptr == parent_node) {
|
||||
continue;
|
||||
}
|
||||
q.push(parent_node->Index());
|
||||
}
|
||||
|
||||
if (is_start_node || cur_node.GetOutputEdgesCount() == 0) {
|
||||
if (cur_node_index == start_node_index || cur_node.GetOutputEdgesCount() == 0) {
|
||||
Node* cur_node_p = graph.GetNode(cur_node_index);
|
||||
RemoveNodeOutputEdges(graph, *cur_node_p);
|
||||
graph.RemoveNode(cur_node_index);
|
||||
|
||||
removed_nodes.insert(cur_node_index);
|
||||
is_start_node = false;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue