mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
Improve readability of Graph::PerformTopologicalSortAndCheckIsAcyclic. (#8187)
This commit is contained in:
parent
9b19241b27
commit
17d4545ccb
1 changed files with 30 additions and 28 deletions
|
|
@ -1673,13 +1673,12 @@ void Graph::KahnsTopologicalSort(const std::function<void(const Node*)>& enter,
|
|||
GSL_SUPPRESS(es .84) // noisy warning about ignoring return value from insert(...)
|
||||
Status Graph::PerformTopologicalSortAndCheckIsAcyclic() {
|
||||
nodes_in_topological_order_.clear();
|
||||
// nodes that have been processed and added to nodes_in_topological_order.
|
||||
std::unordered_set<NodeIndex> processed_nodes;
|
||||
std::unordered_set<NodeIndex> output_nodes;
|
||||
std::unordered_set<NodeIndex> nodes_added_for_processing;
|
||||
std::unordered_set<NodeIndex> downstream_nodes; // nodes downstream of the node we're currently checking
|
||||
std::unordered_set<NodeIndex> nodes_seen; // nodes we have seen but may not have been added to nodes_added yet
|
||||
std::unordered_set<NodeIndex> nodes_added; // nodes added to topo order
|
||||
std::stack<NodeIndex> stack;
|
||||
|
||||
// push the top level nodes into nodes_in_topological_order in the order they were added
|
||||
// push the root nodes into nodes_in_topological_order in the order they were defined in the model
|
||||
// to ensure that is consistent.
|
||||
auto& nodes_in_original_order = Nodes();
|
||||
std::for_each(nodes_in_original_order.cbegin(), nodes_in_original_order.cend(),
|
||||
|
|
@ -1690,42 +1689,41 @@ Status Graph::PerformTopologicalSortAndCheckIsAcyclic() {
|
|||
// need to also consider nodes that only have Constants as inputs as top level nodes,
|
||||
// as the constant will get replaced by an initializer.
|
||||
auto input_edges = node.GetRelationships().input_edges;
|
||||
auto has_inputs = std::any_of(input_edges.cbegin(), input_edges.cend(), [](const Node::EdgeEnd& edge) {
|
||||
return edge.GetNode().OpType() != kConstant;
|
||||
});
|
||||
auto has_inputs = std::any_of(input_edges.cbegin(), input_edges.cend(),
|
||||
[](const Node::EdgeEnd& edge) {
|
||||
return edge.GetNode().OpType() != kConstant;
|
||||
});
|
||||
|
||||
if (!has_inputs) {
|
||||
// add to the topological list, and ensure we skip these nodes when walking the graph
|
||||
nodes_in_topological_order_.push_back(index);
|
||||
processed_nodes.insert(index);
|
||||
|
||||
// mark this as added as we've fully processed it and don't need to do it again later
|
||||
nodes_added_for_processing.insert(index);
|
||||
nodes_added.insert(index);
|
||||
nodes_seen.insert(index);
|
||||
}
|
||||
});
|
||||
|
||||
// start at the bottom and work our way up the graph
|
||||
// find all the leaf nodes (nodes with no output edges as there's no edge to a graph output)
|
||||
for (auto iter = Nodes().begin(); iter != Nodes().end(); ++iter) {
|
||||
if (iter->relationships_.output_edges.empty()) {
|
||||
// This is a leaf node.
|
||||
stack.push(iter->Index());
|
||||
}
|
||||
}
|
||||
|
||||
// work our way up from the leaf nodes
|
||||
while (!stack.empty()) {
|
||||
const NodeIndex current = stack.top();
|
||||
stack.pop();
|
||||
|
||||
if (processed_nodes.find(current) != processed_nodes.end()) {
|
||||
if (nodes_added.find(current) != nodes_added.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (nodes_added_for_processing.find(current) != nodes_added_for_processing.end()) {
|
||||
// we popped the stack and are back to a node that was added previously,
|
||||
// so we know all the upstream nodes from it have been fully processed,
|
||||
if (nodes_seen.find(current) != nodes_seen.end()) {
|
||||
// we popped the stack and are back to a node that was seen previously,
|
||||
// so we know all the upstream nodes from it have been added.
|
||||
nodes_in_topological_order_.push_back(current);
|
||||
processed_nodes.insert(current);
|
||||
output_nodes.erase(current);
|
||||
nodes_added.insert(current);
|
||||
downstream_nodes.erase(current);
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
@ -1734,28 +1732,32 @@ Status Graph::PerformTopologicalSortAndCheckIsAcyclic() {
|
|||
continue;
|
||||
}
|
||||
|
||||
stack.push(current);
|
||||
output_nodes.insert(current);
|
||||
// node hasn't been seen before, so mark it as seen and re-add it along with its inputs
|
||||
// also mark it as downstream of anything new that is added to the stack to detect acyclic graphs
|
||||
nodes_seen.insert(current);
|
||||
downstream_nodes.insert(current);
|
||||
|
||||
for (auto iter = node->InputNodesBegin(); iter != node->InputNodesEnd(); ++iter) {
|
||||
const NodeIndex idx = (*iter).Index();
|
||||
if (output_nodes.find(idx) != output_nodes.end()) {
|
||||
stack.push(current);
|
||||
|
||||
for (auto iter = node->InputNodesBegin(), end = node->InputNodesEnd(); iter != end; ++iter) {
|
||||
const NodeIndex idx = iter->Index();
|
||||
// the input to this node is also downstream of this node
|
||||
if (downstream_nodes.find(idx) != downstream_nodes.end()) {
|
||||
Status status(ONNXRUNTIME, FAIL, "This is an invalid model. Error: the graph is not acyclic.");
|
||||
return status;
|
||||
}
|
||||
|
||||
// avoid re-processing nodes
|
||||
if (nodes_added_for_processing.find(idx) == nodes_added_for_processing.end()) {
|
||||
if (nodes_seen.find(idx) == nodes_seen.end()) {
|
||||
stack.push(idx);
|
||||
}
|
||||
}
|
||||
|
||||
nodes_added_for_processing.insert(current);
|
||||
}
|
||||
|
||||
if (num_of_nodes_ >= 0 && static_cast<size_t>(num_of_nodes_) == nodes_in_topological_order_.size()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
return Status(ONNXRUNTIME, FAIL, "This is an invalid model. Error: the graph is not acyclic.");
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue