Zhijxu/fix toposort (#18705)

in training, shape/size need to be executed immediately when it's ok to
be executed and thus to save memory if possible;

the toposort logic is enhanced before, while didn't take of the
"shape->size" pattern, which make the following size op will not show up
in toposort result.
This commit is contained in:
zhijiang 2023-12-05 17:36:00 +08:00 committed by GitHub
parent e066fca777
commit 2b3050bb0c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -57,12 +57,14 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info)
: ConstGraphNodes::NodeFilterFunc(nullptr))},
filter_info_{filter_info} {
std::vector<const Node*> leaf_nodes;
#ifdef ENABLE_TRAINING
// Keep the info of shape and size nodes and their parents so that after topological sort, we can move them
// right after their parents. This is to make sure the shape and size nodes are executed right after their parents
// so it's possible the input tensor memory can be released as soon as possible. This is especially important
// for non-CPU devices or for training case where some gradient graphs use only shape/size of tensors from forward.
InlinedHashSet<NodeIndex> shape_size_nodes;
InlinedHashMap<NodeIndex, InlinedVector<NodeIndex>> shape_size_parents;
#endif
for (auto& node : graph_->Nodes()) {
// This is a leaf node (without any output node)
if (node.OutputNodesBegin() == node.OutputNodesEnd()) {
@ -72,6 +74,7 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info)
if (node.InputEdgesBegin() == node.InputEdgesEnd()) {
root_nodes_.push_back(node.Index());
}
#ifdef ENABLE_TRAINING
if ((node.OpType() == "Shape" || node.OpType() == "Size") && node.InputEdgesBegin() != node.InputEdgesEnd()) {
shape_size_nodes.insert(node.Index());
NodeIndex parent = node.InputNodesBegin()->Index();
@ -81,6 +84,7 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info)
shape_size_parents[parent].push_back(node.Index());
}
}
#endif
}
graph.ReverseDFSFrom(
@ -90,21 +94,24 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info)
nodes_in_topological_order_.push_back(n->Index());
},
NodeCompare());
#ifdef ENABLE_TRAINING
auto original = std::move(nodes_in_topological_order_);
nodes_in_topological_order_.reserve(original.size());
InlinedHashSet<NodeIndex> visited;
for (auto& node : original) {
if (shape_size_nodes.find(node) != shape_size_nodes.end()) {
if (visited.find(node) != visited.end()) {
continue;
}
nodes_in_topological_order_.push_back(node);
visited.insert(node);
if (shape_size_parents.find(node) != shape_size_parents.end()) {
for (auto& following_node : shape_size_parents[node]) {
nodes_in_topological_order_.push_back(following_node);
visited.insert(following_node);
}
}
}
#endif
#if !defined(ORT_MINIMAL_BUILD)
graph.KahnsTopologicalSort(
[this](const Node* n) {