mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Zhijxu/fix toposort (#18705)
in training, shape/size need to be executed immediately when it's ok to be executed and thus to save memory if possible; the toposort logic is enhanced before, while didn't take of the "shape->size" pattern, which make the following size op will not show up in toposort result.
This commit is contained in:
parent
e066fca777
commit
2b3050bb0c
1 changed files with 10 additions and 3 deletions
|
|
@ -57,12 +57,14 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info)
|
|||
: ConstGraphNodes::NodeFilterFunc(nullptr))},
|
||||
filter_info_{filter_info} {
|
||||
std::vector<const Node*> leaf_nodes;
|
||||
#ifdef ENABLE_TRAINING
|
||||
// Keep the info of shape and size nodes and their parents so that after topological sort, we can move them
|
||||
// right after their parents. This is to make sure the shape and size nodes are executed right after their parents
|
||||
// so it's possible the input tensor memory can be released as soon as possible. This is especially important
|
||||
// for non-CPU devices or for training case where some gradient graphs use only shape/size of tensors from forward.
|
||||
InlinedHashSet<NodeIndex> shape_size_nodes;
|
||||
InlinedHashMap<NodeIndex, InlinedVector<NodeIndex>> shape_size_parents;
|
||||
#endif
|
||||
for (auto& node : graph_->Nodes()) {
|
||||
// This is a leaf node (without any output node)
|
||||
if (node.OutputNodesBegin() == node.OutputNodesEnd()) {
|
||||
|
|
@ -72,6 +74,7 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info)
|
|||
if (node.InputEdgesBegin() == node.InputEdgesEnd()) {
|
||||
root_nodes_.push_back(node.Index());
|
||||
}
|
||||
#ifdef ENABLE_TRAINING
|
||||
if ((node.OpType() == "Shape" || node.OpType() == "Size") && node.InputEdgesBegin() != node.InputEdgesEnd()) {
|
||||
shape_size_nodes.insert(node.Index());
|
||||
NodeIndex parent = node.InputNodesBegin()->Index();
|
||||
|
|
@ -81,6 +84,7 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info)
|
|||
shape_size_parents[parent].push_back(node.Index());
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
graph.ReverseDFSFrom(
|
||||
|
|
@ -90,21 +94,24 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info)
|
|||
nodes_in_topological_order_.push_back(n->Index());
|
||||
},
|
||||
NodeCompare());
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
auto original = std::move(nodes_in_topological_order_);
|
||||
nodes_in_topological_order_.reserve(original.size());
|
||||
InlinedHashSet<NodeIndex> visited;
|
||||
for (auto& node : original) {
|
||||
if (shape_size_nodes.find(node) != shape_size_nodes.end()) {
|
||||
if (visited.find(node) != visited.end()) {
|
||||
continue;
|
||||
}
|
||||
nodes_in_topological_order_.push_back(node);
|
||||
visited.insert(node);
|
||||
if (shape_size_parents.find(node) != shape_size_parents.end()) {
|
||||
for (auto& following_node : shape_size_parents[node]) {
|
||||
nodes_in_topological_order_.push_back(following_node);
|
||||
visited.insert(following_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
graph.KahnsTopologicalSort(
|
||||
[this](const Node* n) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue