diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 7dd9aa9f55..9557b0e3ec 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1673,13 +1673,12 @@ void Graph::KahnsTopologicalSort(const std::function& enter, GSL_SUPPRESS(es .84) // noisy warning about ignoring return value from insert(...) Status Graph::PerformTopologicalSortAndCheckIsAcyclic() { nodes_in_topological_order_.clear(); - // nodes that have been processed and added to nodes_in_topological_order. - std::unordered_set processed_nodes; - std::unordered_set output_nodes; - std::unordered_set nodes_added_for_processing; + std::unordered_set downstream_nodes; // nodes downstream of the node we're currently checking + std::unordered_set nodes_seen; // nodes we have seen but may not have been added to nodes_added yet + std::unordered_set nodes_added; // nodes added to topo order std::stack stack; - // push the top level nodes into nodes_in_topological_order in the order they were added + // push the root nodes into nodes_in_topological_order in the order they were defined in the model // to ensure that is consistent. auto& nodes_in_original_order = Nodes(); std::for_each(nodes_in_original_order.cbegin(), nodes_in_original_order.cend(), @@ -1690,42 +1689,41 @@ Status Graph::PerformTopologicalSortAndCheckIsAcyclic() { // need to also consider nodes that only have Constants as inputs as top level nodes, // as the constant will get replaced by an initializer. auto input_edges = node.GetRelationships().input_edges; - auto has_inputs = std::any_of(input_edges.cbegin(), input_edges.cend(), [](const Node::EdgeEnd& edge) { - return edge.GetNode().OpType() != kConstant; - }); + auto has_inputs = std::any_of(input_edges.cbegin(), input_edges.cend(), + [](const Node::EdgeEnd& edge) { + return edge.GetNode().OpType() != kConstant; + }); if (!has_inputs) { // add to the topological list, and ensure we skip these nodes when walking the graph nodes_in_topological_order_.push_back(index); - processed_nodes.insert(index); - - // mark this as added as we've fully processed it and don't need to do it again later - nodes_added_for_processing.insert(index); + nodes_added.insert(index); + nodes_seen.insert(index); } }); - // start at the bottom and work our way up the graph + // find all the leaf nodes (nodes with no output edges as there's no edge to a graph output) for (auto iter = Nodes().begin(); iter != Nodes().end(); ++iter) { if (iter->relationships_.output_edges.empty()) { - // This is a leaf node. stack.push(iter->Index()); } } + // work our way up from the leaf nodes while (!stack.empty()) { const NodeIndex current = stack.top(); stack.pop(); - if (processed_nodes.find(current) != processed_nodes.end()) { + if (nodes_added.find(current) != nodes_added.end()) { continue; } - if (nodes_added_for_processing.find(current) != nodes_added_for_processing.end()) { - // we popped the stack and are back to a node that was added previously, - // so we know all the upstream nodes from it have been fully processed, + if (nodes_seen.find(current) != nodes_seen.end()) { + // we popped the stack and are back to a node that was seen previously, + // so we know all the upstream nodes from it have been added. nodes_in_topological_order_.push_back(current); - processed_nodes.insert(current); - output_nodes.erase(current); + nodes_added.insert(current); + downstream_nodes.erase(current); continue; } @@ -1734,28 +1732,32 @@ Status Graph::PerformTopologicalSortAndCheckIsAcyclic() { continue; } - stack.push(current); - output_nodes.insert(current); + // node hasn't been seen before, so mark it as seen and re-add it along with its inputs + // also mark it as downstream of anything new that is added to the stack to detect acyclic graphs + nodes_seen.insert(current); + downstream_nodes.insert(current); - for (auto iter = node->InputNodesBegin(); iter != node->InputNodesEnd(); ++iter) { - const NodeIndex idx = (*iter).Index(); - if (output_nodes.find(idx) != output_nodes.end()) { + stack.push(current); + + for (auto iter = node->InputNodesBegin(), end = node->InputNodesEnd(); iter != end; ++iter) { + const NodeIndex idx = iter->Index(); + // the input to this node is also downstream of this node + if (downstream_nodes.find(idx) != downstream_nodes.end()) { Status status(ONNXRUNTIME, FAIL, "This is an invalid model. Error: the graph is not acyclic."); return status; } // avoid re-processing nodes - if (nodes_added_for_processing.find(idx) == nodes_added_for_processing.end()) { + if (nodes_seen.find(idx) == nodes_seen.end()) { stack.push(idx); } } - - nodes_added_for_processing.insert(current); } if (num_of_nodes_ >= 0 && static_cast(num_of_nodes_) == nodes_in_topological_order_.size()) { return Status::OK(); } + return Status(ONNXRUNTIME, FAIL, "This is an invalid model. Error: the graph is not acyclic."); }