Improve readability of Graph::PerformTopologicalSortAndCheckIsAcyclic. (#8187)

This commit is contained in:
Scott McKay 2021-06-30 12:15:17 +10:00 committed by GitHub
parent 9b19241b27
commit 17d4545ccb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1673,13 +1673,12 @@ void Graph::KahnsTopologicalSort(const std::function<void(const Node*)>& enter,
GSL_SUPPRESS(es .84) // noisy warning about ignoring return value from insert(...)
Status Graph::PerformTopologicalSortAndCheckIsAcyclic() {
nodes_in_topological_order_.clear();
// nodes that have been processed and added to nodes_in_topological_order.
std::unordered_set<NodeIndex> processed_nodes;
std::unordered_set<NodeIndex> output_nodes;
std::unordered_set<NodeIndex> nodes_added_for_processing;
std::unordered_set<NodeIndex> downstream_nodes; // nodes downstream of the node we're currently checking
std::unordered_set<NodeIndex> nodes_seen; // nodes we have seen but may not have been added to nodes_added yet
std::unordered_set<NodeIndex> nodes_added; // nodes added to topo order
std::stack<NodeIndex> stack;
// push the top level nodes into nodes_in_topological_order in the order they were added
// push the root nodes into nodes_in_topological_order in the order they were defined in the model
// to ensure that is consistent.
auto& nodes_in_original_order = Nodes();
std::for_each(nodes_in_original_order.cbegin(), nodes_in_original_order.cend(),
@ -1690,42 +1689,41 @@ Status Graph::PerformTopologicalSortAndCheckIsAcyclic() {
// need to also consider nodes that only have Constants as inputs as top level nodes,
// as the constant will get replaced by an initializer.
auto input_edges = node.GetRelationships().input_edges;
auto has_inputs = std::any_of(input_edges.cbegin(), input_edges.cend(), [](const Node::EdgeEnd& edge) {
return edge.GetNode().OpType() != kConstant;
});
auto has_inputs = std::any_of(input_edges.cbegin(), input_edges.cend(),
[](const Node::EdgeEnd& edge) {
return edge.GetNode().OpType() != kConstant;
});
if (!has_inputs) {
// add to the topological list, and ensure we skip these nodes when walking the graph
nodes_in_topological_order_.push_back(index);
processed_nodes.insert(index);
// mark this as added as we've fully processed it and don't need to do it again later
nodes_added_for_processing.insert(index);
nodes_added.insert(index);
nodes_seen.insert(index);
}
});
// start at the bottom and work our way up the graph
// find all the leaf nodes (nodes with no output edges as there's no edge to a graph output)
for (auto iter = Nodes().begin(); iter != Nodes().end(); ++iter) {
if (iter->relationships_.output_edges.empty()) {
// This is a leaf node.
stack.push(iter->Index());
}
}
// work our way up from the leaf nodes
while (!stack.empty()) {
const NodeIndex current = stack.top();
stack.pop();
if (processed_nodes.find(current) != processed_nodes.end()) {
if (nodes_added.find(current) != nodes_added.end()) {
continue;
}
if (nodes_added_for_processing.find(current) != nodes_added_for_processing.end()) {
// we popped the stack and are back to a node that was added previously,
// so we know all the upstream nodes from it have been fully processed,
if (nodes_seen.find(current) != nodes_seen.end()) {
// we popped the stack and are back to a node that was seen previously,
// so we know all the upstream nodes from it have been added.
nodes_in_topological_order_.push_back(current);
processed_nodes.insert(current);
output_nodes.erase(current);
nodes_added.insert(current);
downstream_nodes.erase(current);
continue;
}
@ -1734,28 +1732,32 @@ Status Graph::PerformTopologicalSortAndCheckIsAcyclic() {
continue;
}
stack.push(current);
output_nodes.insert(current);
// node hasn't been seen before, so mark it as seen and re-add it along with its inputs
// also mark it as downstream of anything new that is added to the stack to detect acyclic graphs
nodes_seen.insert(current);
downstream_nodes.insert(current);
for (auto iter = node->InputNodesBegin(); iter != node->InputNodesEnd(); ++iter) {
const NodeIndex idx = (*iter).Index();
if (output_nodes.find(idx) != output_nodes.end()) {
stack.push(current);
for (auto iter = node->InputNodesBegin(), end = node->InputNodesEnd(); iter != end; ++iter) {
const NodeIndex idx = iter->Index();
// the input to this node is also downstream of this node
if (downstream_nodes.find(idx) != downstream_nodes.end()) {
Status status(ONNXRUNTIME, FAIL, "This is an invalid model. Error: the graph is not acyclic.");
return status;
}
// avoid re-processing nodes
if (nodes_added_for_processing.find(idx) == nodes_added_for_processing.end()) {
if (nodes_seen.find(idx) == nodes_seen.end()) {
stack.push(idx);
}
}
nodes_added_for_processing.insert(current);
}
if (num_of_nodes_ >= 0 && static_cast<size_t>(num_of_nodes_) == nodes_in_topological_order_.size()) {
return Status::OK();
}
return Status(ONNXRUNTIME, FAIL, "This is an invalid model. Error: the graph is not acyclic.");
}